Skip to content

Commit 8c64cb5

Browse files
committed
first commit from local
0 parents  commit 8c64cb5

14 files changed

+339
-0
lines changed

.idea/.gitignore

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/inspectionProfiles/Project_Default.xml

Lines changed: 19 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/inspectionProfiles/profiles_settings.xml

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/modules.xml

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/pyqt_pytorch_image_classification_gui.iml

Lines changed: 8 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

__pycache__/imageView.cpython-311.pyc

3.42 KB
Binary file not shown.
3.26 KB
Binary file not shown.

__pycache__/script.cpython-311.pyc

5.76 KB
Binary file not shown.

imageView.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
3+
from PyQt5.QtCore import Qt, QRectF
4+
from PyQt5.QtGui import QPixmap, QImage
5+
from PyQt5.QtWidgets import QGraphicsScene, QGraphicsView
6+
7+
8+
class ImageView(QGraphicsView):
9+
def __init__(self):
10+
super().__init__()
11+
self.__aspectRatioMode = Qt.KeepAspectRatio
12+
self.__gradient_enabled = False
13+
self.__initVal()
14+
15+
def __initVal(self):
16+
self._scene = QGraphicsScene()
17+
self._p = QPixmap()
18+
self._item = ''
19+
20+
def displayPillowImage(self, image):
21+
img_array = np.array(image)
22+
23+
# Convert NumPy array to QImage
24+
if img_array.ndim == 3:
25+
h, w, ch = img_array.shape
26+
bytesPerLine = ch * w
27+
qim = QImage(img_array.data, w, h, bytesPerLine, QImage.Format_RGB888)
28+
else:
29+
raise ValueError("Unsupported image dimension: {}".format(img_array.ndim))
30+
31+
pixmap = QPixmap.fromImage(qim)
32+
33+
self._scene.clear()
34+
35+
self._scene.addPixmap(pixmap)
36+
37+
self._scene.setSceneRect(QRectF(0, 0, pixmap.width(), pixmap.height()))
38+
self.setScene(self._scene)
39+
self.fitInView(self.sceneRect(), self.__aspectRatioMode)
40+
41+
def setAspectRatioMode(self, mode):
42+
self.__aspectRatioMode = mode
43+
44+
def resizeEvent(self, e):
45+
self.fitInView(self.sceneRect(), self.__aspectRatioMode)
46+
return super().resizeEvent(e)

loadingLbl.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from PyQt5.QtWidgets import QLabel
2+
from PyQt5.QtCore import Qt, QTimer
3+
4+
5+
class LoadingLabel(QLabel):
6+
def __init__(self):
7+
super(LoadingLabel, self).__init__()
8+
self.__initVal()
9+
self.__initUi()
10+
11+
def __initVal(self):
12+
self.__default_text = 'Wait'
13+
14+
def __initUi(self):
15+
self.__timer = QTimer(self)
16+
self.setText(self.__default_text)
17+
self.setVisible(False)
18+
self.setAlignment(Qt.AlignVCenter | Qt.AlignCenter)
19+
20+
def __timerInit(self):
21+
self.__timer.timeout.connect(self.__ticking)
22+
self.__timer.singleShot(0, self.__ticking)
23+
self.__timer.start(500)
24+
25+
def __ticking(self):
26+
dot = '.'
27+
cur_text = self.text()
28+
cnt = cur_text.count(dot)
29+
if cnt % 3 == 0 and cnt != 0:
30+
self.setText(self.__default_text + dot)
31+
else:
32+
self.setText(cur_text + dot)
33+
34+
def start(self):
35+
self.setVisible(True)
36+
self.__timerInit()
37+
38+
def stop(self):
39+
self.setVisible(False)
40+
self.__timer.stop()

main.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import os, sys
2+
3+
from imageView import ImageView
4+
from loadingLbl import LoadingLabel
5+
from script import ImagePredictor
6+
7+
# Get the absolute path of the current script file
8+
script_path = os.path.abspath(__file__)
9+
10+
# Get the root directory by going up one level from the script directory
11+
project_root = os.path.dirname(os.path.dirname(script_path))
12+
13+
sys.path.insert(0, project_root)
14+
sys.path.insert(0, os.getcwd()) # Add the current directory as well
15+
16+
from PyQt5.QtWidgets import QMainWindow, QPushButton, QApplication, QLineEdit, QVBoxLayout, QLabel, QWidget
17+
from PyQt5.QtCore import Qt, QCoreApplication, QThread, pyqtSignal
18+
from PyQt5.QtGui import QFont
19+
20+
QApplication.setAttribute(Qt.AA_EnableHighDpiScaling)
21+
QCoreApplication.setAttribute(Qt.AA_UseHighDpiPixmaps) # HighDPI support
22+
23+
QApplication.setFont(QFont('Arial', 12))
24+
25+
26+
class Thread(QThread):
27+
generateFinished = pyqtSignal(str)
28+
29+
def __init__(self, image, pred: ImagePredictor):
30+
super(Thread, self).__init__()
31+
self.__image = image
32+
self.__pred = pred
33+
34+
def run(self):
35+
try:
36+
self.generateFinished.emit(self.__pred.predict_image(self.__image))
37+
except Exception as e:
38+
raise Exception(e)
39+
40+
41+
class MainWindow(QMainWindow):
42+
def __init__(self):
43+
super(MainWindow, self).__init__()
44+
self.__initVal()
45+
self.__initUi()
46+
47+
def __initVal(self):
48+
model_path = 'result.pth'
49+
self.__pred = ImagePredictor(model_path)
50+
51+
def __initUi(self):
52+
self.setWindowTitle('PyTorch Image Classification')
53+
54+
self.__urlLineEdit = QLineEdit()
55+
self.__urlLineEdit.setPlaceholderText('Input the URL...')
56+
self.__urlLineEdit.textChanged.connect(self.__urlChanged)
57+
58+
self.__view = ImageView()
59+
60+
self.__runBtn = QPushButton('Run')
61+
self.__runBtn.setEnabled(False)
62+
self.__runBtn.clicked.connect(self.__run)
63+
64+
self.__waitLbl = LoadingLabel()
65+
self.__waitLbl.setVisible(False)
66+
67+
self.__resultLbl = QLabel()
68+
self.__resultLbl.setAlignment(Qt.AlignCenter)
69+
self.__resultLbl.setVisible(False)
70+
71+
lay = QVBoxLayout()
72+
lay.addWidget(self.__urlLineEdit)
73+
lay.addWidget(self.__view)
74+
lay.addWidget(self.__runBtn)
75+
lay.addWidget(self.__waitLbl)
76+
lay.addWidget(self.__resultLbl)
77+
78+
mainWidget = QWidget()
79+
mainWidget.setLayout(lay)
80+
81+
self.setCentralWidget(mainWidget)
82+
83+
def __urlChanged(self, url):
84+
self.__runBtn.setEnabled(url.strip() != '')
85+
86+
def __run(self):
87+
image_url = self.__urlLineEdit.text()
88+
image = self.__pred.get_image_from_url(image_url)
89+
self.__view.displayPillowImage(image)
90+
self.__t = Thread(image, self.__pred)
91+
self.__t.started.connect(self.__started)
92+
self.__t.generateFinished.connect(self.__generateFinished)
93+
self.__t.finished.connect(self.__finished)
94+
self.__t.start()
95+
96+
def __started(self):
97+
self.__waitLbl.setVisible(True)
98+
self.__runBtn.setEnabled(False)
99+
100+
def __generateFinished(self, result):
101+
self.__resultLbl.setText(result)
102+
103+
def __finished(self):
104+
self.__waitLbl.setVisible(False)
105+
self.__resultLbl.setVisible(True)
106+
self.__runBtn.setEnabled(True)
107+
108+
109+
if __name__ == "__main__":
110+
import sys
111+
112+
app = QApplication(sys.argv)
113+
w = MainWindow()
114+
w.show()
115+
sys.exit(app.exec())

result.pth

15.2 MB
Binary file not shown.

script.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import torchvision.transforms as transforms
5+
6+
classes = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
7+
8+
# Define the model
9+
class Net(nn.Module):
10+
def __init__(self, num_classes):
11+
super(Net, self).__init__()
12+
self.data_augmentation = nn.Sequential(
13+
transforms.RandomHorizontalFlip(),
14+
transforms.RandomRotation(10),
15+
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1))
16+
)
17+
self.model = nn.Sequential(
18+
nn.Conv2d(3, 16, 3, padding=1),
19+
nn.ReLU(),
20+
nn.MaxPool2d(2),
21+
nn.Conv2d(16, 32, 3, padding=1),
22+
nn.ReLU(),
23+
nn.MaxPool2d(2),
24+
nn.Conv2d(32, 64, 3, padding=1),
25+
nn.ReLU(),
26+
nn.MaxPool2d(2),
27+
nn.Dropout(0.2),
28+
nn.Flatten(),
29+
nn.Linear(64 * 22 * 22, 128),
30+
nn.ReLU(),
31+
nn.Linear(128, num_classes)
32+
)
33+
34+
def forward(self, x):
35+
x = self.data_augmentation(x)
36+
return self.model(x)
37+
38+
39+
class ImagePredictor:
40+
def __init__(self, model_path):
41+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
42+
self.model = self.load_model(model_path)
43+
self.transform = self.load_transform()
44+
45+
def load_model(self, model_path):
46+
model = Net(len(classes)).to(self.device)
47+
model.load_state_dict(torch.load(model_path))
48+
return model
49+
50+
def load_transform(self):
51+
img_height = 180
52+
img_width = 180
53+
transform = transforms.Compose([
54+
transforms.Lambda(lambda img: img.convert('RGB')),
55+
transforms.Resize((img_height, img_width)), # 이미지 크기 조정
56+
transforms.ToTensor() # PIL 이미지를 텐서로 변환
57+
])
58+
return transform
59+
60+
def get_image_from_url(self, image_url):
61+
import requests
62+
from PIL import Image
63+
from io import BytesIO
64+
65+
response = requests.get(image_url)
66+
image = Image.open(BytesIO(response.content))
67+
return image
68+
69+
def predict_image(self, image):
70+
image = self.transform(image)
71+
image = image.unsqueeze(0) # 배치 차원 추가
72+
image = image.to(self.device)
73+
#
74+
output = self.model(image)
75+
_, predicted = torch.max(output, 1)
76+
77+
prob = F.softmax(output, dim=1)[0] * 100
78+
79+
prob_res = round(prob[predicted[0]].item(), 2)
80+
81+
return f'<span style="color: blue">Predicted: {classes[predicted[0]]}</span><br/>' \
82+
f'Percent: {prob_res}'
83+
84+
# pred = ImagePredictor('result.pth')
85+
# image = pred.get_image_from_url('https://www.health.com/thmb/AADrlQdpWITCjFjKnfBnqWy5A8w=/2153x0/filters:no_upscale():max_bytes(150000):strip_icc()/Dandelion-d5aed7a95a6f4b16a3e954aa78694626.jpg')
86+
# print(pred.predict_image(image))
87+
# image = pred.get_image_from_url('https://ucarecdn.com/8b756a96-8495-4d00-9201-601d6b49c700/')
88+
# print(pred.predict_image(image))
89+
# image = pred.get_image_from_url('https://www.bolster.eu/media/images/5460_dbweb.jpg?1549350221')
90+
# print(pred.predict_image(image))

0 commit comments

Comments
 (0)