Skip to content

Commit 5d554cf

Browse files
committed
commit
1 parent 7ad692b commit 5d554cf

File tree

2 files changed

+73
-22
lines changed

2 files changed

+73
-22
lines changed

LSTMNet.cpp

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@ LSTMNet::LSTMNet(int memCells, int inputVecSize) {
1414
noOfIns = 0;
1515
}
1616

17-
LSTMNet::LSTMNet(const LSTMNet& orig) {
18-
}
17+
LSTMNet::LSTMNet(const LSTMNet& orig) { }
1918

20-
LSTMNet::~LSTMNet() {
21-
}
19+
LSTMNet::~LSTMNet() { }
2220

2321
int LSTMNet::forward(std::vector<double> * input, int timeSteps) {
2422

@@ -150,7 +148,6 @@ int LSTMNet::train(std::vector<double> * input, std::vector<double> output, int
150148

151149
std::vector<double> *inVec;
152150

153-
154151
for (int i = 0; i < trainingIterations; i++){
155152

156153
inVec = input + (timeSteps*i);
@@ -159,8 +156,6 @@ int LSTMNet::train(std::vector<double> * input, std::vector<double> output, int
159156
std::vector<double>::const_iterator last = output.begin() + (timeSteps*i + timeSteps);
160157
std::vector<double> outVec(first, last);
161158

162-
// printVector(outVec);
163-
164159
forward(inVec,timeSteps);
165160
backward(outVec,timeSteps);
166161

@@ -185,23 +180,19 @@ int LSTMNet::train(std::vector<double> * input, std::vector<double> output, int
185180
oDeltaWeightVecArr[p].at(wPos) += *it * delta_o_t;
186181
wPos++;
187182
}
188-
189183
delta_bias_a_t += delta_a_t;
190184
delta_bias_i_t += delta_i_t;
191185
delta_bias_f_t += delta_f_t;
192186
delta_bias_o_t += delta_o_t;
193187
}
194-
195188
aBiasArr[p] -= (delta_bias_a_t * learningRate);
196189
iBiasArr[p] -= (delta_bias_i_t * learningRate);
197190
fBiasArr[p] -= (delta_bias_f_t * learningRate);
198191
oBiasArr[p] -= (delta_bias_o_t * learningRate);
199192

200193
}
201-
202-
194+
203195
index += timeSteps;
204-
205196
for(int j = 0; j < memCells; j++) {
206197

207198
std::transform(
@@ -248,10 +239,8 @@ int LSTMNet::train(std::vector<double> * input, std::vector<double> output, int
248239
oWeightVecArr[j].begin(), oWeightVecArr[j].end(),
249240
oDeltaWeightVecArr[j].begin(), oWeightVecArr[j].begin(),
250241
std::minus<double>()
251-
);
252-
253-
}
254-
242+
);
243+
}
255244
clearVectors();
256245
}
257246
return 0;
@@ -405,7 +394,6 @@ int LSTMNet::clearVectors() {
405394
memCellOutArr[i].clear();
406395
memCellOutArr[i].push_back(out);
407396
}
408-
409397
return 0;
410398
}
411399

@@ -424,7 +412,6 @@ double LSTMNet::predict(std::vector<double> * input) {
424412
result += *(memCellOutArr[i].end()-1);
425413
}
426414

427-
428415
// output2.push_back(result);
429416
// if (noOfIns == timeSteps) {
430417
// printVector(memCellOutArr[0]);
@@ -434,8 +421,7 @@ double LSTMNet::predict(std::vector<double> * input) {
434421
// noOfIns = 0;
435422
// train(input2, output, timeSteps, timeSteps, 0.095);
436423
// }
437-
438-
424+
439425
return result;
440426
}
441427

LSTMNet.h

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,29 +25,93 @@ class LSTMNet {
2525
LSTMNet(const LSTMNet& orig);
2626
virtual ~LSTMNet();
2727

28+
/**
29+
*
30+
* @param input: training data set
31+
* @param output: target values
32+
* @param trainDataSize: training data size
33+
* @param timeSteps: unfolding time steps
34+
* @param learningRate
35+
* @param iterations: training iterations
36+
* @return 0
37+
*/
2838
int train(std::vector<double> * input, std::vector<double> output, int trainDataSize, int timeSteps, float learningRate, int iterations);
39+
/**
40+
* predict a future point using a input vector of size n
41+
*
42+
* input: {t-n,...,t-2,t-1,t}
43+
* result: {t+1}
44+
*
45+
* @param input: input vector for he prediction
46+
* @return result: predicted value
47+
*/
2948
double predict(std::vector<double> * input);
3049

3150

3251
private:
52+
/**
53+
* Forward Propagation of the network
54+
*
55+
* @param input: input vector
56+
* @param timeSteps: unfolded time steps in the input vector
57+
* @return 0
58+
*/
3359
int forward(std::vector<double> * input, int timeSteps);
60+
/**
61+
* Backward Propagation of the network
62+
*
63+
* @param output: output from the forward propagation
64+
* @param timeSteps: unfolded time steps
65+
* @return 0
66+
*/
3467
int backward(std::vector<double> output, int timeSteps);
35-
double sigmoid(double x);
36-
std::vector<double> sigmoid(std::vector<double> x);
68+
/**
69+
* Initialize the weights and bias values for the gates
70+
* Random initialization
71+
*
72+
* @return 0
73+
*/
3774
int initWeights();
75+
/**
76+
* Clear Vectors
77+
*
78+
* @return 0
79+
*/
3880
int clearVectors();
81+
/**
82+
* print the given vector
83+
*
84+
* @param vec: vector
85+
* @return 0
86+
*/
3987
int printVector(std::vector<double> vec);
88+
/**
89+
* Sigmoid function
90+
*
91+
* @param x
92+
* @return value
93+
*/
94+
double sigmoid(double x);
95+
/**
96+
* Sigmoid function
97+
*
98+
* @param x: input vector
99+
* @return vector
100+
*/
101+
std::vector<double> sigmoid(std::vector<double> x);
40102

41103
private:
42104
int memCells;
43105
int inputVectDim;
44106
int timeSteps;
45107

108+
// weight vector arrays
46109
std::vector<double> * aWeightVecArr;
47110
std::vector<double> * iWeightVecArr;
48111
std::vector<double> * fWeightVecArr;
49112
std::vector<double> * oWeightVecArr;
50113

114+
// bias value arrays
51115
double * aBiasArr;
52116
double * iBiasArr;
53117
double * fBiasArr;
@@ -56,6 +120,7 @@ class LSTMNet {
56120
std::vector<double> * memCellOutArr;
57121
std::vector<double> * memCellStateArr;
58122

123+
// gate output value arrays
59124
std::vector<double> * aGateVecArr;
60125
std::vector<double> * iGateVecArr;
61126
std::vector<double> * fGateVecArr;

0 commit comments

Comments
 (0)