Skip to content

Commit 655e5f1

Browse files
committed
1.2.1.190125
1 parent 645457e commit 655e5f1

19 files changed

+84
-56
lines changed
File renamed without changes.

Generate_image/main.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
1-
# coding:utf8
21
import os
3-
import ipdb
42
import torch as t
53
import torchvision as tv
64
import tqdm
7-
from model import NetG, NetD
5+
from models import NetG, NetD
86
from torchnet.meter import AverageValueMeter
97
from config import opt
108

119

1210
def train(**kwargs):
1311
opt._parse(kwargs)
1412
if opt.vis:
15-
from visualize import Visualizer
13+
from utils.visualize import Visualizer
1614
vis = Visualizer(opt.env)
1715

1816
# 数据

Generate_image/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .model import NetG, NetD

Generate_image/model.py renamed to Generate_image/models/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# coding:utf8
21
from torch import nn
32

43

File renamed without changes.

Image_recognition/Authority/.gitkeep

Whitespace-only changes.

Image_recognition/config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,22 @@ class DefaultConfig(object):
88
env = 'opalus_recognltion' # visdom 环境
99
vis_port = 8097 # visdom 端口
1010
image_size = 224
11-
model = 'ResNet152' # 使用的模型,名字必须与models/__init__.py中的名字一致
11+
model = 'AlexNet1' # 使用的模型,名字必须与models/__init__.py中的名字一致
1212

1313
data_root = "/home/tian/Desktop/spiders/design/design/spiders/image" # 数据集存放路径
1414
# load_model_path = None # 加载训练的模型的路径,为None代表不加载
15-
load_model_path = 'checkpoint/ResNet152_0124_11-57-28.pth.tar'
15+
load_model_path = 'checkpoint/AlexNet1_0125_18-08-46.pth.tar'
1616

1717
batch_size = 16 # 每批训练数据的个数,显存不足,适当减少
1818
use_gpu = True # user GPU or not
1919
num_workers = 4 # how many workers for loading data
2020
print_freq = 2 # print info every N batch
21-
vis = False # 是否使用visdom可视化
21+
vis = True # 是否使用visdom可视化
2222

2323
cate_classes = get_classes(data_root)['class2num'] # 分类列表
2424
num_classes = len(cate_classes) # 分类个数
25-
pretrained = False # 不加载预训练
26-
# pretrained = True # 加载预训练模型
25+
# pretrained = False # 不加载预训练
26+
pretrained = True # 加载预训练模型
2727

2828
max_epoch = 10 # 学习次数
2929
lr = 0.001 # initial learning rate

Image_recognition/data/dataset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# coding:utf8
21
from PIL import Image
32
from torch.utils import data
43
from torchvision import transforms as T
@@ -63,9 +62,9 @@ def __getitem__(self, index):
6362
# img = cv2.imread(img_path)
6463
# img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
6564
data = Image.open(img_path)
66-
data = data.convert("RGB")
65+
data = data.convert("RGB") # 有4通道图片转化为长通道
6766
data = self.transforms(data)
68-
return data, label
67+
return data, label # 返回数据级标签
6968

7069
def __len__(self):
7170
return len(self.imgs)

Image_recognition/main.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test(**kwargs):
2020
model = getattr(models, opt.model)()
2121
if opt.load_model_path:
2222
checkpoint = t.load(opt.load_model_path)
23-
model.load_state_dict(checkpoint["state_dict"])
23+
model.load_state_dict(checkpoint["state_dict"]) # 加载模型
2424
model.to(opt.device)
2525
model.eval() # 把module设成测试模式,对Dropout和BatchNorm有影响
2626
# data
@@ -46,13 +46,13 @@ def test(**kwargs):
4646

4747

4848
def recognition(**kwargs):
49-
with t.no_grad():
50-
opt._parse(kwargs) # 用来标志计算要被计算图隔离出去
49+
with t.no_grad(): # 用来标志计算要被计算图隔离出去
50+
opt._parse(kwargs)
5151
image = image_loader(opt.url)
5252
model = getattr(models, opt.model)()
5353
if opt.load_model_path:
5454
checkpoint = t.load(opt.load_model_path)
55-
model.load_state_dict(checkpoint["state_dict"]) # 预加载模型
55+
model.load_state_dict(checkpoint["state_dict"]) # 加载模型
5656
model.to(opt.device)
5757
model.eval()
5858
image = image.view(1, 3, opt.image_size, opt.image_size).to(opt.device) # 转换image
@@ -68,7 +68,7 @@ def train(**kwargs):
6868
if opt.vis:
6969
vis = Visualizer(opt.env, port=opt.vis_port) # 开启visdom 可视化
7070
previous_loss = 1e10 # 上次学习的loss
71-
best_precision = 0
71+
best_precision = 0 # 最好的精确度
7272
start_epoch = 0
7373
lr = opt.lr
7474
# step1: criterion and optimizer
@@ -119,14 +119,14 @@ def train(**kwargs):
119119
target = label.to(opt.device)
120120

121121
score = model(input)
122-
# loss = criterion(score, target) # 计算损失
123-
loss = criterion(score[0], target) # 计算损失 Inception3网络
122+
loss = criterion(score, target) # 计算损失
123+
# loss = criterion(score[0], target) # 计算损失 Inception3网络
124124
optimizer.zero_grad() # 参数梯度设成0
125125
loss.backward() # 反向传播
126126
optimizer.step() # 更新参数
127127
# meters update and visualize
128-
# precision1_train, precision2_train = accuracy(score, target, topk=(1, 2))
129-
precision1_train, precision2_train = accuracy(score[0], target, topk=(1, 2)) # Inception3网络
128+
precision1_train, precision2_train = accuracy(score, target, topk=(1, 2))
129+
# precision1_train, precision2_train = accuracy(score[0], target, topk=(1, 2)) # Inception3网络
130130
train_losses.update(loss.item(), input.size(0))
131131
train_top1.update(precision1_train[0].item(), input.size(0))
132132
train_progressor.current_loss = train_losses.avg
@@ -137,11 +137,11 @@ def train(**kwargs):
137137
else:
138138
print('loss', train_losses.val)
139139
train_progressor()
140-
# train_progressor.done() #
140+
# train_progressor.done() # 保存训练结果为txt
141141
# validate and visualize
142142
valid_loss = val(model, epoch, criterion, val_dataloader) # 校验模型
143143
best_precision = valid_loss[1]
144-
# is_best = valid_loss[1] > best_precision # 准确率比较,如果此次比上次大  保存模型
144+
# is_best = valid_loss[1] > best_precision # 精确度比较,如果此次比上次大  保存模型
145145
# best_precision = max(valid_loss[1], best_precision)
146146
# if is_best:
147147
model.save({
@@ -189,7 +189,7 @@ def val(model, epoch, criterion, dataloader):
189189
val_progressor.current_top1 = top1.avg
190190
val_progressor()
191191

192-
# val_progressor.done()
192+
# val_progressor.done() # 保存校验结果为txt
193193
return [losses.avg, top1.avg]
194194

195195

Image_recognition/models/alexnet.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,20 @@
22
from config import opt
33
from torchvision.models import AlexNet
44
import torch
5+
import torch.nn as nn
56

67

7-
def alexnet(pretrained=False, **kwargs): # 224*224
8+
def alexnet(pretrained=False, **kwargs): # 224*224
89
if pretrained:
910
model = AlexNet(**kwargs)
10-
model.load_state_dict(torch.load('./checkpoint/inception_v3_google-1a9a5a14.pth'))
11+
pretrained_state_dict = torch.load(
12+
'./Authority/alexnet-owt-4df8aa71.pth')
13+
now_state_dict = model.state_dict() # 返回model模块的字典
14+
pretrained_state_dict.pop('classifier.6.weight')
15+
pretrained_state_dict.pop('classifier.6.bias')
16+
now_state_dict.update(pretrained_state_dict)
17+
model.load_state_dict(
18+
now_state_dict)
1119
return model
1220
return AlexNet(**kwargs)
1321

@@ -25,9 +33,10 @@ def get_optimizer(self, lr, weight_decay):
2533
if not opt.pretrained:
2634
return super(AlexNet1, self).get_optimizer(lr, weight_decay)
2735
else:
28-
return torch.optim.Adam(self.model.fc.parameters(), lr=lr, weight_decay=weight_decay)
36+
return torch.optim.Adam(self.model.classifier[6].parameters(), lr=lr, weight_decay=weight_decay)
2937

3038

3139
if __name__ == '__main__':
3240
a = AlexNet1()
33-
print(a)
41+
for i in a.model.classifier[6]:
42+
print(i)

Image_recognition/models/denseNet201.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,23 @@ def densenet201(pretrained=False, **kwargs):
1111
**kwargs)
1212
pattern = re.compile(
1313
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
14-
state_dict = torch.load('./checkpoint/inception_v3_google-1a9a5a14.pth')
15-
for key in list(state_dict.keys()):
14+
pretrained_state_dict = torch.load(
15+
'./Authority/densenet201-c1103571.pth') # load_url函数根据model_urls字典下载或导入相应的预训练模型
16+
for key in list(pretrained_state_dict.keys()):
1617
res = pattern.match(key)
1718
if res:
1819
new_key = res.group(1) + res.group(2)
19-
state_dict[new_key] = state_dict[key]
20-
del state_dict[key]
21-
model.load_state_dict(state_dict)
20+
pretrained_state_dict[new_key] = pretrained_state_dict[key]
21+
del pretrained_state_dict[key]
22+
now_state_dict = model.state_dict() # 返回model模块的字典
23+
pretrained_state_dict.pop('classifier.weight')
24+
pretrained_state_dict.pop('classifier.bias')
25+
now_state_dict.update(pretrained_state_dict)
26+
model.load_state_dict(
27+
now_state_dict)
28+
# 最后通过调用model的load_state_dict方法用预训练的模型参数来初始化你构建的网络结构,
29+
# 这个方法就是PyTorch中通用的用一个模型的参数初始化另一个模型的层的操作。load_state_dict方法还有一个重要的参数是strict,
30+
# 该参数默认是True,表示预训练模型的层和你的网络结构层严格对应相等(比如层名和维度)
2231
return model
2332
return DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
2433
**kwargs)
@@ -37,7 +46,7 @@ def get_optimizer(self, lr, weight_decay):
3746
if not opt.pretrained:
3847
return super(DenseNet201, self).get_optimizer(lr, weight_decay)
3948
else:
40-
return torch.optim.Adam(self.model.fc.parameters(), lr=lr, weight_decay=weight_decay)
49+
return torch.optim.Adam(self.model.classifier.parameters(), lr=lr, weight_decay=weight_decay)
4150

4251

4352
if __name__ == '__main__':

Image_recognition/models/densenet161.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,23 @@ def densenet161(pretrained=False, **kwargs):
1111
**kwargs)
1212
pattern = re.compile(
1313
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
14-
state_dict = torch.load('./checkpoint/inception_v3_google-1a9a5a14.pth')
15-
for key in list(state_dict.keys()):
14+
pretrained_state_dict = torch.load(
15+
'./Authority/densenet161-8d451a50.pth') # load_url函数根据model_urls字典下载或导入相应的预训练模型
16+
for key in list(pretrained_state_dict.keys()):
1617
res = pattern.match(key)
1718
if res:
1819
new_key = res.group(1) + res.group(2)
19-
state_dict[new_key] = state_dict[key]
20-
del state_dict[key]
21-
model.load_state_dict(state_dict)
20+
pretrained_state_dict[new_key] = pretrained_state_dict[key]
21+
del pretrained_state_dict[key]
22+
now_state_dict = model.state_dict() # 返回model模块的字典
23+
pretrained_state_dict.pop('classifier.weight')
24+
pretrained_state_dict.pop('classifier.bias')
25+
now_state_dict.update(pretrained_state_dict)
26+
model.load_state_dict(
27+
now_state_dict)
28+
# 最后通过调用model的load_state_dict方法用预训练的模型参数来初始化你构建的网络结构,
29+
# 这个方法就是PyTorch中通用的用一个模型的参数初始化另一个模型的层的操作。load_state_dict方法还有一个重要的参数是strict,
30+
# 该参数默认是True,表示预训练模型的层和你的网络结构层严格对应相等(比如层名和维度)
2231
return model
2332
return DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
2433
**kwargs)

Image_recognition/models/inceptionv3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ def inception_v3(pretrained=False, **kwargs): # 299*299
1818
pretrained_state_dict.pop('fc.bias')
1919
now_state_dict.update(pretrained_state_dict)
2020
model.load_state_dict(
21-
now_state_dict) # 最后通过调用model的load_state_dict方法用预训练的模型参数来初始化你构建的网络结构,这个方法就是PyTorch中通用的用一个模型的参数初始化另一个模型的层的操作。load_state_dict方法还有一个重要的参数是strict,该参数默认是True,表示预训练模型的层和你的网络结构层严格对应相等(比如层名和维度)
21+
now_state_dict)
22+
# 最后通过调用model的load_state_dict方法用预训练的模型参数来初始化你构建的网络结构,
23+
# 这个方法就是PyTorch中通用的用一个模型的参数初始化另一个模型的层的操作。load_state_dict方法还有一个重要的参数是strict,
24+
# 该参数默认是True,表示预训练模型的层和你的网络结构层严格对应相等(比如层名和维度)
2225
return model
2326
return Inception3(**kwargs)
2427

Image_recognition/models/resnet152.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,14 @@ def resnet152(pretrained=False, **kwargs):
5151
pretrained_state_dict = torch.load(
5252
'./Authority/resnet152-b121ed2d.pth') # load_url函数根据model_urls字典下载或导入相应的预训练模型
5353
now_state_dict = model.state_dict() # 返回model模块的字典
54-
pretrained_state_dict.pop('fc.weight')
54+
pretrained_state_dict.pop('fc.weight') # 排除全连接层的参数(全连接层返回分类个数)
5555
pretrained_state_dict.pop('fc.bias')
5656
now_state_dict.update(pretrained_state_dict)
5757
model.load_state_dict(
58-
now_state_dict) # 最后通过调用model的load_state_dict方法用预训练的模型参数来初始化你构建的网络结构,这个方法就是PyTorch中通用的用一个模型的参数初始化另一个模型的层的操作。load_state_dict方法还有一个重要的参数是strict,该参数默认是True,表示预训练模型的层和你的网络结构层严格对应相等(比如层名和维度)
58+
now_state_dict)
59+
# 最后通过调用model的load_state_dict方法用预训练的模型参数来初始化你构建的网络结构,
60+
# 这个方法就是PyTorch中通用的用一个模型的参数初始化另一个模型的层的操作。load_state_dict方法还有一个重要的参数是strict,
61+
# 该参数默认是True,表示预训练模型的层和你的网络结构层严格对应相等(比如层名和维度)
5962
return model
6063
return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
6164

Image_recognition/utils/get_classes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22

3+
34
# 获取目录下所有分类和图片数据
45
def get_classes(path):
56
class2num = {}

Image_recognition/utils/imagefolder_splitter.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,21 @@ class ImageFolderSplitter:
66
def __init__(self, path, train_size=0.8):
77
self.path = path
88
self.train_size = train_size
9-
self.x_train = []
10-
self.x_valid = []
11-
self.y_train = []
12-
self.y_valid = []
9+
self.x_train = [] # 训练图片
10+
self.x_valid = [] # 训练标签
11+
self.y_train = [] # 测试图片
12+
self.y_valid = [] # 测试标签
1313
self.data_x_path = get_classes(path)['data_x_path']
1414
self.data_y_label = get_classes(path)['data_y_label']
15-
# 80%的训练集,20%的测试机集
15+
# 随机80%的训练集和20%的测试集
1616
self.x_train, self.x_valid, self.y_train, self.y_valid = train_test_split(self.data_x_path, self.data_y_label,
1717
shuffle=True,
1818
train_size=self.train_size)
1919

20-
def getTrainingDataset(self):
20+
def getTrainingDataset(self): # 返回训练级
2121
return self.x_train, self.y_train
2222

23-
def getValidationDataset(self):
23+
def getValidationDataset(self): # 返回测试集
2424
return self.x_valid, self.y_valid
2525

2626

Image_recognition/utils/progress_bar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sys
22

3-
3+
# 进度条
44
class ProgressBar(object):
55
DEFAULT = "Progress: %(bar)s %(percent)3d%%"
66

Image_recognition/utils/utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
import shutil
21
import time
32

43
import torch
54
from config import opt
6-
import os
75

86

9-
# 仪表
7+
# 仪表盘
108
class AverageMeter(object):
119
"""Computes and stores the average and current value"""
1210

@@ -28,7 +26,6 @@ def update(self, val, n=1):
2826

2927
# 准确率
3028
def accuracy(output, target, topk=(1,)):
31-
"""Computes the accuracy over the k top predictions for the specified values of k"""
3229
with torch.no_grad():
3330
maxk = max(topk)
3431
batch_size = target.size(0)

Image_recognition/utils/visualize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
import numpy as np
55

6-
6+
# visdom 可视化工具
77
class Visualizer(object):
88
"""
99
封装了visdom的基本操作,但是你仍然可以通过`self.vis.function`

0 commit comments

Comments
 (0)