Skip to content

Commit a1fc68b

Browse files
committed
add gru
1 parent 3734f8a commit a1fc68b

File tree

3 files changed

+112
-64
lines changed

3 files changed

+112
-64
lines changed

CSL_Skeleton_LSTM.py renamed to CSL_Skeleton_RNN.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch.optim as optim
99
from torch.utils.data import DataLoader, random_split
1010
from torch.utils.tensorboard import SummaryWriter
11-
from models.LSTM import LSTM
11+
from models.RNN import LSTM, GRU
1212
from dataset import CSL_Skeleton
1313
from train import train_epoch
1414
from validation import val_epoch
@@ -39,13 +39,13 @@
3939
num_classes = 100
4040
sample_duration = 16
4141
selected_joints = ['HANDLEFT', 'HANDRIGHT', 'ELBOWLEFT', 'ELBOWRIGHT']
42-
lstm_input_size = len(selected_joints)*2
43-
lstm_hidden_size = 512
44-
lstm_num_layers = 1
42+
input_size = len(selected_joints)*2
43+
hidden_size = 512
44+
num_layers = 1
4545
hidden1 = 512
4646
drop_p = 0.0
4747

48-
# Train with Skeleton+LSTM
48+
# Train with Skeleton+RNN
4949
if __name__ == '__main__':
5050
# Load data
5151
transform = None # TODO
@@ -57,7 +57,9 @@
5757
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
5858
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
5959
# Create model
60-
model = LSTM(lstm_input_size=lstm_input_size, lstm_hidden_size=lstm_hidden_size, lstm_num_layers=lstm_num_layers,
60+
# model = LSTM(lstm_input_size=input_size, lstm_hidden_size=hidden_size, lstm_num_layers=num_layers,
61+
# num_classes=num_classes, hidden1=hidden1, drop_p=drop_p).to(device)
62+
model = GRU(gru_input_size=input_size, gru_hidden_size=hidden_size, gru_num_layers=num_layers,
6163
num_classes=num_classes, hidden1=hidden1, drop_p=drop_p).to(device)
6264
# Run the model parallelly
6365
if torch.cuda.device_count() > 1:

models/LSTM.py

Lines changed: 0 additions & 58 deletions
This file was deleted.

models/RNN.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
"""
6+
Implementation of LSTM
7+
Reference: SIGN LANGUAGE RECOGNITION WITH LONG SHORT-TERM MEMORY
8+
"""
9+
class LSTM(nn.Module):
10+
def __init__(self, lstm_input_size=512, lstm_hidden_size=512, lstm_num_layers=3,
11+
num_classes=100, hidden1=256, drop_p=0.0):
12+
super(LSTM, self).__init__()
13+
# network params
14+
self.lstm_input_size = lstm_input_size
15+
self.lstm_hidden_size = lstm_hidden_size
16+
self.lstm_num_layers = lstm_num_layers
17+
self.num_classes = num_classes
18+
self.hidden1 = hidden1
19+
self.drop_p = drop_p
20+
21+
# network architecture
22+
self.lstm = nn.LSTM(
23+
input_size=self.lstm_input_size,
24+
hidden_size=self.lstm_hidden_size,
25+
num_layers=self.lstm_num_layers,
26+
batch_first=True,
27+
)
28+
self.drop = nn.Dropout2d(p=self.drop_p)
29+
self.fc1 = nn.Linear(self.lstm_hidden_size, self.hidden1)
30+
self.fc2 = nn.Linear(self.hidden1, self.num_classes)
31+
32+
def forward(self, x):
33+
# LSTM
34+
# use faster code paths
35+
self.lstm.flatten_parameters()
36+
# print(x.shape)
37+
# batch first: (batch, seq, feature)
38+
out, (h_n, c_n) = self.lstm(x, None)
39+
# MLP
40+
# out: (batch, seq, feature), choose the last time step
41+
out = F.relu(self.fc1(out[:, -1, :]))
42+
out = F.dropout(out, p=self.drop_p, training=self.training)
43+
out = self.fc2(out)
44+
45+
return out
46+
47+
48+
"""
49+
Implementation of GRU
50+
"""
51+
class GRU(nn.Module):
52+
def __init__(self, gru_input_size=512, gru_hidden_size=512, gru_num_layers=3,
53+
num_classes=100, hidden1=256, drop_p=0.0):
54+
super(GRU, self).__init__()
55+
# network params
56+
self.gru_input_size = gru_input_size
57+
self.gru_hidden_size = gru_hidden_size
58+
self.gru_num_layers = gru_num_layers
59+
self.num_classes = num_classes
60+
self.hidden1 = hidden1
61+
self.drop_p = drop_p
62+
63+
# network architecture
64+
self.gru = nn.GRU(
65+
input_size=self.gru_input_size,
66+
hidden_size=self.gru_hidden_size,
67+
num_layers=self.gru_num_layers,
68+
batch_first=True,
69+
)
70+
self.drop = nn.Dropout2d(p=self.drop_p)
71+
self.fc1 = nn.Linear(self.gru_hidden_size, self.hidden1)
72+
self.fc2 = nn.Linear(self.hidden1, self.num_classes)
73+
74+
def forward(self, x):
75+
# GRU
76+
# use faster code paths
77+
self.gru.flatten_parameters()
78+
# print(x.shape)
79+
# batch first: (batch, seq, feature)
80+
out, hidden = self.gru(x, None)
81+
# MLP
82+
# out: (batch, seq, feature), choose the last time step
83+
out = F.relu(self.fc1(out[:, -1, :]))
84+
out = F.dropout(out, p=self.drop_p, training=self.training)
85+
out = self.fc2(out)
86+
87+
return out
88+
89+
# Test
90+
if __name__ == '__main__':
91+
import sys
92+
sys.path.append("..")
93+
from dataset import CSL_Skeleton
94+
selected_joints = ['HANDLEFT', 'HANDRIGHT', 'ELBOWLEFT', 'ELBOWRIGHT']
95+
dataset = CSL_Skeleton(data_path="/home/haodong/Data/CSL_Isolated/xf500_body_depth_txt",
96+
label_path="/home/haodong/Data/CSL_Isolated/dictionary.txt", selected_joints=selected_joints)
97+
input_size = len(selected_joints)*2
98+
# test LSTM
99+
lstm = LSTM(lstm_input_size=input_size)
100+
print(lstm(dataset[0]['data'].unsqueeze(0)))
101+
102+
# test GRU
103+
gru = GRU(gru_input_size=input_size)
104+
print(gru(dataset[0]['data'].unsqueeze(0)))

0 commit comments

Comments
 (0)