Skip to content

Commit 531215a

Browse files
committed
completing the dataloader
1 parent b33fc2d commit 531215a

File tree

3 files changed

+51
-1
lines changed

3 files changed

+51
-1
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
.idea
2-
data
2+
data
3+
__pycache__/

core/dataloader.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import warnings
2+
3+
import pandas as pd
4+
import numpy as np
5+
from torch.utils.data import Dataset, DataLoader
6+
from sklearn.model_selection import train_test_split
7+
8+
from settings import label_column_name
9+
10+
11+
class StockPriceDataset(Dataset):
12+
def __init__(self, filepath: str, time_step, test: bool = False, train_size: float = 0.7, test_size: float = 0.3,
13+
phase: str = "train"):
14+
data_df = pd.read_csv(filepath)
15+
data_df = data_df.drop(labels=np.where(data_df.isnull().any(axis=1) == True)[0], axis=0)
16+
data_df = data_df.drop(labels="Adj Close", axis=1)
17+
data_df = data_df.loc[(data_df["Volume"] != "0") & (data_df["Volume"] != 0)]
18+
19+
if phase == "train":
20+
train_df, valid_df = train_test_split(data_df, test_size=test_size, train_size=train_size, shuffle=False)
21+
if not test:
22+
std = train_df.std()
23+
mean = train_df.mean()
24+
25+
std_open = std.Open
26+
std_high = std.High
27+
std_low = std.Low
28+
std_close = std.Close
29+
std_volume = std.Volume
30+
31+
mean_open = mean.Open
32+
mean_high = mean.High
33+
mean_low = mean.Low
34+
mean_close = mean.Close
35+
mean_volume = mean.Volume
36+
37+
38+
else:
39+
pass
40+
elif phase == "test":
41+
pass
42+
else:
43+
print("[Failed] You had selected wrong phase")
44+
45+
46+
if __name__ == '__main__':
47+
filename = "/home/lezarus/Documents/Project/cnn_lstm/data/dataset/000001.SS.csv"
48+
dataset = StockPriceDataset(filepath=filename, time_step=10)

settings.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
has_cuda = torch.cuda.is_available()
99

1010
selected_column = []
11+
label_column_name = "Close"
1112

1213
output_dir = project_dir.joinpath("result")
1314
output_dir.mkdir(parents=True, exist_ok=True)

0 commit comments

Comments
 (0)