Skip to content

Commit 393ae33

Browse files
authored
SAM 2 Update 12/11/2024 -- full model compilation for a major VOS speedup and a new SAM2VideoPredictor to better handle multi-object tracking (#486)
This PR provides new features and updates for SAM 2: - We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`). * Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS. * In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag. * Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model. * **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts. - We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`: * Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features. * This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage). * We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.
1 parent c2ec8e1 commit 393ae33

27 files changed

+1794
-443
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ build/*
88
_C.*
99
outputs/*
1010
checkpoints/*.pt
11+
demo/backend/checkpoints/*.pt

INSTALL.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
### Requirements
44

5-
- Linux with Python ≥ 3.10, PyTorch ≥ 2.3.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
5+
- Linux with Python ≥ 3.10, PyTorch ≥ 2.5.1 and [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation. Install them together at https://pytorch.org to ensure this.
66
* Note older versions of Python or PyTorch may also work. However, the versions above are strongly recommended to provide all features such as `torch.compile`.
77
- [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) that match the CUDA version for your PyTorch installation. This should typically be CUDA 12.1 if you follow the default installation command.
88
- If you are installing on Windows, it's strongly recommended to use [Windows Subsystem for Linux (WSL)](https://learn.microsoft.com/en-us/windows/wsl/install) with Ubuntu.
@@ -121,9 +121,9 @@ I got `undefined symbol: _ZN3c1015SmallVectorBaseIjE8grow_podEPKvmm` (or similar
121121

122122
This usually happens because you have multiple versions of dependencies (PyTorch or CUDA) in your environment. During installation, the SAM 2 library is compiled against one version library while at run time it links against another version. This might be due to that you have different versions of PyTorch or CUDA installed separately via `pip` or `conda`. You may delete one of the duplicates to only keep a single PyTorch and CUDA version.
123123

124-
In particular, if you have a lower PyTorch version than 2.3.1, it's recommended to upgrade to PyTorch 2.3.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.
124+
In particular, if you have a lower PyTorch version than 2.5.1, it's recommended to upgrade to PyTorch 2.5.1 or higher first. Otherwise, the installation script will try to upgrade to the latest PyTorch using `pip`, which could sometimes lead to duplicated PyTorch installation if you have previously installed another PyTorch version using `conda`.
125125

126-
We have been building SAM 2 against PyTorch 2.3.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.3.1` to `torch>=2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
126+
We have been building SAM 2 against PyTorch 2.5.1 internally. However, a few user comments (e.g. https://github.com/facebookresearch/sam2/issues/22, https://github.com/facebookresearch/sam2/issues/14) suggested that downgrading to PyTorch 2.1.0 might resolve this problem. In case the error persists, you may try changing the restriction from `torch>=2.5.1` to `torch==2.1.0` in both [`pyproject.toml`](pyproject.toml) and [`setup.py`](setup.py) to allow PyTorch 2.1.0.
127127
</details>
128128

129129
<details>

README.md

+17-12
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414

1515
## Latest updates
1616

17+
**12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking**
18+
19+
- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor`, leading to a major speedup for VOS inference.
20+
- We update the implementation of `SAM2VideoPredictor` to support independent per-object inference, allowing us to relax the assumption of prompting for multi-object tracking and adding new objects after tracking starts.
21+
- See [`RELEASE_NOTES.md`](RELEASE_NOTES.md) for full details.
22+
1723
**09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released**
1824

1925
- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details.
@@ -23,7 +29,7 @@
2329

2430
## Installation
2531

26-
SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.3.1` and `torchvision>=0.18.1`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using:
32+
SAM 2 needs to be installed first before use. The code requires `python>=3.10`, as well as `torch>=2.5.1` and `torchvision>=0.20.1`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. You can install SAM 2 on a GPU machine using:
2733

2834
```bash
2935
git clone https://github.com/facebookresearch/sam2.git && cd sam2
@@ -39,7 +45,7 @@ pip install -e ".[notebooks]"
3945
```
4046

4147
Note:
42-
1. It's recommended to create a new Python environment via [Anaconda](https://www.anaconda.com/) for this installation and install PyTorch 2.3.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.3.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`.
48+
1. It's recommended to create a new Python environment via [Anaconda](https://www.anaconda.com/) for this installation and install PyTorch 2.5.1 (or higher) via `pip` following https://pytorch.org/. If you have a PyTorch version lower than 2.5.1 in your current environment, the installation command above will try to upgrade it to the latest PyTorch version using `pip`.
4349
2. The step above requires compiling a custom CUDA kernel with the `nvcc` compiler. If it isn't already available on your machine, please install the [CUDA toolkits](https://developer.nvidia.com/cuda-toolkit-archive) with a version that matches your PyTorch CUDA version.
4450
3. If you see a message like `Failed to build the SAM 2 CUDA extension` during installation, you can ignore it and still use SAM 2 (some post-processing functionality may be limited, but it doesn't affect the results in most cases).
4551

@@ -158,24 +164,23 @@ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
158164
The table below shows the improved SAM 2.1 checkpoints released on September 29, 2024.
159165
| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
160166
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
161-
| sam2.1_hiera_tiny <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 47.2 | 76.5 | 71.8 | 77.3 |
162-
| sam2.1_hiera_small <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 43.3 (53.0 compiled\*) | 76.6 | 73.5 | 78.3 |
163-
| sam2.1_hiera_base_plus <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 34.8 (43.8 compiled\*) | 78.2 | 73.7 | 78.2 |
164-
| sam2.1_hiera_large <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 24.2 (30.2 compiled\*) | 79.5 | 74.6 | 80.6 |
167+
| sam2.1_hiera_tiny <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt)) | 38.9 | 91.2 | 76.5 | 71.8 | 77.3 |
168+
| sam2.1_hiera_small <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt)) | 46 | 84.8 | 76.6 | 73.5 | 78.3 |
169+
| sam2.1_hiera_base_plus <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_base_plus.pt)) | 80.8 | 64.1 | 78.2 | 73.7 | 78.2 |
170+
| sam2.1_hiera_large <br /> ([config](sam2/configs/sam2.1/sam2.1_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt)) | 224.4 | 39.5 | 79.5 | 74.6 | 80.6 |
165171

166172
### SAM 2 checkpoints
167173

168174
The previous SAM 2 checkpoints released on July 29, 2024 can be found as follows:
169175

170176
| **Model** | **Size (M)** | **Speed (FPS)** | **SA-V test (J&F)** | **MOSE val (J&F)** | **LVOS v2 (J&F)** |
171177
| :------------------: | :----------: | :--------------------: | :-----------------: | :----------------: | :---------------: |
172-
| sam2_hiera_tiny <br /> ([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 47.2 | 75.0 | 70.9 | 75.3 |
173-
| sam2_hiera_small <br /> ([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 43.3 (53.0 compiled\*) | 74.9 | 71.5 | 76.4 |
174-
| sam2_hiera_base_plus <br /> ([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 34.8 (43.8 compiled\*) | 74.7 | 72.8 | 75.8 |
175-
| sam2_hiera_large <br /> ([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 24.2 (30.2 compiled\*) | 76.0 | 74.6 | 79.8 |
176-
177-
\* Compile the model by setting `compile_image_encoder: True` in the config.
178+
| sam2_hiera_tiny <br /> ([config](sam2/configs/sam2/sam2_hiera_t.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt)) | 38.9 | 91.5 | 75.0 | 70.9 | 75.3 |
179+
| sam2_hiera_small <br /> ([config](sam2/configs/sam2/sam2_hiera_s.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt)) | 46 | 85.6 | 74.9 | 71.5 | 76.4 |
180+
| sam2_hiera_base_plus <br /> ([config](sam2/configs/sam2/sam2_hiera_b+.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt)) | 80.8 | 64.8 | 74.7 | 72.8 | 75.8 |
181+
| sam2_hiera_large <br /> ([config](sam2/configs/sam2/sam2_hiera_l.yaml), [checkpoint](https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt)) | 224.4 | 39.7 | 76.0 | 74.6 | 79.8 |
178182

183+
Speed measured on an A100 with `torch 2.5.1, cuda 12.4`. See `benchmark.py` for an example on benchmarking (compiling all the model components). Compiling only the image encoder can be more flexible and also provide (a smaller) speed-up (set `compile_image_encoder: True` in the config).
179184
## Segment Anything Video Dataset
180185

181186
See [sav_dataset/README.md](sav_dataset/README.md) for details.

RELEASE_NOTES.md

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
## SAM 2 release notes
2+
3+
### 12/11/2024 -- full model compilation for a major VOS speedup and a new `SAM2VideoPredictor` to better handle multi-object tracking
4+
5+
- We now support `torch.compile` of the entire SAM 2 model on videos, which can be turned on by setting `vos_optimized=True` in `build_sam2_video_predictor` (it uses the new `SAM2VideoPredictorVOS` predictor class in `sam2/sam2_video_predictor.py`).
6+
* Compared to the previous setting (which only compiles the image encoder backbone), the new full model compilation gives a major speedup in inference FPS.
7+
* In the VOS prediction script `tools/vos_inference.py`, you can specify this option in `tools/vos_inference.py` via the `--use_vos_optimized_video_predictor` flag.
8+
* Note that turning on this flag might introduce a small variance in the predictions due to numerical differences caused by `torch.compile` of the full model.
9+
* **PyTorch 2.5.1 is the minimum version for full support of this feature**. (Earlier PyTorch versions might run into compilation errors in some cases.) Therefore, we have updated the minimum PyTorch version to 2.5.1 accordingly in the installation scripts.
10+
- We also update the implementation of the `SAM2VideoPredictor` class for the SAM 2 video prediction in `sam2/sam2_video_predictor.py`, which allows for independent per-object inference. Specifically, in the new `SAM2VideoPredictor`:
11+
* Now **we handle the inference of each object independently** (as if we are opening a separate session for each object) while sharing their backbone features.
12+
* This change allows us to relax the assumption of prompting for multi-object tracking. Previously (due to the batching behavior in inference), if a video frame receives clicks for only a subset of objects, the rest of the (non-prompted) objects are assumed to be non-existent in this frame (i.e., in such frames, the user is telling SAM 2 that the rest of the objects don't appear). Now, if a frame receives clicks for only a subset of objects, we do not make any assumptions about the remaining (non-prompted) objects (i.e., now each object is handled independently and is not affected by how other objects are prompted). As a result, **we allow adding new objects after tracking starts** after this change (which was previously a restriction on usage).
13+
* We believe that the new version is a more natural inference behavior and therefore switched to it as the default behavior. The previous implementation of `SAM2VideoPredictor` is backed up to in `sam2/sam2_video_predictor_legacy.py`. All the VOS inference results using `tools/vos_inference.py` should remain the same after this change to the `SAM2VideoPredictor` class.
14+
15+
### 09/30/2024 -- SAM 2.1 Developer Suite (new checkpoints, training code, web demo) is released
16+
17+
- A new suite of improved model checkpoints (denoted as **SAM 2.1**) are released. See [Model Description](#model-description) for details.
18+
* To use the new SAM 2.1 checkpoints, you need the latest model code from this repo. If you have installed an earlier version of this repo, please first uninstall the previous version via `pip uninstall SAM-2`, pull the latest code from this repo (with `git pull`), and then reinstall the repo following [Installation](#installation) below.
19+
- The training (and fine-tuning) code has been released. See [`training/README.md`](training/README.md) on how to get started.
20+
- The frontend + backend code for the SAM 2 web demo has been released. See [`demo/README.md`](demo/README.md) for details.
21+
22+
### 07/29/2024 -- SAM 2 is released
23+
24+
- We release Segment Anything Model 2 (SAM 2), a foundation model towards solving promptable visual segmentation in images and videos.
25+
* SAM 2 code: https://github.com/facebookresearch/sam2
26+
* SAM 2 demo: https://sam2.metademolab.com/
27+
* SAM 2 paper: https://arxiv.org/abs/2408.00714

backend.Dockerfile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
ARG BASE_IMAGE=pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime
1+
ARG BASE_IMAGE=pytorch/pytorch:2.5.1-cuda12.1-cudnn9-runtime
22
ARG MODEL_SIZE=base_plus
33

44
FROM ${BASE_IMAGE}

demo/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ cd demo/backend/server/
105105
```bash
106106
PYTORCH_ENABLE_MPS_FALLBACK=1 \
107107
APP_ROOT="$(pwd)/../../../" \
108-
APP_URL=http://localhost:7263 \
108+
API_URL=http://localhost:7263 \
109109
MODEL_SIZE=base_plus \
110110
DATA_PATH="$(pwd)/../../data" \
111111
DEFAULT_VIDEO_PATH=gallery/05_default_juggle.mp4 \

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[build-system]
22
requires = [
33
"setuptools>=61.0",
4-
"torch>=2.3.1",
4+
"torch>=2.5.1",
55
]
66
build-backend = "setuptools.build_meta"

sam2/benchmark.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import os
8+
import time
9+
10+
import numpy as np
11+
import torch
12+
from tqdm import tqdm
13+
14+
from sam2.build_sam import build_sam2_video_predictor
15+
16+
# Only cuda supported
17+
assert torch.cuda.is_available()
18+
device = torch.device("cuda")
19+
20+
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
21+
if torch.cuda.get_device_properties(0).major >= 8:
22+
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
23+
torch.backends.cuda.matmul.allow_tf32 = True
24+
torch.backends.cudnn.allow_tf32 = True
25+
26+
# Config and checkpoint
27+
sam2_checkpoint = "checkpoints/sam2.1_hiera_base_plus.pt"
28+
model_cfg = "configs/sam2.1/sam2.1_hiera_b+.yaml"
29+
30+
# Build video predictor with vos_optimized=True setting
31+
predictor = build_sam2_video_predictor(
32+
model_cfg, sam2_checkpoint, device=device, vos_optimized=True
33+
)
34+
35+
36+
# Initialize with video
37+
video_dir = "notebooks/videos/bedroom"
38+
# scan all the JPEG frame names in this directory
39+
frame_names = [
40+
p
41+
for p in os.listdir(video_dir)
42+
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
43+
]
44+
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
45+
inference_state = predictor.init_state(video_path=video_dir)
46+
47+
48+
# Number of runs, warmup etc
49+
warm_up, runs = 5, 25
50+
verbose = True
51+
num_frames = len(frame_names)
52+
total, count = 0, 0
53+
torch.cuda.empty_cache()
54+
55+
# We will select an object with a click.
56+
# See video_predictor_example.ipynb for more detailed explanation
57+
ann_frame_idx, ann_obj_id = 0, 1
58+
# Add a positive click at (x, y) = (210, 350)
59+
# For labels, `1` means positive click
60+
points = np.array([[210, 350]], dtype=np.float32)
61+
labels = np.array([1], np.int32)
62+
63+
_, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
64+
inference_state=inference_state,
65+
frame_idx=ann_frame_idx,
66+
obj_id=ann_obj_id,
67+
points=points,
68+
labels=labels,
69+
)
70+
71+
# Warmup and then average FPS over several runs
72+
with torch.autocast("cuda", torch.bfloat16):
73+
with torch.inference_mode():
74+
for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
75+
start = time.time()
76+
# Start tracking
77+
for (
78+
out_frame_idx,
79+
out_obj_ids,
80+
out_mask_logits,
81+
) in predictor.propagate_in_video(inference_state):
82+
pass
83+
84+
end = time.time()
85+
total += end - start
86+
count += 1
87+
if i == warm_up - 1:
88+
print("Warmup FPS: ", count * num_frames / total)
89+
total = 0
90+
count = 0
91+
92+
print("FPS: ", count * num_frames / total)

sam2/build_sam.py

+7
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,18 @@ def build_sam2_video_predictor(
104104
mode="eval",
105105
hydra_overrides_extra=[],
106106
apply_postprocessing=True,
107+
vos_optimized=False,
107108
**kwargs,
108109
):
109110
hydra_overrides = [
110111
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
111112
]
113+
if vos_optimized:
114+
hydra_overrides = [
115+
"++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictorVOS",
116+
"++model.compile_image_encoder=True", # Let sam2_base handle this
117+
]
118+
112119
if apply_postprocessing:
113120
hydra_overrides_extra = hydra_overrides_extra.copy()
114121
hydra_overrides_extra += [

0 commit comments

Comments
 (0)