Skip to content

Commit 561d409

Browse files
authored
Merge pull request #25 from rodgzilla/multiple_choice_head
Simplifying the use of the model to perform different tasks
2 parents d914228 + cbccdb0 commit 561d409

File tree

3 files changed

+154
-49
lines changed

3 files changed

+154
-49
lines changed

loss.py

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import torch
2+
3+
class MultipleChoiceLossCompute:
4+
"A Loss compute and train function for multiple choice tasks."
5+
6+
def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
7+
self.lm_criterion = lm_criterion
8+
self.clf_criterion = clf_criterion
9+
self.lm_coef = lm_coef
10+
self.opt = opt
11+
12+
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
13+
# Language modeling loss
14+
if lm_logits is not None:
15+
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
16+
M = M.view(-1, M.size(2))
17+
lm_losses = self.lm_criterion(lm_logits, x_shifted)
18+
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2) - 1)
19+
lm_losses = lm_losses * M[:, 1:]
20+
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
21+
# Classification loss
22+
clf_losses = self.clf_criterion(clf_logits, Y)
23+
if only_return_losses:
24+
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses
25+
26+
if self.lm_coef > 0 and lm_logits is not None:
27+
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
28+
else:
29+
train_loss = clf_losses.sum()
30+
train_loss.backward()
31+
if self.opt is not None:
32+
self.opt.step()
33+
self.opt.zero_grad()
34+
return train_loss.item()
35+
36+
class ClassificationLossCompute:
37+
"A Loss compute and train function for classification tasks."
38+
39+
def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
40+
self.lm_criterion = lm_criterion
41+
self.clf_criterion = clf_criterion
42+
self.lm_coef = lm_coef
43+
self.opt = opt
44+
45+
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
46+
# Language modeling loss
47+
if lm_logits is not None:
48+
x_shifted = X[:, 1:, 0].contiguous().view(-1)
49+
M = M.view(-1, M.size(-1))
50+
lm_losses = self.lm_criterion(lm_logits, x_shifted)
51+
lm_losses = lm_losses.view(X.size(0), X.size(-2) - 1)
52+
lm_losses = lm_losses * M[:, 1:]
53+
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
54+
# Classification loss
55+
clf_losses = self.clf_criterion(clf_logits, Y)
56+
if only_return_losses:
57+
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses
58+
59+
if self.lm_coef > 0 and lm_logits is not None:
60+
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
61+
else:
62+
train_loss = clf_losses.sum()
63+
train_loss.backward()
64+
if self.opt is not None:
65+
self.opt.step()
66+
self.opt.zero_grad()
67+
return train_loss.item()
68+
69+
# TODO Implement a LossCompute class for similiraty tasks.

model_pytorch.py

+79-9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import math
44
import re
5+
import collections
56

67
import numpy as np
78
import torch
@@ -187,22 +188,23 @@ def forward(self, h):
187188
return lm_logits
188189

189190

190-
class ClfHead(nn.Module):
191+
class MultipleChoiceHead(nn.Module):
191192
""" Classifier Head for the transformer """
192193

193194
def __init__(self, clf_token, cfg):
194-
super(ClfHead, self).__init__()
195+
super(MultipleChoiceHead, self).__init__()
195196
self.n_embd = cfg.n_embd
196197
self.clf_token = clf_token
197198
self.dropout = nn.Dropout2d(cfg.clf_pdrop) # To reproduce the noise_shape parameter of TF implementation
198199
self.linear = nn.Linear(cfg.n_embd, 1)
199-
nn.init.normal_(self.linear.weight, std=0.02)
200+
201+
nn.init.normal_(self.linear.weight, std = 0.02)
200202
nn.init.normal_(self.linear.bias, 0)
201203

202204
def forward(self, h, x):
203205
# Classification logits
204206
clf_h = h.view(-1, self.n_embd)
205-
flat = x[:, :, :, 0].contiguous().view(-1)
207+
flat = x[..., 0].contiguous().view(-1)
206208
clf_h = clf_h[flat == self.clf_token, :]
207209
clf_h = clf_h.view(-1, x.size(1), self.n_embd, 1)
208210
# This double transposition is there to replicate the behavior
@@ -212,22 +214,90 @@ def forward(self, h, x):
212214
clf_h = self.dropout(clf_h.transpose(1, 2)).transpose(1, 2)
213215
clf_h = clf_h.contiguous().view(-1, self.n_embd)
214216
clf_logits = self.linear(clf_h)
217+
215218
return clf_logits.view(-1, x.size(1))
216219

217220

221+
class ClfHead(nn.Module):
222+
"""Classification Head for the transformer
223+
224+
TODO: test this class."""
225+
def __init__(self, clf_token, cfg, n_class):
226+
super(ClfHead, self).__init__()
227+
self.n_embd = cfg.n_embd
228+
self.clf_token = clf_token
229+
self.dropout = nn.Dropout(cfg.clf_pdrop)
230+
self.linear = nn.Linear(cfg.n_embd, n_class)
231+
232+
nn.init.normal_(self.linear.weight, std = 0.02)
233+
nn.init.normal_(self.linear.bias, 0)
234+
235+
def forward(self, h, x):
236+
clf_h = h.view(-1, self.n_embd)
237+
flat = x[..., 0].contiguous().view(-1)
238+
clf_h = clf_h[flat == self.clf_token, :]
239+
clf_h = self.dropout(clf_h)
240+
clf_logits = self.linear(clf_h)
241+
242+
return clf_logits
243+
244+
class SimilarityHead(nn.Module):
245+
""" Similarity Head for the transformer
246+
247+
TODO: test this class."""
248+
def __init__(self, clf_token, cfg):
249+
super(SimilarityHead, self).__init__()
250+
self.n_embd = cfg.n_embd
251+
self.clf_token = clf_token
252+
self.dropout = nn.Dropout(cfg.clf_pdrop)
253+
self.linear = nn.Linear(cfg.n_embd, 1)
254+
255+
nn.init.normal_(self.linear.weight, std = 0.02)
256+
nn.init.normal_(self.linear.bias, 0)
257+
258+
def forward(self, h, x):
259+
sim_h = h.view(-1, self.n_embd)
260+
flat = x[..., 0].contiguous().view(-1)
261+
sim_h = sim_h[flat == self.clf_token, :]
262+
sim_h = self.dropout(sim_h)
263+
sim_h = sim_h.sum(dim = 1)
264+
sim_logits = self.linear(sim_h)
265+
266+
return sim_logits
267+
218268
class DoubleHeadModel(nn.Module):
219-
""" Transformer with language model and classification heads """
220-
def __init__(self, cfg, clf_token, vocab=40990, n_ctx=512):
269+
""" Transformer with language model and task specific heads """
270+
def __init__(self, cfg, clf_token, task_head_type, vocab=40990, n_ctx=512):
221271
super(DoubleHeadModel, self).__init__()
222272
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
223273
self.lm_head = LMHead(self.transformer, cfg)
224-
self.clf_head = ClfHead(clf_token, cfg)
274+
if isinstance(task_head_type, str):
275+
if task_head_type == 'multiple_choice':
276+
self.task_head = MultipleChoiceHead(clf_token, cfg)
277+
elif task_head_type == 'similarity':
278+
self.task_head = SimilarityHead(clf_token, cfg)
279+
elif task_head_type == 'inference':
280+
# the three classes correspond to entailment, contradiction and neutral.
281+
self.task_head = ClfHead(clf_token, cfg, 3)
282+
else:
283+
raise ValueError("task_head_type is expected to be 'multiple_choice' "
284+
"'similarity', 'inference' or ('classification', n_class) "
285+
f"got {task_head_type}.")
286+
elif isinstance(task_head_type, collections.abc.Sequence) and len(task_head_type) == 2 and \
287+
task_head_type[0] == 'classification':
288+
n_class = task_head_type[1]
289+
self.task_head = ClfHead(clf_token, cfg, n_class)
290+
else:
291+
raise ValueError("task_head_type is expected to be 'multiple_choice' "
292+
"'similarity', 'inference' or ('classification', n_class) "
293+
f"got {task_head_type}.")
225294

226295
def forward(self, x):
227296
h = self.transformer(x)
228297
lm_logits = self.lm_head(h)
229-
clf_logits = self.clf_head(h, x)
230-
return lm_logits, clf_logits
298+
task_logits = self.task_head(h, x)
299+
300+
return lm_logits, task_logits
231301

232302

233303
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/',

train.py

+6-40
Original file line numberDiff line numberDiff line change
@@ -15,41 +15,7 @@
1515
from text_utils import TextEncoder
1616
from utils import (encode_dataset, iter_data,
1717
ResultLogger, make_path)
18-
19-
20-
class LossCompute:
21-
"A Loss compute and train function."
22-
23-
def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
24-
self.lm_criterion = lm_criterion
25-
self.clf_criterion = clf_criterion
26-
self.lm_coef = lm_coef
27-
self.opt = opt
28-
29-
def __call__(self, X, Y, M, clf_logits, lm_logits=None, only_return_losses=False):
30-
# Language modeling loss
31-
if lm_logits is not None:
32-
x_shifted = X[:, :, 1:, 0].contiguous().view(-1) # Shape: 252
33-
M = M.view(-1, M.size(2))
34-
lm_losses = self.lm_criterion(lm_logits, x_shifted)
35-
lm_losses = lm_losses.view(X.size(0) * X.size(1), X.size(2) - 1)
36-
lm_losses = lm_losses * M[:, 1:]
37-
lm_losses = lm_losses.sum(1) / torch.sum(M[:, 1:], 1)
38-
# Classification loss
39-
clf_losses = self.clf_criterion(clf_logits, Y)
40-
if only_return_losses:
41-
return (clf_losses, lm_losses) if lm_logits is not None else clf_losses
42-
43-
if self.lm_coef > 0 and lm_logits is not None:
44-
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
45-
else:
46-
train_loss = clf_losses.sum()
47-
train_loss.backward()
48-
if self.opt is not None:
49-
self.opt.step()
50-
self.opt.zero_grad()
51-
return train_loss.item()
52-
18+
from loss import MultipleChoiceLossCompute
5319

5420
def transform_roc(X1, X2, X3):
5521
n_batch = len(X1)
@@ -263,7 +229,7 @@ def run_epoch():
263229
n_batch_train = args.n_batch * max(n_gpu, 1)
264230
n_updates_total = (n_train // n_batch_train) * args.n_iter
265231

266-
dh_model = DoubleHeadModel(args, clf_token, vocab, n_ctx)
232+
dh_model = DoubleHeadModel(args, clf_token, 'multiple_choice', vocab, n_ctx)
267233

268234
criterion = nn.CrossEntropyLoss(reduce=False)
269235
model_opt = OpenAIAdam(dh_model.parameters(),
@@ -277,10 +243,10 @@ def run_epoch():
277243
l2=args.l2,
278244
vector_l2=args.vector_l2,
279245
max_grad_norm=args.max_grad_norm)
280-
compute_loss_fct = LossCompute(criterion,
281-
criterion,
282-
args.lm_coef,
283-
model_opt)
246+
compute_loss_fct = MultipleChoiceLossCompute(criterion,
247+
criterion,
248+
args.lm_coef,
249+
model_opt)
284250
load_openai_pretrained_model(dh_model.transformer, n_ctx=n_ctx, n_special=n_special)
285251

286252
dh_model.to(device)

0 commit comments

Comments
 (0)