2
2
import json
3
3
import math
4
4
import re
5
+ import collections
5
6
6
7
import numpy as np
7
8
import torch
@@ -187,22 +188,23 @@ def forward(self, h):
187
188
return lm_logits
188
189
189
190
190
- class ClfHead (nn .Module ):
191
+ class MultipleChoiceHead (nn .Module ):
191
192
""" Classifier Head for the transformer """
192
193
193
194
def __init__ (self , clf_token , cfg ):
194
- super (ClfHead , self ).__init__ ()
195
+ super (MultipleChoiceHead , self ).__init__ ()
195
196
self .n_embd = cfg .n_embd
196
197
self .clf_token = clf_token
197
198
self .dropout = nn .Dropout2d (cfg .clf_pdrop ) # To reproduce the noise_shape parameter of TF implementation
198
199
self .linear = nn .Linear (cfg .n_embd , 1 )
199
- nn .init .normal_ (self .linear .weight , std = 0.02 )
200
+
201
+ nn .init .normal_ (self .linear .weight , std = 0.02 )
200
202
nn .init .normal_ (self .linear .bias , 0 )
201
203
202
204
def forward (self , h , x ):
203
205
# Classification logits
204
206
clf_h = h .view (- 1 , self .n_embd )
205
- flat = x [:, :, : , 0 ].contiguous ().view (- 1 )
207
+ flat = x [... , 0 ].contiguous ().view (- 1 )
206
208
clf_h = clf_h [flat == self .clf_token , :]
207
209
clf_h = clf_h .view (- 1 , x .size (1 ), self .n_embd , 1 )
208
210
# This double transposition is there to replicate the behavior
@@ -212,22 +214,90 @@ def forward(self, h, x):
212
214
clf_h = self .dropout (clf_h .transpose (1 , 2 )).transpose (1 , 2 )
213
215
clf_h = clf_h .contiguous ().view (- 1 , self .n_embd )
214
216
clf_logits = self .linear (clf_h )
217
+
215
218
return clf_logits .view (- 1 , x .size (1 ))
216
219
217
220
221
+ class ClfHead (nn .Module ):
222
+ """Classification Head for the transformer
223
+
224
+ TODO: test this class."""
225
+ def __init__ (self , clf_token , cfg , n_class ):
226
+ super (ClfHead , self ).__init__ ()
227
+ self .n_embd = cfg .n_embd
228
+ self .clf_token = clf_token
229
+ self .dropout = nn .Dropout (cfg .clf_pdrop )
230
+ self .linear = nn .Linear (cfg .n_embd , n_class )
231
+
232
+ nn .init .normal_ (self .linear .weight , std = 0.02 )
233
+ nn .init .normal_ (self .linear .bias , 0 )
234
+
235
+ def forward (self , h , x ):
236
+ clf_h = h .view (- 1 , self .n_embd )
237
+ flat = x [..., 0 ].contiguous ().view (- 1 )
238
+ clf_h = clf_h [flat == self .clf_token , :]
239
+ clf_h = self .dropout (clf_h )
240
+ clf_logits = self .linear (clf_h )
241
+
242
+ return clf_logits
243
+
244
+ class SimilarityHead (nn .Module ):
245
+ """ Similarity Head for the transformer
246
+
247
+ TODO: test this class."""
248
+ def __init__ (self , clf_token , cfg ):
249
+ super (SimilarityHead , self ).__init__ ()
250
+ self .n_embd = cfg .n_embd
251
+ self .clf_token = clf_token
252
+ self .dropout = nn .Dropout (cfg .clf_pdrop )
253
+ self .linear = nn .Linear (cfg .n_embd , 1 )
254
+
255
+ nn .init .normal_ (self .linear .weight , std = 0.02 )
256
+ nn .init .normal_ (self .linear .bias , 0 )
257
+
258
+ def forward (self , h , x ):
259
+ sim_h = h .view (- 1 , self .n_embd )
260
+ flat = x [..., 0 ].contiguous ().view (- 1 )
261
+ sim_h = sim_h [flat == self .clf_token , :]
262
+ sim_h = self .dropout (sim_h )
263
+ sim_h = sim_h .sum (dim = 1 )
264
+ sim_logits = self .linear (sim_h )
265
+
266
+ return sim_logits
267
+
218
268
class DoubleHeadModel (nn .Module ):
219
- """ Transformer with language model and classification heads """
220
- def __init__ (self , cfg , clf_token , vocab = 40990 , n_ctx = 512 ):
269
+ """ Transformer with language model and task specific heads """
270
+ def __init__ (self , cfg , clf_token , task_head_type , vocab = 40990 , n_ctx = 512 ):
221
271
super (DoubleHeadModel , self ).__init__ ()
222
272
self .transformer = TransformerModel (cfg , vocab = vocab , n_ctx = n_ctx )
223
273
self .lm_head = LMHead (self .transformer , cfg )
224
- self .clf_head = ClfHead (clf_token , cfg )
274
+ if isinstance (task_head_type , str ):
275
+ if task_head_type == 'multiple_choice' :
276
+ self .task_head = MultipleChoiceHead (clf_token , cfg )
277
+ elif task_head_type == 'similarity' :
278
+ self .task_head = SimilarityHead (clf_token , cfg )
279
+ elif task_head_type == 'inference' :
280
+ # the three classes correspond to entailment, contradiction and neutral.
281
+ self .task_head = ClfHead (clf_token , cfg , 3 )
282
+ else :
283
+ raise ValueError ("task_head_type is expected to be 'multiple_choice' "
284
+ "'similarity', 'inference' or ('classification', n_class) "
285
+ f"got { task_head_type } ." )
286
+ elif isinstance (task_head_type , collections .abc .Sequence ) and len (task_head_type ) == 2 and \
287
+ task_head_type [0 ] == 'classification' :
288
+ n_class = task_head_type [1 ]
289
+ self .task_head = ClfHead (clf_token , cfg , n_class )
290
+ else :
291
+ raise ValueError ("task_head_type is expected to be 'multiple_choice' "
292
+ "'similarity', 'inference' or ('classification', n_class) "
293
+ f"got { task_head_type } ." )
225
294
226
295
def forward (self , x ):
227
296
h = self .transformer (x )
228
297
lm_logits = self .lm_head (h )
229
- clf_logits = self .clf_head (h , x )
230
- return lm_logits , clf_logits
298
+ task_logits = self .task_head (h , x )
299
+
300
+ return lm_logits , task_logits
231
301
232
302
233
303
def load_openai_pretrained_model (model , n_ctx = - 1 , n_special = - 1 , n_transfer = 12 , n_embd = 768 , path = './model/' ,
0 commit comments