-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathmodel.py
executable file
·86 lines (71 loc) · 3.39 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright 2019 Stanislav Pidhorskyi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import torch
from torch import nn
import random
from net import Generator, Mapping
import numpy as np
class DLatent(nn.Module):
def __init__(self, dlatent_size, layer_count):
super(DLatent, self).__init__()
buffer = torch.zeros(layer_count, dlatent_size, dtype=torch.float32)
self.register_buffer('buff', buffer)
class Model(nn.Module):
def __init__(self, startf=32, maxf=256, layer_count=3, latent_size=128, mapping_layers=5, dlatent_avg_beta=None,
truncation_psi=None, truncation_cutoff=None, style_mixing_prob=None, channels=3):
super(Model, self).__init__()
self.mapping = Mapping(
num_layers=2 * layer_count,
latent_size=latent_size,
dlatent_size=latent_size,
mapping_fmaps=latent_size,
mapping_layers=mapping_layers)
self.generator = Generator(
startf=startf,
layer_count=layer_count,
maxf=maxf,
latent_size=latent_size,
channels=channels)
self.dlatent_avg = DLatent(latent_size, self.mapping.num_layers)
self.latent_size = latent_size
self.dlatent_avg_beta = dlatent_avg_beta
self.truncation_psi = truncation_psi
self.style_mixing_prob = style_mixing_prob
self.truncation_cutoff = truncation_cutoff
def generate(self, lod, remove_blob=True, z=None, count=32):
if z is None:
z = torch.randn(count, self.latent_size)
styles = self.mapping(z)
if self.dlatent_avg_beta is not None:
with torch.no_grad():
batch_avg = styles.mean(dim=0)
self.dlatent_avg.buff.data.lerp_(batch_avg.data, 1.0 - self.dlatent_avg_beta)
if self.style_mixing_prob is not None:
if random.random() < self.style_mixing_prob:
z2 = torch.randn(count, self.latent_size)
styles2 = self.mapping(z2)
layer_idx = torch.arange(self.mapping.num_layers)[np.newaxis, :, np.newaxis]
cur_layers = (lod + 1) * 2
mixing_cutoff = random.randint(1, cur_layers)
styles = torch.where(layer_idx < mixing_cutoff, styles, styles2)
if self.truncation_psi is not None:
layer_idx = torch.arange(self.mapping.num_layers)[np.newaxis, :, np.newaxis]
ones = torch.ones(layer_idx.shape, dtype=torch.float32)
coefs = torch.where(layer_idx < self.truncation_cutoff, self.truncation_psi * ones, ones)
styles = torch.lerp(self.dlatent_avg.buff.data, styles, coefs)
rec = self.generator.forward(styles, lod, remove_blob)
return rec
def forward(self, x, lod, blend_factor, d_train):
return self.generate(x, lod, blend_factor, d_train)