@@ -59,13 +59,13 @@ def dict_to_xy(d):
59
59
y .append (v )
60
60
return x , y
61
61
62
- def parse_ordo_file (filename ):
62
+ def parse_ordo_file (filename , label ):
63
63
p = re .compile ('.*nn-epoch(\\ d*)\\ .nnue' )
64
64
with open (filename , 'r' ) as ordo_file :
65
65
rows = []
66
66
lines = ordo_file .readlines ()
67
67
for line in lines :
68
- if 'nn-epoch' in line :
68
+ if 'nn-epoch' in line and label in line :
69
69
fields = line .split ()
70
70
net = fields [1 ]
71
71
epoch = int (p .match (net )[1 ])
@@ -78,7 +78,7 @@ def parse_ordo_file(filename):
78
78
def transpose_list_of_tuples (l ):
79
79
return list (map (list , zip (* l )))
80
80
81
- def do_plots (out_filename , root_dirs , elo_range , loss_range ):
81
+ def do_plots (out_filename , root_dirs , elo_range , loss_range , split ):
82
82
'''
83
83
1. Find tfevents files for each root directory
84
84
2. Look for metrics
@@ -107,72 +107,92 @@ def do_plots(out_filename, root_dirs, elo_range, loss_range):
107
107
ax_train_loss .set_xlabel ('step' )
108
108
ax_train_loss .set_ylabel ('train_loss' )
109
109
110
- for root_dir in root_dirs :
111
- print ('Processing root_dir {}' .format (root_dir ))
112
- tfevents_files = find_event_files (root_dir )
113
- print ('Found {} tfevents files.' .format (len (tfevents_files )))
114
-
115
- val_losses = collections .defaultdict (lambda : [])
116
- train_losses = collections .defaultdict (lambda : [])
117
- for i , tfevents_file in enumerate (tfevents_files ):
118
- print ('Processing tfevents file {}/{}: {}' .format (i + 1 , len (tfevents_files ), tfevents_file ))
119
- events_acc = EventAccumulator (tfevents_file , tf_size_guidance )
120
- events_acc .Reload ()
121
-
122
- vv = events_acc .Scalars ('val_loss' )
123
- print ('Found {} val_loss entries.' .format (len (vv )))
124
- minloss = min ([v [2 ] for v in vv ])
125
- for v in vv :
126
- if v [2 ] < minloss + loss_range :
127
- step = v [1 ]
128
- val_losses [step ].append (v [2 ])
129
-
130
- vv = events_acc .Scalars ('train_loss' )
131
- minloss = min ([v [2 ] for v in vv ])
132
- print ('Found {} train_loss entries.' .format (len (vv )))
133
- for v in vv :
134
- if v [2 ] < minloss + loss_range :
135
- step = v [1 ]
136
- train_losses [step ].append (v [2 ])
137
-
138
- print ('Aggregating data...' )
139
-
140
- val_loss = aggregate_dict (val_losses , 'min' )
141
- x , y = dict_to_xy (val_loss )
142
- ax_val_loss .plot (x , y , label = root_dir )
143
-
144
- train_loss = aggregate_dict (train_losses , 'min' )
145
- x , y = dict_to_xy (train_loss )
146
- ax_train_loss .plot (x , y , label = root_dir )
147
-
148
- print ('Finished aggregating data.' )
149
-
150
- ordo_file = find_ordo_file (root_dir )
110
+
111
+ for user_root_dir in root_dirs :
112
+
113
+ # if asked to split we split the roto dir into a number of user root dirs,
114
+ # i.e. all direct subdirectories containing tfevent files.
115
+ # we use the ordo file in the root dir, but split the content.
116
+ split_root_dirs = [user_root_dir ]
117
+ if split :
118
+ split_root_dirs = []
119
+ for item in os .listdir (user_root_dir ):
120
+ if os .path .isdir (os .path .join (user_root_dir , item )):
121
+ root_dir = os .path .join (user_root_dir , item )
122
+ if len (find_event_files (root_dir )) > 0 :
123
+ split_root_dirs .append (root_dir )
124
+ split_root_dirs .sort ()
125
+
126
+ for root_dir in split_root_dirs :
127
+ print ('Processing root_dir {}' .format (root_dir ))
128
+ tfevents_files = find_event_files (root_dir )
129
+ print ('Found {} tfevents files.' .format (len (tfevents_files )))
130
+
131
+ val_losses = collections .defaultdict (lambda : [])
132
+ train_losses = collections .defaultdict (lambda : [])
133
+ for i , tfevents_file in enumerate (tfevents_files ):
134
+ print ('Processing tfevents file {}/{}: {}' .format (i + 1 , len (tfevents_files ), tfevents_file ))
135
+ events_acc = EventAccumulator (tfevents_file , tf_size_guidance )
136
+ events_acc .Reload ()
137
+
138
+ vv = events_acc .Scalars ('val_loss' )
139
+ print ('Found {} val_loss entries.' .format (len (vv )))
140
+ minloss = min ([v [2 ] for v in vv ])
141
+ for v in vv :
142
+ if v [2 ] < minloss + loss_range :
143
+ step = v [1 ]
144
+ val_losses [step ].append (v [2 ])
145
+
146
+ vv = events_acc .Scalars ('train_loss' )
147
+ minloss = min ([v [2 ] for v in vv ])
148
+ print ('Found {} train_loss entries.' .format (len (vv )))
149
+ for v in vv :
150
+ if v [2 ] < minloss + loss_range :
151
+ step = v [1 ]
152
+ train_losses [step ].append (v [2 ])
153
+
154
+ print ('Aggregating data...' )
155
+
156
+ val_loss = aggregate_dict (val_losses , 'min' )
157
+ x , y = dict_to_xy (val_loss )
158
+ ax_val_loss .plot (x , y , label = root_dir )
159
+
160
+ train_loss = aggregate_dict (train_losses , 'min' )
161
+ x , y = dict_to_xy (train_loss )
162
+ ax_train_loss .plot (x , y , label = root_dir )
163
+
164
+ print ('Finished aggregating data.' )
165
+
166
+ ordo_file = find_ordo_file (user_root_dir )
151
167
if ordo_file :
152
168
print ('Found ordo file {}' .format (ordo_file ))
153
169
if ax_elo is None :
154
170
ax_elo = fig .add_subplot (313 )
155
171
ax_elo .set_xlabel ('epoch' )
156
- ax_elo .set_ylabel ('elo' )
157
- rows = parse_ordo_file (ordo_file )
158
- rows = sorted (rows , key = lambda x :x [1 ])
159
- epochs = []
160
- elos = []
161
- errors = []
162
- maxelo = max ([row [2 ] for row in rows ])
163
- for row in rows :
164
- epoch = row [1 ]
165
- elo = row [2 ]
166
- error = row [3 ]
167
- if not epoch in epochs :
168
- if elo > maxelo - elo_range :
169
- epochs .append (epoch )
170
- elos .append (elo )
171
- errors .append (error )
172
-
173
- print ('Found ordo data for {} epochs' .format (len (epochs )))
174
-
175
- ax_elo .errorbar (epochs , elos , yerr = errors , label = root_dir )
172
+ ax_elo .set_ylabel ('Elo' )
173
+
174
+ for root_dir in split_root_dirs :
175
+ rows = parse_ordo_file (ordo_file , root_dir if split else "nnue" )
176
+ if len (rows ) == 0 :
177
+ continue
178
+ rows = sorted (rows , key = lambda x :x [1 ])
179
+ epochs = []
180
+ elos = []
181
+ errors = []
182
+ maxelo = max ([row [2 ] for row in rows ])
183
+ for row in rows :
184
+ epoch = row [1 ]
185
+ elo = row [2 ]
186
+ error = row [3 ]
187
+ if not epoch in epochs :
188
+ if elo > maxelo - elo_range :
189
+ epochs .append (epoch )
190
+ elos .append (elo )
191
+ errors .append (error )
192
+
193
+ print ('Found ordo data for {} epochs' .format (len (epochs )))
194
+
195
+ ax_elo .errorbar (epochs , elos , yerr = errors , label = root_dir )
176
196
177
197
else :
178
198
print ('Did not find ordo file. Skipping.' )
@@ -219,10 +239,14 @@ def main():
219
239
default = 0.004 ,
220
240
help = "Limit loss data shown to the best result + loss_range" ,
221
241
)
242
+ parser .add_argument ("--split" ,
243
+ action = 'store_true' ,
244
+ help = "Split the root dirs provided, assumes the ordo file is still at the root, and nets in that ordo file match root_dir/sub_dir/" ,
245
+ )
222
246
args = parser .parse_args ()
223
247
224
248
print (args .root_dirs )
225
- do_plots (args .output , args .root_dirs , elo_range = args .elo_range , loss_range = args .loss_range )
249
+ do_plots (args .output , args .root_dirs , elo_range = args .elo_range , loss_range = args .loss_range , split = args . split )
226
250
227
251
if __name__ == '__main__' :
228
252
main ()
0 commit comments