Skip to content

Commit 2253d71

Browse files
committed
first code update
1 parent 5a631d7 commit 2253d71

11 files changed

+599
-10
lines changed

.gitignore

+7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
# Custom
2+
3+
*.ipynb
4+
!nilut-multiblend.ipynb
5+
6+
dataset/*.png
7+
18
# Byte-compiled / optimized / DLL files
29
__pycache__/
310
*.py[cod]

README.md

+48-10
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,67 @@
1-
# NILUT: Conditional Neural Implicit 3D Lookup Tables for Image Enhancement
1+
# [NILUT: Conditional Neural Implicit 3D Lookup Tables for Image Enhancement](https://arxiv.org/abs/2306.11920)
2+
3+
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2306.11920)
4+
[<a href=""><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="colab demo"></a>]()
5+
[<a href="https://www.kaggle.com/code/jesucristo/super-resolution-demo-swin2sr-official/"><img src="https://upload.wikimedia.org/wikipedia/commons/7/7c/Kaggle_logo.png?20140912155123" alt="kaggle demo" width=50></a>]()
26

3-
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)]()
47

58
[Marcos V. Conde](https://scholar.google.com/citations?user=NtB1kjYAAAAJ&hl=en), [Javier Vazquez-Corral](https://scholar.google.com/citations?user=gjnuPMoAAAAJ&hl=en), [Michael S. Brown](https://scholar.google.com/citations?hl=en&user=Gv1QGSMAAAAJ), [Radu Timofte](https://scholar.google.com/citations?user=u3MwH5kAAAAJ&hl=en)
69

710

8-
**TL;DR** NILUT uses neural representations for controllable photorealistic image enhancement.
11+
**TL;DR** NILUT uses neural representations for controllable photorealistic image enhancement. 🚀 Demo Tutorial and pretrained models available.
12+
13+
14+
<img src="media/nilut-intro.gif" alt="NILUT" width="800">
15+
16+
----
17+
18+
3D lookup tables (3D LUTs) are a key component for image enhancement. Modern image signal processors (ISPs) have dedicated support for these as part of the camera rendering pipeline. Cameras typically provide multiple options for picture styles, where each style is usually obtained by applying a unique handcrafted 3D LUT. Current approaches for learning and applying 3D LUTs are notably fast, yet not so memory-efficient, as storing multiple 3D LUTs is required. For this reason and other implementation limitations, their use on mobile devices is less popular.
19+
20+
In this work, we propose a Neural Implicit LUT (NILUT), an implicitly defined continuous 3D color transformation parameterized by a neural network. We show that NILUTs are capable of accurately emulating real 3D LUTs. Moreover, a NILUT can be extended to incorporate multiple styles into a single network with the ability to blend styles implicitly. Our novel approach is memory-efficient, controllable and can complement previous methods, including learned ISPs.
21+
922

1023
**Topics** Image Enhancement, Image Editing, Color Manipulation, Tone Mapping, Presets
1124

1225
***Website and repo in progress.*** **See also [AISP](https://github.com/mv-lab/AISP)** for image signal processing code and papers.
1326

27+
----
1428

15-
<br>
29+
**Pre-trained** sample models are available at `models/`. We provide `nilutx3style.pt` a NILUT that encodes three 3D LUT styles (1,3,4) with high accuracy.
1630

17-
<img src="nilut-intro.gif" alt="NILUT" width="800">
31+
**Demo Tutorial** in [nilut-multiblend.ipynb](nilut-multiblend.ipynb) we provide a simple tutorial on how to use NILUT for multi-style image enhancement and blending. The corresponding training code will be released soon.
1832

19-
<br>
33+
**Dataset** The folder `dataset/` includes 100 images from the Adobe MIT 5K Dataset. The images were processed using professional 3D LUTs on Adobe Lightroom. The structure of the dataset is:
2034

21-
----
35+
```
36+
dataset/
37+
├── 001_blend.png
38+
├── 001_LUT01.png
39+
├── 001_LUT02.png
40+
├── 001_LUT03.png
41+
├── 001_LUT04.png
42+
├── 001_LUT05.png
43+
├── 001_LUT08.png
44+
├── 001_LUT10.png
45+
└── 001.png
46+
...
47+
```
2248

23-
3D lookup tables (3D LUTs) are a key component for image enhancement. Modern image signal processors (ISPs) have dedicated support for these as part of the camera rendering pipeline. Cameras typically provide multiple options for picture styles, where each style is usually obtained by applying a unique handcrafted 3D LUT. Current approaches for learning and applying 3D LUTs are notably fast, yet not so memory-efficient, as storing multiple 3D LUTs is required. For this reason and other implementation limitations, their use on mobile devices is less popular.
49+
where `001.png` is the input unprocessed image, `001_LUTXX.png` is the result of applying each corresponding LUT and `001_blend.png` is the example target for evaluating sytle-blending (in the example the blending is between styles 1,3, and 4 with equal weights 0.33).
50+
The complete dataset includes 100 images `aaa.png` and their enhanced variants for each 3D LUT.
2451

25-
In this work, we propose a Neural Implicit LUT (NILUT), an implicitly defined continuous 3D color transformation parameterized by a neural network. We show that NILUTs are capable of accurately emulating real 3D LUTs. Moreover, a NILUT can be extended to incorporate multiple styles into a single network with the ability to blend styles implicitly. Our novel approach is memory-efficient, controllable and can complement previous methods, including learned ISPs.
2652

2753
----
2854

29-
**Contact** marcos.conde[at]uni-wuerzburg.de
55+
Hope you like it 🤗 If you find this interesting/insightful/inspirational or you use it, do not forget to acknowledge our work:
56+
57+
```
58+
@article{conde2023nilut,
59+
title={NILUT: Conditional Neural Implicit 3D Lookup Tables for Image Enhancement},
60+
author={Conde, Marcos V and Vazquez-Corral, Javier and Brown, Michael S and Timofte, Radu},
61+
journal={arXiv preprint arXiv:2306.11920},
62+
year={2023}
63+
}
64+
```
65+
66+
**Contact** marcos.conde[at]uni-wuerzburg.de
67+

dataloader.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
from torch import nn
3+
from torch.utils.data import DataLoader, Dataset
4+
import numpy as np
5+
6+
from utils import load_img, np_psnr
7+
8+
9+
class EvalMultiLUTBlending (Dataset):
10+
"""
11+
Dataloader to load the input image <inp_img> and the reference target images <list_out_imgs>.
12+
The order of the target images must be: ground-truth 3D LUT outputs (the first <nluts> elements in the list), following by gt blending results.
13+
14+
We will load each reference, and include the corresponding style vector a sinput to the network
15+
Example:
16+
17+
test_images = EvalMultiLUTFitting('./DatasetLUTs_100images/001.png',
18+
['./DatasetLUTs_100images/001_LUT01.png',
19+
'./DatasetLUTs_100images/001_LUT03.png',
20+
'./DatasetLUTs_100images/001_LUT04.png',
21+
'./DatasetLUTs_100images/001_blend.png'], nluts=3)
22+
23+
test_dataloader = DataLoader(test_images, batch_size=1, pin_memory=True, num_workers=0)
24+
"""
25+
26+
def __init__(self, inp_img, list_out_img, nluts):
27+
super().__init__()
28+
29+
self.inp_imgs = load_img(inp_img)
30+
self.out_imgs = []
31+
self.error = []
32+
self.shape = self.inp_imgs.shape
33+
self.nluts = nluts
34+
35+
for fout in list_out_img:
36+
lut = load_img(fout)
37+
assert self.inp_imgs.shape == lut.shape
38+
assert (self.inp_imgs.max() <= 1) and (lut.max() <= 1)
39+
self.out_imgs.append(lut)
40+
self.error.append(np_psnr(self.inp_imgs,lut))
41+
del lut
42+
43+
self.references = len(list_out_img)
44+
45+
def __len__(self):
46+
return self.references
47+
48+
def __getitem__(self, idx):
49+
if idx > self.references: raise IndexError
50+
51+
style_vector = np.zeros(self.nluts).astype(np.float32)
52+
53+
if idx < self.nluts:
54+
style_vector[idx] = 1.
55+
else:
56+
style_vector = np.array([0.33, 0.33, 0.33]).astype(np.float32)
57+
58+
# Convert images to pytorch tensors
59+
img = torch.from_numpy(self.inp_imgs)
60+
lut = torch.from_numpy(self.out_imgs[idx])
61+
62+
img = img.reshape((img.shape[0]*img.shape[1],3)) # [hw, 3]
63+
lut = lut.reshape((lut.shape[0]*lut.shape[1],3)) # [hw, 3]
64+
65+
style_vector = torch.from_numpy(style_vector)
66+
style_vector_re = style_vector.repeat(img.shape[0]).view(img.shape[0],self.nluts)
67+
68+
img = torch.cat([img,style_vector_re], dim=-1)
69+
70+
return img, lut, style_vector

dataset/.gitkeep

Whitespace-only changes.

media/cnilut.png

1.29 MB
Loading

media/header.png

811 KB
Loading
File renamed without changes.

models/nilutx3style.pt

1.55 MB
Binary file not shown.

nilut-multiblend.ipynb

+369
Large diffs are not rendered by default.

requirements.txt

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
imageio==2.30.0
2+
matplotlib==3.7.1
3+
numpy==1.24.3
4+
opencv-python==4.7.0.72
5+
Pillow==9.4.0
6+
scikit-image==0.20.0
7+
scipy==1.10.1
8+
torch==2.0.1
9+
torchaudio==2.0.2
10+
torchvision==0.15.2

utils.py

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
NILUT: Conditional Neural Implicit 3D Lookup Tables for Image Enhancement
3+
4+
Utils for training and ploting
5+
"""
6+
7+
import torch
8+
import cv2
9+
from PIL import Image
10+
import matplotlib.pyplot as plt
11+
import numpy as np
12+
import gc
13+
import time
14+
from skimage import io, color
15+
16+
17+
# Timing utilities
18+
19+
def start_timer():
20+
global start_time
21+
gc.collect()
22+
torch.cuda.empty_cache()
23+
torch.cuda.reset_max_memory_allocated()
24+
torch.cuda.synchronize()
25+
start_time = time.time()
26+
27+
def end_timer_and_print(local_msg):
28+
torch.cuda.synchronize()
29+
end_time = time.time()
30+
print("\n" + local_msg)
31+
print("Total execution time = {:.3f} sec".format(end_time - start_time))
32+
print("Max memory used by tensors = {} bytes".format(torch.cuda.max_memory_allocated()))
33+
34+
def clean_mem():
35+
gc.collect()
36+
torch.cuda.empty_cache()
37+
torch.cuda.reset_max_memory_allocated()
38+
39+
# Model
40+
41+
def count_parameters(model):
42+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
43+
44+
45+
# Load/save and plot images
46+
47+
def load_img (filename, norm=True,):
48+
49+
img = np.array(Image.open(filename))
50+
if norm:
51+
img = img / 255.
52+
img = img.astype(np.float32)
53+
return img
54+
55+
def save_rgb (img, filename):
56+
if np.max(img) <= 1:
57+
img = img * 255
58+
59+
img = img.astype(np.float32)
60+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
61+
62+
cv2.imwrite(filename, img)
63+
64+
def plot_all (images, figsize=(20,10), axis='off'):
65+
fig = plt.figure(figsize=figsize, dpi=80)
66+
nplots = len(images)
67+
for i in range(nplots):
68+
plt.subplot(1,nplots,i+1)
69+
plt.axis(axis)
70+
plt.imshow(images[i])
71+
72+
plt.show()
73+
74+
# Metrics
75+
76+
def np_psnr(y_true, y_pred):
77+
mse = np.mean((y_true - y_pred) ** 2)
78+
if(mse == 0): return np.inf
79+
return 20 * np.log10(1 / np.sqrt(mse))
80+
81+
def pt_psnr (y_true, y_pred):
82+
mse = torch.mean((y_true - y_pred) ** 2)
83+
return 20 * torch.log10(1 / torch.sqrt(mse))
84+
85+
def deltae_dist (y_true, y_pred):
86+
"""
87+
Calcultae DeltaE discance in the LAB color space.
88+
Images must numpy arrays.
89+
"""
90+
91+
gt_lab = color.rgb2lab((y_true*255).astype('uint8'))
92+
out_lab = color.rgb2lab((y_pred*255).astype('uint8'))
93+
l2_lab = ((gt_lab - out_lab)**2).mean()
94+
l2_lab = np.sqrt(((gt_lab - out_lab)**2).sum(axis=-1)).mean()
95+
return l2_lab

0 commit comments

Comments
 (0)