|
| 1 | +import argparse |
| 2 | +import cv2 |
| 3 | +import numpy as np |
| 4 | +import torch |
| 5 | +from cog import BasePredictor, Input, Path |
| 6 | + |
| 7 | +from main_test_swin2sr import define_model, test |
| 8 | + |
| 9 | + |
| 10 | +class Predictor(BasePredictor): |
| 11 | + def setup(self): |
| 12 | + """Load the model into memory to make running multiple predictions efficient""" |
| 13 | + print("Loading pipeline...") |
| 14 | + |
| 15 | + self.device = "cuda:0" |
| 16 | + |
| 17 | + args = argparse.Namespace() |
| 18 | + args.scale = 4 |
| 19 | + args.large_model = False |
| 20 | + |
| 21 | + tasks = ["classical_sr", "compressed_sr", "real_sr"] |
| 22 | + paths = [ |
| 23 | + "weights/Swin2SR_ClassicalSR_X4_64.pth", |
| 24 | + "weights/Swin2SR_CompressedSR_X4_48.pth", |
| 25 | + "weights/Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pth", |
| 26 | + ] |
| 27 | + sizes = [64, 48, 128] |
| 28 | + |
| 29 | + self.models = {} |
| 30 | + for task, path, size in zip(tasks, paths, sizes): |
| 31 | + args.training_patch_size = size |
| 32 | + args.task, args.model_path = task, path |
| 33 | + self.models[task] = define_model(args) |
| 34 | + self.models[task].eval() |
| 35 | + self.models[task] = self.models[task].to(self.device) |
| 36 | + |
| 37 | + def predict( |
| 38 | + self, |
| 39 | + image: Path = Input(description="Input image"), |
| 40 | + task: str = Input( |
| 41 | + description="Choose a task", |
| 42 | + choices=["classical_sr", "real_sr", "compressed_sr"], |
| 43 | + default="real_sr", |
| 44 | + ), |
| 45 | + ) -> Path: |
| 46 | + """Run a single prediction on the model""" |
| 47 | + |
| 48 | + model = self.models[task] |
| 49 | + |
| 50 | + window_size = 8 |
| 51 | + scale = 4 |
| 52 | + |
| 53 | + img_lq = cv2.imread(str(image), cv2.IMREAD_COLOR).astype(np.float32) / 255.0 |
| 54 | + img_lq = np.transpose( |
| 55 | + img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1) |
| 56 | + ) # HCW-BGR to CHW-RGB |
| 57 | + img_lq = ( |
| 58 | + torch.from_numpy(img_lq).float().unsqueeze(0).to(self.device) |
| 59 | + ) # CHW-RGB to NCHW-RGB |
| 60 | + |
| 61 | + # inference |
| 62 | + with torch.no_grad(): |
| 63 | + # pad input image to be a multiple of window_size |
| 64 | + _, _, h_old, w_old = img_lq.size() |
| 65 | + h_pad = (h_old // window_size + 1) * window_size - h_old |
| 66 | + w_pad = (w_old // window_size + 1) * window_size - w_old |
| 67 | + img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[ |
| 68 | + :, :, : h_old + h_pad, : |
| 69 | + ] |
| 70 | + img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[ |
| 71 | + :, :, :, : w_old + w_pad |
| 72 | + ] |
| 73 | + |
| 74 | + output = model(img_lq) |
| 75 | + |
| 76 | + if task == "compressed_sr": |
| 77 | + output = output[0][..., : h_old * scale, : w_old * scale] |
| 78 | + else: |
| 79 | + output = output[..., : h_old * scale, : w_old * scale] |
| 80 | + |
| 81 | + # save image |
| 82 | + output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() |
| 83 | + if output.ndim == 3: |
| 84 | + output = np.transpose( |
| 85 | + output[[2, 1, 0], :, :], (1, 2, 0) |
| 86 | + ) # CHW-RGB to HCW-BGR |
| 87 | + output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 |
| 88 | + output_path = "/tmp/out.png" |
| 89 | + cv2.imwrite(output_path, output) |
| 90 | + |
| 91 | + return Path(output_path) |
0 commit comments