1
+ import numpy as np
2
+ import math
3
+ import copy
4
+ import torch
5
+ import torch .nn as nn
6
+ import torch .nn .functional as F
7
+ from torch .nn .parameter import Parameter
8
+
9
+ vocab = n_vocab + n_special + n_ctx
10
+
11
+ def gelu (x ):
12
+ return 0.5 * x * (1 + torch .tanh (math .sqrt (2 / math .pi )* (x + 0.044715 * torch .pow (x , 3 ))))
13
+
14
+ def swish (x ):
15
+ return x * torch .sigmoid (x )
16
+
17
+ ACT_FNS = {
18
+ 'relu' : nn .relu ,
19
+ 'swish' : swish ,
20
+ 'gelu' : gelu
21
+ }
22
+
23
+ def clones (module , N ):
24
+ "Produce N identical layers."
25
+ return nn .ModuleList ([copy .deepcopy (module ) for _ in range (N )])
26
+
27
+
28
+ class LayerNorm (nn .Module ):
29
+ "Construct a layernorm module (See citation for details)."
30
+ def __init__ (self , n_state , eps = 1e-6 ):
31
+ super (LayerNorm , self ).__init__ ()
32
+ self .g = nn .Parameter (torch .ones (n_state ))
33
+ self .b = nn .Parameter (torch .zeros (n_state ))
34
+ self .eps = eps
35
+
36
+ def forward (self , x ):
37
+ mean = x .mean (- 1 , keepdim = True )
38
+ std = x .std (- 1 , keepdim = True )
39
+ # One difference with the TF version here: we add epsilon outside of sqrt
40
+ return self .g * (x - mean ) / (std + self .eps ) + self .b
41
+
42
+
43
+ class Conv1D (nn .Module ):
44
+ def __init__ (self , nf , rf , nx ):
45
+ super (Conv1D , self ).__init__ ()
46
+ self .rf = rf
47
+ if rf == 1 : #faster 1x1 conv
48
+ self .w = Parameter (torch .ones (nx , nf )) # TODO change to random normal
49
+ self .b = Parameter (torch .zeros (nf ))
50
+ else : #was used to train LM
51
+ raise NotImplementedError
52
+
53
+ def forward (self , x ):
54
+ if self .rf == 1 :
55
+ size_out = x .size ()[:- 1 ] + [nf ]
56
+ x = torch .addmm (self .b , x .view (- 1 , x .size (- 1 )), self .w )
57
+ x = x .view (* size_out )
58
+ else :
59
+ raise NotImplementedError
60
+ return x
61
+
62
+
63
+ class Attention (nn .Module ):
64
+ def __init__ (self , nx , n_state , n_head , attn_pdrop , resid_pdrop , scale = False ):
65
+ super (Attention , self ).__init__ ()
66
+ self .c_attn = Conv1D (n_state * 3 , 1 , nx )
67
+ self .c_proj = Conv1D (n_state , 1 , nx )
68
+ self .scale = scale
69
+ self .n_head = n_head
70
+ self .attn_dropout = nn .Dropout (attn_pdrop )
71
+ self .resid_dropout = nn .Dropout (resid_pdrop )
72
+
73
+ @staticmethod
74
+ def mask_attn_weights (w ):
75
+ n = w .size (- 1 )
76
+ b = torch .tril (np .ones (n , n )).view (1 , 1 , n , n )
77
+ return w * b + - 1e9 * (1 - b )
78
+
79
+ def _attn (self , q , k , v ):
80
+ w = torch .matmul (q , k )
81
+ if self .scale :
82
+ w = w / math .sqrt (v .size (- 1 ))
83
+ w = self .mask_attn_weights (w )
84
+ w = nn .Softmax ()(w )
85
+ w = self .attn_dropout (w )
86
+ return torch .matmul (w , v )
87
+
88
+ def merge_heads (self , x ):
89
+ new_x_shape = x .size ()[:- 2 ] + [np .prod (x .size ()[- 2 :])]
90
+ x = x .view (* new_x_shape ) # in Tensorflow version: merge_states
91
+ return x .permute (0 , 2 , 1 , 3 )
92
+
93
+ def split_heads (self , x , k = False ):
94
+ new_x_shape = x .size ()[:- 1 ] + [self .n_head , x .size (- 1 )// self .n_head ]
95
+ x = x .view (* new_x_shape ) # in Tensorflow version: split_states
96
+ if k :
97
+ return x .permute (0 , 2 , 3 , 1 )
98
+ else :
99
+ return x .permute (0 , 2 , 1 , 3 )
100
+
101
+ def forward (self , x ):
102
+ x = self .c_attn (x )
103
+ query , key , value = x .split (3 , dim = 2 )
104
+ query = self .split_heads (query )
105
+ key = self .split_heads (key , k = True )
106
+ value = self .split_heads (value )
107
+ a = self ._attn (query , key , value )
108
+ a = self .merge_heads (a )
109
+ a = self .c_proj (a )
110
+ a = self .resid_dropout (a )
111
+ return a
112
+
113
+
114
+ class MLP (nn .Module ):
115
+ def __init__ (self , nx , n_state , afn , resid_pdrop ):
116
+ super (MLP , self ).__init__ ()
117
+ self .c_fc = Conv1D (n_state , 1 , nx )
118
+ self .c_proj = Conv1D (nx , 1 , nx )
119
+ self .act = ACT_FNS [afn ]
120
+ self .dropout = nn .Dropout (resid_pdrop )
121
+
122
+ def forward (self , x ):
123
+ h = self .act (self .c_fc (x ))
124
+ h = self .c_proj (h )
125
+ return self .dropout (h )
126
+
127
+
128
+ class Block (nn .Module ):
129
+ def __init__ (self , nx , n_head , attn_pdrop , resid_pdrop , afn , scale = False ):
130
+ super (Block , self ).__init__ ()
131
+ self .attn = Attention (nx , nx , n_head , attn_pdrop , resid_pdrop , scale )
132
+ self .ln_1 = LayerNorm (nx )
133
+ self .mlp = MLP (nx , nx * 4 , afn , resid_pdrop )
134
+ self .ln_2 = LayerNorm (nx )
135
+
136
+ def forward (self , x ):
137
+ h = self .attn (x )
138
+ h = self .ln_1 (x )
139
+ h = self .mlp (x )
140
+ h = self .ln_2 (x )
141
+ return h
142
+
143
+
144
+ class Model (nn .Module ):
145
+ """ Transformer model """
146
+ def __init__ (self , vocab , n_embd , pdrop , n_layers ,
147
+ nx , n_head , attn_pdrop , resid_pdrop , afn ):
148
+ super (Model , self ).__init__ ()
149
+ self .embed = nn .Embedding (vocab , n_embd )
150
+ self .drop = nn .Dropout (pdrop )
151
+ self .blocks = clones (Block (nx , n_head , attn_pdrop ,
152
+ resid_pdrop , afn , scale = True ), n_layers )
153
+ self .decoder = nn .Linear (nhid , vocab , bias = False )
154
+ self .decoder .weight = self .embed .weight
155
+
156
+ def forward (self , x , m ):
157
+ x = x .view (- 1 , x .size (2 ), x .size (3 ))
158
+ m = m .view (- 1 , m .size (2 ))
159
+ e = self .embed (x )
160
+ h = e .sum (dim = 2 )
161
+ for block in self .blocks :
162
+ h = block (h )
163
+ return h
0 commit comments