Skip to content

Commit 3bd2c74

Browse files
committed
update
1 parent 35040b2 commit 3bd2c74

File tree

2 files changed

+25
-185
lines changed

2 files changed

+25
-185
lines changed

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
*Our method can realize **arbitrary face swapping** on images and videos with **one single trained model**.*
66

7-
Training and test code are now available! [Colab demo](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/train.ipynb)
7+
Training and test code are now available!
8+
[ <a href="https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/train.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>](https://colab.research.google.com/github/neuralchen/SimSwap/blob/main/train.ipynb)
89

910
We are working with our incoming paper SimSwap++, keeping expecting!
1011

@@ -26,6 +27,8 @@ If you find this project useful, please star it. It is the greatest appreciation
2627

2728
## Top News <img width=8% src="./docs/img/new.gif"/>
2829

30+
**`2022-04-21`**: For resource limited users, we provide the cropped VGGFace2-224 dataset [VGGFace2-224 (10.8G)](https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing).
31+
2932
**`2022-04-20`**: Training scripts are now available. We highly recommend that you guys train the simswap model with our released high quality dataset [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ).
3033

3134
**`2021-11-24`**: We have trained a beta version of ***SimSwap-HQ*** on [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ) and open sourced the checkpoint of this model (if you think the Simswap 512 is cool, please star our [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ) repo). Please don’t forget to go to [Preparation](./docs/guidance/preparation.md) and [Inference for image or video face swapping](./docs/guidance/usage.md) to check the latest set up.
@@ -65,7 +68,7 @@ Download the dataset from [VGGFace2-HQ](https://github.com/NNNNAI/VGGFace2-HQ).
6568
The training script is slightly different from the original version, e.g., we replace the patch discriminator with the projected discriminator, which saves a lot of hardware overhead and achieves slightly better results.
6669
In order to ensure normal training, the batch size must be greater than 1.
6770

68-
- Train 224 models with VGGFace2 224*224 [VGGFace2-224](https://github.com/NNNNAI/VGGFace2-HQ)
71+
- Train 224 models with VGGFace2 224*224 [VGGFace2-224 (10.8G)](https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing)
6972
```
7073
python train.py --name simswap224_test --batchSize 4 --gpu_ids 0 --dataset /path/to/VGGFace2HQ --Gdeep False
7174
```

train.ipynb

Lines changed: 20 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -37,41 +37,10 @@
3737
"!nvidia-smi"
3838
],
3939
"metadata": {
40-
"colab": {
41-
"base_uri": "https://localhost:8080/"
42-
},
43-
"id": "J8WrNaQHuUGC",
44-
"outputId": "afffa0be-92b5-4133-b6d9-6c3e08c6de64"
40+
"id": "J8WrNaQHuUGC"
4541
},
4642
"execution_count": null,
47-
"outputs": [
48-
{
49-
"output_type": "stream",
50-
"name": "stdout",
51-
"text": [
52-
"Thu Apr 21 16:07:35 2022 \n",
53-
"+-----------------------------------------------------------------------------+\n",
54-
"| NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2 |\n",
55-
"|-------------------------------+----------------------+----------------------+\n",
56-
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
57-
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
58-
"| | | MIG M. |\n",
59-
"|===============================+======================+======================|\n",
60-
"| 0 Tesla K80 Off | 00000000:00:04.0 Off | 0 |\n",
61-
"| N/A 67C P8 32W / 149W | 0MiB / 11441MiB | 0% Default |\n",
62-
"| | | N/A |\n",
63-
"+-------------------------------+----------------------+----------------------+\n",
64-
" \n",
65-
"+-----------------------------------------------------------------------------+\n",
66-
"| Processes: |\n",
67-
"| GPU GI CI PID Type Process name GPU Memory |\n",
68-
"| ID ID Usage |\n",
69-
"|=============================================================================|\n",
70-
"| No running processes found |\n",
71-
"+-----------------------------------------------------------------------------+\n"
72-
]
73-
}
74-
]
43+
"outputs": []
7544
},
7645
{
7746
"cell_type": "markdown",
@@ -99,29 +68,10 @@
9968
"!cd SimSwap && git pull"
10069
],
10170
"metadata": {
102-
"colab": {
103-
"base_uri": "https://localhost:8080/"
104-
},
105-
"id": "9jZWwt97uvIe",
106-
"outputId": "42a1bda8-3ca3-46af-fc82-d1af99ce15e1"
71+
"id": "9jZWwt97uvIe"
10772
},
10873
"execution_count": null,
109-
"outputs": [
110-
{
111-
"output_type": "stream",
112-
"name": "stdout",
113-
"text": [
114-
"Cloning into 'SimSwap'...\n",
115-
"remote: Enumerating objects: 1017, done.\u001b[K\n",
116-
"remote: Counting objects: 100% (16/16), done.\u001b[K\n",
117-
"remote: Compressing objects: 100% (13/13), done.\u001b[K\n",
118-
"remote: Total 1017 (delta 5), reused 10 (delta 3), pack-reused 1001\u001b[K\n",
119-
"Receiving objects: 100% (1017/1017), 210.79 MiB | 14.80 MiB/s, done.\n",
120-
"Resolving deltas: 100% (510/510), done.\n",
121-
"Already up to date.\n"
122-
]
123-
}
124-
]
74+
"outputs": []
12575
},
12676
{
12777
"cell_type": "markdown",
@@ -140,32 +90,22 @@
14090
"!wget -P SimSwap/arcface_model https://github.com/neuralchen/SimSwap/releases/download/1.0/arcface_checkpoint.tar"
14191
],
14292
"metadata": {
143-
"colab": {
144-
"base_uri": "https://localhost:8080/"
145-
},
146-
"id": "rwvbPhtOvZAL",
147-
"outputId": "ffa12208-d388-412d-e83b-c54864c4526e"
93+
"id": "rwvbPhtOvZAL"
14894
},
14995
"execution_count": null,
150-
"outputs": [
151-
{
152-
"output_type": "stream",
153-
"name": "stdout",
154-
"text": [
155-
"Requirement already satisfied: googledrivedownloader in /usr/local/lib/python3.7/dist-packages (0.4)\n",
156-
"Requirement already satisfied: imageio==2.4.1 in /usr/local/lib/python3.7/dist-packages (2.4.1)\n",
157-
"Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from imageio==2.4.1) (7.1.2)\n",
158-
"Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from imageio==2.4.1) (1.21.6)\n"
159-
]
160-
}
161-
]
96+
"outputs": []
16297
},
16398
{
16499
"cell_type": "markdown",
165100
"source": [
166101
"#Download the Training Dataset\n",
167102
"We employ the cropped VGGFace2-224 dataset for this toy training demo.\n",
168-
"You can download the dataset from our google driver "
103+
"\n",
104+
"You can download the dataset from our google driver https://drive.google.com/file/d/19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc/view?usp=sharing\n",
105+
"\n",
106+
"***Please check the dataset in dir /content/TrainingData***\n",
107+
"\n",
108+
"***If dataset already exists in /content/TrainingData, please do not run blow scripts!***\n"
169109
],
170110
"metadata": {
171111
"id": "hleVtHIJ_QUK"
@@ -174,28 +114,16 @@
174114
{
175115
"cell_type": "code",
176116
"source": [
177-
"from google_drive_downloader import GoogleDriveDownloader as gdd\n",
178-
"gdd.download_file_from_google_drive(file_id='1iytA1n2z4go3uVCwE__vIKouTKyIDjEq',dest_path='/content/TrainingData/vggface2_crop_arcfacealign_224.tar',showsize=True)\n",
179-
"!tar -xzvf /content/TrainingData/vggface2_crop_arcfacealign_224.tar"
117+
"!wget --load-cookies /tmp/cookies.txt \"https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=19pWvdEHS-CEG6tW3PdxdtZ5QEymVjImc\" -O /content/TrainingData/vggface2_crop_arcfacealign_224.tar && rm -rf /tmp/cookies.txt\n",
118+
"%%cd /content/\n",
119+
"!tar -xzvf /content/TrainingData/vggface2_crop_arcfacealign_224.tar\n",
120+
"!rm /content/TrainingData/vggface2_crop_arcfacealign_224.tar"
180121
],
181122
"metadata": {
182-
"colab": {
183-
"base_uri": "https://localhost:8080/"
184-
},
185-
"id": "gMVKEej59LX9",
186-
"outputId": "2e508c44-d006-4183-81d9-f9753d08dea7"
123+
"id": "h2tyjBl0Llxp"
187124
},
188125
"execution_count": null,
189-
"outputs": [
190-
{
191-
"output_type": "stream",
192-
"name": "stdout",
193-
"text": [
194-
"Downloading 1iytA1n2z4go3uVCwE__vIKouTKyIDjEq into /content/TrainingData/mnist.zip... \n",
195-
"0.0 B Done.\n"
196-
]
197-
}
198-
]
126+
"outputs": []
199127
},
200128
{
201129
"cell_type": "markdown",
@@ -215,101 +143,10 @@
215143
"!python train.py --name simswap224_test --gpu_ids 0 --dataset /content/TrainingData/vggface2_crop_arcfacealign_224 --Gdeep False"
216144
],
217145
"metadata": {
218-
"colab": {
219-
"base_uri": "https://localhost:8080/"
220-
},
221-
"id": "XCxHa4oW507s",
222-
"outputId": "c84c52d9-0b36-4932-925d-1ae38a3f7bb0"
146+
"id": "XCxHa4oW507s"
223147
},
224148
"execution_count": null,
225-
"outputs": [
226-
{
227-
"output_type": "stream",
228-
"name": "stdout",
229-
"text": [
230-
"/content/SimSwap\n",
231-
" arcface_model\t predict.py\n",
232-
" cog.yaml\t README.md\n",
233-
" crop_224\t 'SimSwap colab.ipynb'\n",
234-
" data\t\t simswaplogo\n",
235-
" demo_file\t test_one_image.py\n",
236-
" docs\t\t test_video_swapmulti.py\n",
237-
" download-weights.sh test_video_swap_multispecific.py\n",
238-
" insightface_func test_video_swapsingle.py\n",
239-
" LICENSE\t test_video_swapspecific.py\n",
240-
" models\t\t test_wholeimage_swapmulti.py\n",
241-
" MultiSpecific.ipynb test_wholeimage_swap_multispecific.py\n",
242-
" options\t test_wholeimage_swapsingle.py\n",
243-
" output\t\t test_wholeimage_swapspecific.py\n",
244-
" parsing_model\t train.py\n",
245-
" pg_modules\t util\n",
246-
"------------ Options -------------\n",
247-
"Arc_path: arcface_model/arcface_checkpoint.tar\n",
248-
"Gdeep: False\n",
249-
"batchSize: 2\n",
250-
"beta1: 0.0\n",
251-
"checkpoints_dir: ./checkpoints\n",
252-
"continue_train: False\n",
253-
"dataset: /path/to/VGGFace2\n",
254-
"gpu_ids: 0\n",
255-
"isTrain: True\n",
256-
"lambda_feat: 10.0\n",
257-
"lambda_id: 30.0\n",
258-
"lambda_rec: 10.0\n",
259-
"load_pretrain: checkpoints\n",
260-
"log_frep: 200\n",
261-
"lr: 0.0004\n",
262-
"model_freq: 10000\n",
263-
"name: simswap\n",
264-
"niter: 10000\n",
265-
"niter_decay: 10000\n",
266-
"phase: train\n",
267-
"sample_freq: 1000\n",
268-
"tag: simswap\n",
269-
"total_step: 1000000\n",
270-
"train_simswap: True\n",
271-
"use_tensorboard: False\n",
272-
"which_epoch: 800000\n",
273-
"-------------- End ----------------\n",
274-
"GPU used : 0\n",
275-
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.parallel.data_parallel.DataParallel' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
276-
" warnings.warn(msg, SourceChangeWarning)\n",
277-
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.conv.Conv2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
278-
" warnings.warn(msg, SourceChangeWarning)\n",
279-
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
280-
" warnings.warn(msg, SourceChangeWarning)\n",
281-
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.activation.PReLU' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
282-
" warnings.warn(msg, SourceChangeWarning)\n",
283-
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.MaxPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
284-
" warnings.warn(msg, SourceChangeWarning)\n",
285-
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.container.Sequential' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
286-
" warnings.warn(msg, SourceChangeWarning)\n",
287-
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.pooling.AdaptiveAvgPool2d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
288-
" warnings.warn(msg, SourceChangeWarning)\n",
289-
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.linear.Linear' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
290-
" warnings.warn(msg, SourceChangeWarning)\n",
291-
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.activation.Sigmoid' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
292-
" warnings.warn(msg, SourceChangeWarning)\n",
293-
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.dropout.Dropout' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
294-
" warnings.warn(msg, SourceChangeWarning)\n",
295-
"/usr/local/lib/python3.7/dist-packages/torch/serialization.py:671: SourceChangeWarning: source code of class 'torch.nn.modules.batchnorm.BatchNorm1d' has changed. you can retrieve the original source code by accessing the object's source attribute or set `torch.nn.Module.dump_patches = True` and use the patch tool to revert the changes.\n",
296-
" warnings.warn(msg, SourceChangeWarning)\n",
297-
"Downloading: \"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth\" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_lite0-0aa007d2.pth\n",
298-
"processing Swapping dataset images...\n",
299-
"Finished preprocessing the Swapping dataset, total dirs number: 0...\n",
300-
"Traceback (most recent call last):\n",
301-
" File \"train.py\", line 163, in <module>\n",
302-
" train_loader = GetLoader(opt.dataset,opt.batchSize,8,1234)\n",
303-
" File \"/content/SimSwap/data/data_loader_Swapping.py\", line 119, in GetLoader\n",
304-
" drop_last=True,shuffle=True,num_workers=num_workers,pin_memory=True)\n",
305-
" File \"/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py\", line 268, in __init__\n",
306-
" sampler = RandomSampler(dataset, generator=generator)\n",
307-
" File \"/usr/local/lib/python3.7/dist-packages/torch/utils/data/sampler.py\", line 103, in __init__\n",
308-
" \"value, but got num_samples={}\".format(self.num_samples))\n",
309-
"ValueError: num_samples should be a positive integer value, but got num_samples=0\n"
310-
]
311-
}
312-
]
149+
"outputs": []
313150
}
314151
]
315152
}

0 commit comments

Comments
 (0)