-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatasets.py
91 lines (76 loc) · 2.93 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from typing import Dict, Any
import torch
from torch.utils.data import Dataset
from rxnfp.transformer_fingerprints import (
RXNBERTFingerprintGenerator,
get_default_model_and_tokenizer
)
from utils.reactions import reaction_fps
model, tokenizer = get_default_model_and_tokenizer()
rxnfp_generator = RXNBERTFingerprintGenerator(model, tokenizer)
class ReactionSmilesDataset(Dataset):
def __init__(self,
filepath: str,
dev: Any,
fp_method: str,
params: Dict[str, Any]) -> None:
self.filepath = filepath # path to .csv file
self.dev = dev
self.smiles = []
self.labels = []
self.fp_method = fp_method
self.params = params
with open(self.filepath) as _file:
for i, line in enumerate(_file):
try:
smi, label = line.split(";")
except ValueError:
smi = line
label = 0
if not params["include_agents"]:
reactants, agents, products = smi.split(">")
rearranged_smi = f"{reactants}.{agents}>>{products}"
self.smiles.append(rearranged_smi)
else:
self.smiles.append(smi.strip())
self.labels.append(int(label))
def __len__(self):
return len(self.smiles)
def __getitem__(self, idx):
descriptors = reaction_fps(self.smiles[idx],
fp_method=self.fp_method,
**self.params)
return torch.from_numpy(descriptors).float().to(self.dev), self.labels[idx]
class BERTFpsReactionSmilesDataset(Dataset):
def __init__(self,
filepath: str,
no_agents: bool,
dev: Any) -> None:
self.filepath = filepath # path to .csv file
self.dev = dev
self.smiles = []
self.labels = []
self.fps_dict = {}
with open(self.filepath) as _file:
for i, line in enumerate(_file):
try:
smi, label = line.split(";")
except ValueError:
smi = line
label = 0
smi = smi.strip()
if no_agents:
reactants, agents, products = smi.split(">")
smi = f"{reactants}.{agents}>>{products}"
self.smiles.append(smi)
self.labels.append(int(label))
def __len__(self):
return len(self.smiles)
def __getitem__(self, idx):
smi = self.smiles[idx]
if smi not in self.fps_dict:
bert_fingerprint = rxnfp_generator.convert(self.smiles[idx])
self.fps_dict[smi] = bert_fingerprint
else:
bert_fingerprint = self.fps_dict[smi]
return torch.tensor(bert_fingerprint).float().to(self.dev), self.labels[idx]