Skip to content

Commit 645457e

Browse files
committed
1.1.1.19124
1 parent ca99efd commit 645457e

File tree

10 files changed

+76
-26
lines changed

10 files changed

+76
-26
lines changed

Generate_image/augmentor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import Augmentor # 图像增强 图像预处理 https://github.com/mdbloice/Augmentor https://augmentor.readthedocs.io/en/master/
2+
import torchvision
3+
4+
p = Augmentor.Pipeline("/home/tian/Desktop/spiders/design/design/spiders/image_test/dog")
5+
6+
p.rotate90(probability=0.5)
7+
p.rotate270(probability=0.5)
8+
p.flip_left_right(probability=0.8)
9+
p.flip_top_bottom(probability=0.3)
10+
p.crop_random(probability=1, percentage_area=0.5)
11+
p.resize(probability=1.0, width=224, height=224)
12+
13+
transforms = torchvision.transforms.Compose([
14+
p.torch_transform(),
15+
torchvision.transforms.ToTensor(),
16+
])

Image_recognition/config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,32 +7,32 @@
77
class DefaultConfig(object):
88
env = 'opalus_recognltion' # visdom 环境
99
vis_port = 8097 # visdom 端口
10+
image_size = 224
1011
model = 'ResNet152' # 使用的模型,名字必须与models/__init__.py中的名字一致
1112

1213
data_root = "/home/tian/Desktop/spiders/design/design/spiders/image" # 数据集存放路径
13-
load_model_path = None # 加载训练的模型的路径,为None代表不加载
14+
# load_model_path = None # 加载训练的模型的路径,为None代表不加载
1415
load_model_path = 'checkpoint/ResNet152_0124_11-57-28.pth.tar'
1516

1617
batch_size = 16 # 每批训练数据的个数,显存不足,适当减少
1718
use_gpu = True # user GPU or not
1819
num_workers = 4 # how many workers for loading data
1920
print_freq = 2 # print info every N batch
20-
vis = True # 是否使用visdom可视化
21+
vis = False # 是否使用visdom可视化
2122

2223
cate_classes = get_classes(data_root)['class2num'] # 分类列表
2324
num_classes = len(cate_classes) # 分类个数
24-
# pretrained = False # 不加载预训练
25-
pretrained = True # 加载预训练模型
26-
result_file = 'result.csv'
25+
pretrained = False # 不加载预训练
26+
# pretrained = True # 加载预训练模型
2727

2828
max_epoch = 10 # 学习次数
2929
lr = 0.001 # initial learning rate
3030
lr_decay = 0.5 # when val_loss increase, lr = lr*lr_decay
3131
weight_decay = 0e-5 # 损失函数
3232
# url = 'https://ss3.bdstatic.com/70cFv8Sh_Q1YnxGkpoWK1HF6hhy/it/u=614134999,3540271868&fm=27&gp=0.jpg' # 识别图片地址
3333
# url = 'https://ss1.bdstatic.com/70cFuXSh_Q1YnxGkpoWK1HF6hhy/it/u=688429408,3192272581&fm=27&gp=0.jpg'
34-
# url = 'https://ss1.bdstatic.com/70cFvXSh_Q1YnxGkpoWK1HF6hhy/it/u=1515206672,3808938099&fm=27&gp=0.jpg'
35-
url = 'https://ss0.bdstatic.com/70cFuHSh_Q1YnxGkpoWK1HF6hhy/it/u=3211343338,3677737612&fm=27&gp=0.jpg'
34+
url = 'https://ss1.bdstatic.com/70cFvXSh_Q1YnxGkpoWK1HF6hhy/it/u=1515206672,3808938099&fm=27&gp=0.jpg'
35+
# url = 'https://ss0.bdstatic.com/70cFuHSh_Q1YnxGkpoWK1HF6hhy/it/u=3211343338,3677737612&fm=27&gp=0.jpg'
3636
# url = 'https://ss0.bdstatic.com/70cFuHSh_Q1YnxGkpoWK1HF6hhy/it/u=1173573129,2720567755&fm=27&gp=0.jpg'
3737

3838
def _parse(self, kwargs):

Image_recognition/data/dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch.utils import data
44
from torchvision import transforms as T
55
from utils.imagefolder_splitter import ImageFolderSplitter
6+
from config import opt
67
import cv2
78

89

@@ -38,16 +39,16 @@ def __init__(self, root, transforms=None, train=True, test=False):
3839
if self.test or not self.train:
3940
# 训练 测试
4041
self.transforms = T.Compose([
41-
T.Resize(224), # #缩放图片(Image),保持长宽比不变,最短边为224像素
42-
T.CenterCrop(224), # 在图片的中间区域进行裁剪
42+
T.Resize(opt.image_size), # #缩放图片(Image),保持长宽比不变,最短边为224像素
43+
T.CenterCrop(opt.image_size), # 在图片的中间区域进行裁剪
4344
T.ToTensor(), # 转tensor
4445
normalize # 归一化
4546
])
4647
else:
4748
# 验证
4849
self.transforms = T.Compose([
4950
T.Resize(256), # #缩放图片(Image),保持长宽比不变,最短边为224像素
50-
T.RandomResizedCrop(224), # 在一个随机的位置进行裁剪
51+
T.RandomResizedCrop(opt.image_size), # 在一个随机的位置进行裁剪
5152
T.RandomHorizontalFlip(),
5253
T.ToTensor(),
5354
normalize

Image_recognition/main.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def recognition(**kwargs):
5555
model.load_state_dict(checkpoint["state_dict"]) # 预加载模型
5656
model.to(opt.device)
5757
model.eval()
58-
image = image.view(1, 3, 224, 224).to(opt.device) # 转换image
58+
image = image.view(1, 3, opt.image_size, opt.image_size).to(opt.device) # 转换image
5959
outputs = model(image)
6060
result = {}
6161
for i in range(opt.num_classes): # 计算各分类比重
@@ -119,17 +119,15 @@ def train(**kwargs):
119119
target = label.to(opt.device)
120120

121121
score = model(input)
122-
loss = criterion(score, target) # 计算损失
123-
122+
# loss = criterion(score, target) # 计算损失
123+
loss = criterion(score[0], target) # 计算损失 Inception3网络
124124
optimizer.zero_grad() # 参数梯度设成0
125125
loss.backward() # 反向传播
126126
optimizer.step() # 更新参数
127127
# meters update and visualize
128-
precision1_train, precision2_train = accuracy(score, target, topk=(1, 2))
128+
# precision1_train, precision2_train = accuracy(score, target, topk=(1, 2))
129+
precision1_train, precision2_train = accuracy(score[0], target, topk=(1, 2)) # Inception3网络
129130
train_losses.update(loss.item(), input.size(0))
130-
a = precision1_train[0]
131-
b = input.size(0)
132-
c = precision1_train[0].item()
133131
train_top1.update(precision1_train[0].item(), input.size(0))
134132
train_progressor.current_loss = train_losses.avg
135133
train_progressor.current_top1 = train_top1.avg

Image_recognition/models/alexnet.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66

7-
def alexnet(pretrained=False, **kwargs):
7+
def alexnet(pretrained=False, **kwargs): # 224*224
88
if pretrained:
99
model = AlexNet(**kwargs)
1010
model.load_state_dict(torch.load('./checkpoint/inception_v3_google-1a9a5a14.pth'))
@@ -21,6 +21,12 @@ def __init__(self):
2121
def forward(self, x):
2222
return self.model(x)
2323

24+
def get_optimizer(self, lr, weight_decay):
25+
if not opt.pretrained:
26+
return super(AlexNet1, self).get_optimizer(lr, weight_decay)
27+
else:
28+
return torch.optim.Adam(self.model.fc.parameters(), lr=lr, weight_decay=weight_decay)
29+
2430

2531
if __name__ == '__main__':
2632
a = AlexNet1()

Image_recognition/models/denseNet201.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ def __init__(self):
3333
def forward(self, x):
3434
return self.model(x)
3535

36+
def get_optimizer(self, lr, weight_decay):
37+
if not opt.pretrained:
38+
return super(DenseNet201, self).get_optimizer(lr, weight_decay)
39+
else:
40+
return torch.optim.Adam(self.model.fc.parameters(), lr=lr, weight_decay=weight_decay)
41+
3642

3743
if __name__ == '__main__':
3844
a = DenseNet201()

Image_recognition/models/densenet161.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ def __init__(self):
3333
def forward(self, x):
3434
return self.model(x)
3535

36+
def get_optimizer(self, lr, weight_decay):
37+
if not opt.pretrained:
38+
return super(DenseNet161, self).get_optimizer(lr, weight_decay)
39+
else:
40+
return torch.optim.Adam(self.model.fc.parameters(), lr=lr, weight_decay=weight_decay)
41+
3642

3743
if __name__ == '__main__':
3844
a = DenseNet161()

Image_recognition/models/inceptionv3.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,22 @@
44
import torch
55

66

7-
def inception_v3(pretrained=False, **kwargs):
7+
def inception_v3(pretrained=False, **kwargs): # 299*299
88
if pretrained:
99
if 'transform_input' not in kwargs:
1010
kwargs['transform_input'] = True
1111
model = Inception3(**kwargs)
12-
model.load_state_dict(torch.load('./checkpoint/inception_v3_google-1a9a5a14.pth'))
13-
12+
pretrained_state_dict = torch.load(
13+
'./Authority/inception_v3_google-1a9a5a14.pth') # load_url函数根据model_urls字典下载或导入相应的预训练模型
14+
now_state_dict = model.state_dict() # 返回model模块的字典
15+
pretrained_state_dict.pop('AuxLogits.fc.weight')
16+
pretrained_state_dict.pop('AuxLogits.fc.bias')
17+
pretrained_state_dict.pop('fc.weight')
18+
pretrained_state_dict.pop('fc.bias')
19+
now_state_dict.update(pretrained_state_dict)
20+
model.load_state_dict(
21+
now_state_dict) # 最后通过调用model的load_state_dict方法用预训练的模型参数来初始化你构建的网络结构,这个方法就是PyTorch中通用的用一个模型的参数初始化另一个模型的层的操作。load_state_dict方法还有一个重要的参数是strict,该参数默认是True,表示预训练模型的层和你的网络结构层严格对应相等(比如层名和维度)
22+
return model
1423
return Inception3(**kwargs)
1524

1625

@@ -23,6 +32,15 @@ def __init__(self):
2332
def forward(self, x):
2433
return self.model(x)
2534

35+
def get_optimizer(self, lr, weight_decay):
36+
if not opt.pretrained:
37+
return super(InceptionV3, self).get_optimizer(lr, weight_decay)
38+
else:
39+
return torch.optim.Adam([
40+
{'params': self.model.AuxLogits.fc.parameters()},
41+
{'params': self.model.fc.parameters()}
42+
], lr=lr, weight_decay=weight_decay)
43+
2644

2745
if __name__ == '__main__':
2846
a = InceptionV3()

Image_recognition/models/resnet152.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,8 @@ def forward(self, x):
4848
def resnet152(pretrained=False, **kwargs):
4949
if pretrained:
5050
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
51-
# model.load_state_dict(torch.load('./checkpoint/resnet152-b121ed2d.pth'))
52-
# 网络结构不对等
5351
pretrained_state_dict = torch.load(
54-
'./checkpoint/resnet152-b121ed2d.pth') # load_url函数根据model_urls字典下载或导入相应的预训练模型
52+
'./Authority/resnet152-b121ed2d.pth') # load_url函数根据model_urls字典下载或导入相应的预训练模型
5553
now_state_dict = model.state_dict() # 返回model模块的字典
5654
pretrained_state_dict.pop('fc.weight')
5755
pretrained_state_dict.pop('fc.bias')

Image_recognition/utils/image_loader.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
from torchvision import transforms as T
33
import requests
44
from PIL import Image
5+
from config import opt
56

67
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
78
std=[0.229, 0.224, 0.225])
89

910
transforms = T.Compose([
10-
T.Resize(224),
11-
T.CenterCrop(224),
11+
T.Resize(opt.image_size),
12+
T.CenterCrop(opt.image_size),
1213
T.ToTensor(),
1314
normalize
1415
])

0 commit comments

Comments
 (0)