Skip to content

Commit 0fe95cd

Browse files
authored
Add files via upload
still progressing
1 parent 10d35cb commit 0fe95cd

File tree

3 files changed

+146
-0
lines changed

3 files changed

+146
-0
lines changed

Image_preprocess.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Image Classification
2+
import torch
3+
from torchvision.transforms import v2
4+
# Visualization
5+
import matplotlib.pyplot as plt
6+
7+
def image_transform():
8+
9+
H, W = 224, 224
10+
img = torch.randint(0, 256, size=(3, H, W), dtype=torch.uint8)
11+
12+
transforms = v2.Compose([
13+
v2.RandomResizedCrop(size=(224, 224), antialias=True),
14+
v2.RandomPhotometricDistort(p=1),
15+
v2.RandomChannelPermutation() ,# 채널 무작위 변경
16+
v2.RandomHorizontalFlip(p=0.2),
17+
v2.ToDtype(torch.float32, scale=True),
18+
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
19+
])
20+
21+
22+
return transforms
23+
24+
25+
def visualize_data(train_dataloader):
26+
# 이미지와 정답(label)을 표시합니다.
27+
train_features, train_labels = next(iter(train_dataloader))
28+
print(f"Feature batch shape: {train_features.size()}")
29+
print(f"Labels batch shape: {train_labels.size()}")
30+
img = train_features[0].squeeze()
31+
label = train_labels[0]
32+
33+
if img.dim() == 3 and img.size(0) == 3:
34+
img = img.permute(1, 2, 0)
35+
36+
plt.imshow(img, cmap="gray")
37+
plt.show()
38+
print(f"Label: {label}")

custom_image_dataset.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
import pandas as pd
3+
from torch.utils.data import Dataset
4+
from torchvision.io import read_image
5+
6+
7+
class CustomImageDataset():
8+
def __init__(self, csv_file, img_dir, transform=None, target_transform=None):
9+
self.img_file = pd.read_csv(os.path.join(img_dir, csv_file))
10+
self.img_dir = img_dir
11+
self.transform = transform
12+
self.target_transform = target_transform
13+
14+
def __len__(self):
15+
return len(self.img_file)
16+
17+
def __getitem__(self, index):
18+
# 레이블을 기반으로 서브폴더 경로 결정
19+
label = self.img_file.iloc[index, 2]
20+
subfolder = 'Sleep' if label == 'Sleep' else 'Fall'
21+
img_path = os.path.join(self.img_dir, 'train', subfolder, self.img_file.iloc[index, 0])
22+
if(label == "Sleep"):
23+
label = 0
24+
else: label = 1
25+
image = read_image(img_path)
26+
if self.transform:
27+
image = self.transform(image)
28+
if self.target_transform:
29+
label = self.target_transform(label)
30+
31+
return image, label

vit_model.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Pytorch
2+
import sys
3+
import torch
4+
5+
print(torch.__version__)
6+
from torch import nn
7+
import os
8+
from torch.utils.data import DataLoader
9+
from torchvision import datasets
10+
from custom_image_dataset import CustomImageDataset
11+
import argparse
12+
from Image_preprocess import *
13+
14+
device = (
15+
"cuda"
16+
if torch.cuda.is_available()
17+
else "mps" if torch.backends.mps.is_available() else "cpu"
18+
)
19+
print(f"Using {device} device")
20+
21+
22+
class NerualNetwork(nn.Module):
23+
def __init__(self):
24+
super().__init__()
25+
self.flatten = nn.Flatten()
26+
self.linear_relu_stack = nn.Sequential(
27+
nn.Linear(224 * 224, 512),
28+
nn.RuLU(),
29+
nn.Linear(512, 512),
30+
nn.RuLU(),
31+
nn.Linear(512, 2),
32+
)
33+
34+
def forward(self, x):
35+
x = self.flatten(x)
36+
logits = self.linear_relu_stack(x)
37+
return logits
38+
39+
40+
def main(argv):
41+
42+
parser = argparse.ArgumentParser(
43+
description="Fall detection with Vision Transformer(ViT) Model",
44+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
45+
)
46+
parser.add_argument(
47+
"-l",
48+
"--load",
49+
nargs="+",
50+
type=str,
51+
help="python vit_model.py --load train_captions.csv",
52+
)
53+
args = parser.parse_args()
54+
55+
## data load & argumentation
56+
if args.load is not None:
57+
file_path = "C:/Users/Jaeho/OneDrive/바탕 화면/fall detection/dataset/"
58+
data_csv = args.load[0]
59+
print("Loading file...")
60+
transform = image_transform()
61+
train_dataset = CustomImageDataset(
62+
csv_file=data_csv, img_dir=file_path, transform=transform
63+
)
64+
# test_dataset = CustomImageDataset(data_csv, file_path, transform)
65+
train_dataloader = torch.utils.data.DataLoader(
66+
train_dataset, batch_size=1024, shuffle=True, num_workers=4
67+
)
68+
# test_dataloader = torch.utils.data.DataLoader(test_dataset,
69+
# batch_size=1024,
70+
# shuffle=True,
71+
# num_workers=4)
72+
73+
visualize_data(train_dataloader)
74+
75+
76+
if __name__ == "__main__":
77+
main(sys.argv[1:])

0 commit comments

Comments
 (0)