Skip to content

Commit dc93d73

Browse files
committed
replicate
1 parent 7eeebfb commit dc93d73

File tree

3 files changed

+105
-0
lines changed

3 files changed

+105
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
![visitors](https://visitor-badge.glitch.me/badge?page_id=mv-lab/swin2sr)
77
[ <a href="https://colab.research.google.com/drive/1paPrt62ydwLv2U2eZqfcFsePI4X4WRR1?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>](https://colab.research.google.com/drive/1paPrt62ydwLv2U2eZqfcFsePI4X4WRR1?usp=sharing)
88
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/jjourney1125/swin2sr)
9+
[![Replicate](https://replicate.com/cjwbw/japanese-stable-diffusion/badge)](https://replicate.com/cjwbw/japanese-stable-diffusion)
910
[ <a href="https://www.kaggle.com/code/jesucristo/super-resolution-demo-swin2sr-official/"><img src="https://upload.wikimedia.org/wikipedia/commons/7/7c/Kaggle_logo.png?20140912155123" alt="kaggle logo" width=50></a>](https://www.kaggle.com/code/jesucristo/super-resolution-demo-swin2sr-official/)
1011

1112

cog.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
build:
2+
gpu: true
3+
cuda: "11.6.2"
4+
python_version: "3.10"
5+
system_packages:
6+
- "libgl1-mesa-glx"
7+
- "libglib2.0-0"
8+
python_packages:
9+
- "ipython==8.4.0"
10+
- "torch==1.12.1 --extra-index-url=https://download.pytorch.org/whl/cu116"
11+
- "opencv-python==4.6.0.66"
12+
- "timm==0.6.11"
13+
predict: "predict.py:Predictor"

predict.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import argparse
2+
import cv2
3+
import numpy as np
4+
import torch
5+
from cog import BasePredictor, Input, Path
6+
7+
from main_test_swin2sr import define_model, test
8+
9+
10+
class Predictor(BasePredictor):
11+
def setup(self):
12+
"""Load the model into memory to make running multiple predictions efficient"""
13+
print("Loading pipeline...")
14+
15+
self.device = "cuda:0"
16+
17+
args = argparse.Namespace()
18+
args.scale = 4
19+
args.large_model = False
20+
21+
tasks = ["classical_sr", "compressed_sr", "real_sr"]
22+
paths = [
23+
"weights/Swin2SR_ClassicalSR_X4_64.pth",
24+
"weights/Swin2SR_CompressedSR_X4_48.pth",
25+
"weights/Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pth",
26+
]
27+
sizes = [64, 48, 128]
28+
29+
self.models = {}
30+
for task, path, size in zip(tasks, paths, sizes):
31+
args.training_patch_size = size
32+
args.task, args.model_path = task, path
33+
self.models[task] = define_model(args)
34+
self.models[task].eval()
35+
self.models[task] = self.models[task].to(self.device)
36+
37+
def predict(
38+
self,
39+
image: Path = Input(description="Input image"),
40+
task: str = Input(
41+
description="Choose a task",
42+
choices=["classical_sr", "real_sr", "compressed_sr"],
43+
default="real_sr",
44+
),
45+
) -> Path:
46+
"""Run a single prediction on the model"""
47+
48+
model = self.models[task]
49+
50+
window_size = 8
51+
scale = 4
52+
53+
img_lq = cv2.imread(str(image), cv2.IMREAD_COLOR).astype(np.float32) / 255.0
54+
img_lq = np.transpose(
55+
img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1)
56+
) # HCW-BGR to CHW-RGB
57+
img_lq = (
58+
torch.from_numpy(img_lq).float().unsqueeze(0).to(self.device)
59+
) # CHW-RGB to NCHW-RGB
60+
61+
# inference
62+
with torch.no_grad():
63+
# pad input image to be a multiple of window_size
64+
_, _, h_old, w_old = img_lq.size()
65+
h_pad = (h_old // window_size + 1) * window_size - h_old
66+
w_pad = (w_old // window_size + 1) * window_size - w_old
67+
img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[
68+
:, :, : h_old + h_pad, :
69+
]
70+
img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[
71+
:, :, :, : w_old + w_pad
72+
]
73+
74+
output = model(img_lq)
75+
76+
if task == "compressed_sr":
77+
output = output[0][..., : h_old * scale, : w_old * scale]
78+
else:
79+
output = output[..., : h_old * scale, : w_old * scale]
80+
81+
# save image
82+
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
83+
if output.ndim == 3:
84+
output = np.transpose(
85+
output[[2, 1, 0], :, :], (1, 2, 0)
86+
) # CHW-RGB to HCW-BGR
87+
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
88+
output_path = "/tmp/out.png"
89+
cv2.imwrite(output_path, output)
90+
91+
return Path(output_path)

0 commit comments

Comments
 (0)