Skip to content

Commit 7309be4

Browse files
committed
Initial commit
0 parents  commit 7309be4

File tree

6 files changed

+947
-0
lines changed

6 files changed

+947
-0
lines changed

.gitignore

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# Created by .ignore support plugin (hsz.mobi)
2+
### JetBrains template
3+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
4+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
5+
6+
# User-specific stuff
7+
.idea
8+
.idea/**/workspace.xml
9+
.idea/**/tasks.xml
10+
.idea/**/usage.statistics.xml
11+
.idea/**/dictionaries
12+
.idea/**/shelf
13+
*.iml
14+
15+
# Generated files
16+
.idea/**/contentModel.xml
17+
18+
# Sensitive or high-churn files
19+
.idea/**/dataSources/
20+
.idea/**/dataSources.ids
21+
.idea/**/dataSources.local.xml
22+
.idea/**/sqlDataSources.xml
23+
.idea/**/dynamic.xml
24+
.idea/**/uiDesigner.xml
25+
.idea/**/dbnavigator.xml
26+
27+
# Gradle
28+
.idea/**/gradle.xml
29+
.idea/**/libraries
30+
31+
# Gradle and Maven with auto-import
32+
# When using Gradle or Maven with auto-import, you should exclude module files,
33+
# since they will be recreated, and may cause churn. Uncomment if using
34+
# auto-import.
35+
# .idea/artifacts
36+
# .idea/compiler.xml
37+
# .idea/jarRepositories.xml
38+
# .idea/modules.xml
39+
# .idea/*.iml
40+
# .idea/modules
41+
# *.iml
42+
# *.ipr
43+
44+
# CMake
45+
cmake-build-*/
46+
47+
# Mongo Explorer plugin
48+
.idea/**/mongoSettings.xml
49+
50+
# File-based project format
51+
*.iws
52+
53+
# IntelliJ
54+
out/
55+
56+
# mpeltonen/sbt-idea plugin
57+
.idea_modules/
58+
59+
# JIRA plugin
60+
atlassian-ide-plugin.xml
61+
62+
# Cursive Clojure plugin
63+
.idea/replstate.xml
64+
65+
# Crashlytics plugin (for Android Studio and IntelliJ)
66+
com_crashlytics_export_strings.xml
67+
crashlytics.properties
68+
crashlytics-build.properties
69+
fabric.properties
70+
71+
# Editor-based Rest Client
72+
.idea/httpRequests
73+
74+
# Android studio 3.1+ serialized cache file
75+
.idea/caches/build_file_checksums.ser
76+
77+
### Python template
78+
# Byte-compiled / optimized / DLL files
79+
__pycache__/
80+
*.py[cod]
81+
*$py.class
82+
83+
# C extensions
84+
*.so
85+
86+
# Distribution / packaging
87+
.Python
88+
build/
89+
develop-eggs/
90+
dist/
91+
downloads/
92+
eggs/
93+
.eggs/
94+
lib/
95+
lib64/
96+
parts/
97+
sdist/
98+
var/
99+
wheels/
100+
pip-wheel-metadata/
101+
share/python-wheels/
102+
*.egg-info/
103+
.installed.cfg
104+
*.egg
105+
MANIFEST
106+
107+
# PyInstaller
108+
# Usually these files are written by a python script from a template
109+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
110+
*.manifest
111+
*.spec
112+
113+
# Installer logs
114+
pip-log.txt
115+
pip-delete-this-directory.txt
116+
117+
# Unit test / coverage reports
118+
htmlcov/
119+
.tox/
120+
.nox/
121+
.coverage
122+
.coverage.*
123+
.cache
124+
nosetests.xml
125+
coverage.xml
126+
*.cover
127+
*.py,cover
128+
.hypothesis/
129+
.pytest_cache/
130+
cover/
131+
132+
# Translations
133+
*.mo
134+
*.pot
135+
136+
# Django stuff:
137+
*.log
138+
local_settings.py
139+
db.sqlite3
140+
db.sqlite3-journal
141+
142+
# Flask stuff:
143+
instance/
144+
.webassets-cache
145+
146+
# Scrapy stuff:
147+
.scrapy
148+
149+
# Sphinx documentation
150+
docs/_build/
151+
152+
# PyBuilder
153+
.pybuilder/
154+
target/
155+
156+
# Jupyter Notebook
157+
.ipynb_checkpoints
158+
159+
# IPython
160+
profile_default/
161+
ipython_config.py
162+
163+
# pyenv
164+
# For a library or package, you might want to ignore these files since the code is
165+
# intended to run in multiple environments; otherwise, check them in:
166+
# .python-version
167+
168+
# pipenv
169+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
170+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
171+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
172+
# install all needed dependencies.
173+
#Pipfile.lock
174+
175+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
176+
__pypackages__/
177+
178+
# Celery stuff
179+
celerybeat-schedule
180+
celerybeat.pid
181+
182+
# SageMath parsed files
183+
*.sage.py
184+
185+
# Environments
186+
.env
187+
.venv
188+
env/
189+
venv/
190+
ENV/
191+
env.bak/
192+
venv.bak/
193+
194+
# Spyder project settings
195+
.spyderproject
196+
.spyproject
197+
198+
# Rope project settings
199+
.ropeproject
200+
201+
# mkdocs documentation
202+
/site
203+
204+
# mypy
205+
.mypy_cache/
206+
.dmypy.json
207+
dmypy.json
208+
209+
# Pyre type checker
210+
.pyre/
211+
212+
# pytype static type analyzer
213+
.pytype/
214+
215+
# Cython debug symbols
216+
cython_debug/
217+

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
RSNA Intracranial Hemorrhage Detection (https://www.kaggle.com/c/rsna-intracranial-hemorrhage-detection) model
2+
3+
ResNeXt + PCA + BiLSTM for 0.04989 on Private Test Dataset
4+
5+
Sequence Metadata Required: https://www.kaggle.com/mihailburduja/rsna-intracranial-sequence-metadata
6+
7+
Slices are resized to 256x256, embedding vector is resized to 120.
8+
9+
`models.py` contains the CNN and LSTM model
10+
11+
`datasets.py` contains the torch Datasets for CNN and for LSTM model
12+
13+
`train_cnn.py` trains the CNN and outputs PCA embeddings and predictions
14+
15+
`train_lstm.py` train the LSTM and outputs the submission file

datasets.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# from apex import amp
2+
import numpy as np
3+
import pandas as pd
4+
import pydicom
5+
import torch
6+
from torch.utils.data import Dataset
7+
8+
9+
def correct_dcm(dcm):
10+
x = dcm.pixel_array + 1000
11+
px_mode = 4096
12+
x[x >= px_mode] = x[x >= px_mode] - px_mode
13+
dcm.PixelData = x.tobytes()
14+
dcm.RescaleIntercept = -1000
15+
16+
17+
def window_image(dcm, window_center, window_width):
18+
if (dcm.BitsStored == 12) and (dcm.PixelRepresentation == 0) and (int(dcm.RescaleIntercept) > -100):
19+
correct_dcm(dcm)
20+
21+
img = dcm.pixel_array * dcm.RescaleSlope + dcm.RescaleIntercept
22+
img_min = window_center - window_width // 2
23+
img_max = window_center + window_width // 2
24+
img = np.clip(img, img_min, img_max)
25+
26+
return img
27+
28+
29+
def bsb_window(dcm):
30+
brain_img = window_image(dcm, 40, 80)
31+
subdural_img = window_image(dcm, 80, 200)
32+
soft_img = window_image(dcm, 40, 380)
33+
34+
brain_img = (brain_img - 0) / 80
35+
subdural_img = (subdural_img - (-20)) / 200
36+
soft_img = (soft_img - (-150)) / 380
37+
bsb_img = np.array([brain_img, subdural_img, soft_img]).transpose(1, 2, 0)
38+
39+
return bsb_img
40+
41+
42+
class IntracranialDataset(Dataset):
43+
44+
def __init__(self, csv_file, path, labels, transform=None):
45+
self.path = path
46+
self.data = pd.read_csv(csv_file)
47+
self.transform = transform
48+
self.labels = labels
49+
50+
def __len__(self):
51+
return len(self.data)
52+
53+
def __getitem__(self, idx):
54+
try:
55+
dicom = pydicom.dcmread(self.path + self.data.loc[idx, 'Image'] + '.dcm')
56+
img = bsb_window(dicom)
57+
except:
58+
img = np.zeros((512, 512, 3))
59+
60+
if self.transform:
61+
augmented = self.transform(image=img)
62+
img = augmented['image']
63+
64+
if self.labels:
65+
66+
labels = torch.tensor(
67+
self.data.loc[
68+
idx, ['epidural', 'intraparenchymal', 'intraventricular', 'subarachnoid', 'subdural', 'any']])
69+
return {'image': img, 'labels': labels}
70+
71+
else:
72+
73+
return {'image': img}
74+
75+
76+
class PredictionsDataset(Dataset):
77+
78+
def __init__(self, data, col_names, features=120, train=True, series=None):
79+
self.data = data
80+
self.train = train
81+
self.col_names = col_names
82+
self.embed_cols = [str(i) for i in range(features)]
83+
84+
if series is None:
85+
self.series = self.data['SeriesInstanceUID'].unique()
86+
else:
87+
self.series = series
88+
89+
def __len__(self):
90+
return len(self.series)
91+
92+
def __getitem__(self, idx):
93+
series_id = self.series[idx]
94+
images = self.data[self.data['SeriesInstanceUID'] == series_id].sort_values(by=['ImagePositionSpan', 'ImageId'])
95+
96+
cols = self.col_names
97+
if self.train:
98+
cols = [x + '_x' for x in self.col_names]
99+
100+
image_preds = images[cols].to_numpy().astype(np.float)
101+
102+
if self.train:
103+
image_truths = images[[x + '_y' for x in self.col_names]].to_numpy().astype(np.float)
104+
105+
image_embeds = images[self.embed_cols].to_numpy().astype(np.float)
106+
107+
return {
108+
'preds': torch.tensor(image_preds).to(torch.float),
109+
'labels': torch.tensor(image_truths).to(torch.float),
110+
'embeds': torch.tensor(image_embeds).to(torch.float)
111+
}
112+
else:
113+
image_embeds = images[self.embed_cols].to_numpy().astype(np.float)
114+
return {
115+
'preds': torch.tensor(image_preds).to(torch.float),
116+
'embeds': torch.tensor(image_embeds).to(torch.float)
117+
}

models.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# from apex import amp
2+
import torch
3+
4+
5+
class ResNeXtModel(torch.nn.Module):
6+
def __init__(self):
7+
super(ResNeXtModel, self).__init__()
8+
resnext = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl')
9+
self.base = torch.nn.Sequential(*list(resnext.children())[:-1])
10+
self.fc = torch.nn.Sequential(
11+
torch.nn.Linear(2048, 6)
12+
)
13+
14+
def forward(self, input):
15+
features = self.base(input).reshape(-1, 2048)
16+
out = self.fc(features)
17+
return out, features
18+
19+
20+
class EmbeddingSmootherModel(torch.nn.Module):
21+
22+
def __init__(self, features=120, hidden_size=256):
23+
super(EmbeddingSmootherModel, self).__init__()
24+
self.hidden_size = hidden_size
25+
self.lstm = torch.nn.LSTM(features + 6, self.hidden_size, num_layers=3, dropout=0.3, batch_first=True,
26+
bidirectional=True)
27+
self.scan_rnn = torch.nn.GRU(6, 64, num_layers=1, batch_first=True, bidirectional=True)
28+
self.classifier = torch.nn.Sequential(
29+
torch.nn.Linear(self.hidden_size * 2 + 6, 6)
30+
)
31+
self.dropout = torch.nn.Dropout(0.5)
32+
33+
def forward(self, seq, preds):
34+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
35+
36+
hidden = (
37+
torch.zeros(6, 1, self.hidden_size).to(device),
38+
torch.zeros(6, 1, self.hidden_size).to(device)
39+
)
40+
41+
out, hidden = self.lstm(seq, hidden)
42+
combined_out = torch.cat((out, preds), 2)
43+
out = self.classifier(self.dropout(combined_out))
44+
45+
return out

0 commit comments

Comments
 (0)