Skip to content

Commit 41b1ad7

Browse files
committed
Add --split option to split root dirs over various sub dirs
1 parent d40a890 commit 41b1ad7

File tree

1 file changed

+89
-65
lines changed

1 file changed

+89
-65
lines changed

do_plots.py

Lines changed: 89 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ def dict_to_xy(d):
5959
y.append(v)
6060
return x, y
6161

62-
def parse_ordo_file(filename):
62+
def parse_ordo_file(filename, label):
6363
p = re.compile('.*nn-epoch(\\d*)\\.nnue')
6464
with open(filename, 'r') as ordo_file:
6565
rows = []
6666
lines = ordo_file.readlines()
6767
for line in lines:
68-
if 'nn-epoch' in line:
68+
if 'nn-epoch' in line and label in line:
6969
fields = line.split()
7070
net = fields[1]
7171
epoch = int(p.match(net)[1])
@@ -78,7 +78,7 @@ def parse_ordo_file(filename):
7878
def transpose_list_of_tuples(l):
7979
return list(map(list, zip(*l)))
8080

81-
def do_plots(out_filename, root_dirs, elo_range, loss_range):
81+
def do_plots(out_filename, root_dirs, elo_range, loss_range, split):
8282
'''
8383
1. Find tfevents files for each root directory
8484
2. Look for metrics
@@ -107,72 +107,92 @@ def do_plots(out_filename, root_dirs, elo_range, loss_range):
107107
ax_train_loss.set_xlabel('step')
108108
ax_train_loss.set_ylabel('train_loss')
109109

110-
for root_dir in root_dirs:
111-
print('Processing root_dir {}'.format(root_dir))
112-
tfevents_files = find_event_files(root_dir)
113-
print('Found {} tfevents files.'.format(len(tfevents_files)))
114-
115-
val_losses = collections.defaultdict(lambda: [])
116-
train_losses = collections.defaultdict(lambda: [])
117-
for i, tfevents_file in enumerate(tfevents_files):
118-
print('Processing tfevents file {}/{}: {}'.format(i+1, len(tfevents_files), tfevents_file))
119-
events_acc = EventAccumulator(tfevents_file, tf_size_guidance)
120-
events_acc.Reload()
121-
122-
vv = events_acc.Scalars('val_loss')
123-
print('Found {} val_loss entries.'.format(len(vv)))
124-
minloss = min([v[2] for v in vv])
125-
for v in vv:
126-
if v[2] < minloss + loss_range:
127-
step = v[1]
128-
val_losses[step].append(v[2])
129-
130-
vv = events_acc.Scalars('train_loss')
131-
minloss = min([v[2] for v in vv])
132-
print('Found {} train_loss entries.'.format(len(vv)))
133-
for v in vv:
134-
if v[2] < minloss + loss_range:
135-
step = v[1]
136-
train_losses[step].append(v[2])
137-
138-
print('Aggregating data...')
139-
140-
val_loss = aggregate_dict(val_losses, 'min')
141-
x, y = dict_to_xy(val_loss)
142-
ax_val_loss.plot(x, y, label=root_dir)
143-
144-
train_loss = aggregate_dict(train_losses, 'min')
145-
x, y = dict_to_xy(train_loss)
146-
ax_train_loss.plot(x, y, label=root_dir)
147-
148-
print('Finished aggregating data.')
149-
150-
ordo_file = find_ordo_file(root_dir)
110+
111+
for user_root_dir in root_dirs:
112+
113+
# if asked to split we split the roto dir into a number of user root dirs,
114+
# i.e. all direct subdirectories containing tfevent files.
115+
# we use the ordo file in the root dir, but split the content.
116+
split_root_dirs = [user_root_dir]
117+
if split:
118+
split_root_dirs = []
119+
for item in os.listdir(user_root_dir):
120+
if os.path.isdir(os.path.join(user_root_dir, item)):
121+
root_dir = os.path.join(user_root_dir, item)
122+
if len(find_event_files(root_dir)) > 0:
123+
split_root_dirs.append(root_dir)
124+
split_root_dirs.sort()
125+
126+
for root_dir in split_root_dirs:
127+
print('Processing root_dir {}'.format(root_dir))
128+
tfevents_files = find_event_files(root_dir)
129+
print('Found {} tfevents files.'.format(len(tfevents_files)))
130+
131+
val_losses = collections.defaultdict(lambda: [])
132+
train_losses = collections.defaultdict(lambda: [])
133+
for i, tfevents_file in enumerate(tfevents_files):
134+
print('Processing tfevents file {}/{}: {}'.format(i+1, len(tfevents_files), tfevents_file))
135+
events_acc = EventAccumulator(tfevents_file, tf_size_guidance)
136+
events_acc.Reload()
137+
138+
vv = events_acc.Scalars('val_loss')
139+
print('Found {} val_loss entries.'.format(len(vv)))
140+
minloss = min([v[2] for v in vv])
141+
for v in vv:
142+
if v[2] < minloss + loss_range:
143+
step = v[1]
144+
val_losses[step].append(v[2])
145+
146+
vv = events_acc.Scalars('train_loss')
147+
minloss = min([v[2] for v in vv])
148+
print('Found {} train_loss entries.'.format(len(vv)))
149+
for v in vv:
150+
if v[2] < minloss + loss_range:
151+
step = v[1]
152+
train_losses[step].append(v[2])
153+
154+
print('Aggregating data...')
155+
156+
val_loss = aggregate_dict(val_losses, 'min')
157+
x, y = dict_to_xy(val_loss)
158+
ax_val_loss.plot(x, y, label=root_dir)
159+
160+
train_loss = aggregate_dict(train_losses, 'min')
161+
x, y = dict_to_xy(train_loss)
162+
ax_train_loss.plot(x, y, label=root_dir)
163+
164+
print('Finished aggregating data.')
165+
166+
ordo_file = find_ordo_file(user_root_dir)
151167
if ordo_file:
152168
print('Found ordo file {}'.format(ordo_file))
153169
if ax_elo is None:
154170
ax_elo = fig.add_subplot(313)
155171
ax_elo.set_xlabel('epoch')
156-
ax_elo.set_ylabel('elo')
157-
rows = parse_ordo_file(ordo_file)
158-
rows = sorted(rows, key=lambda x:x[1])
159-
epochs = []
160-
elos = []
161-
errors = []
162-
maxelo = max([row[2] for row in rows])
163-
for row in rows:
164-
epoch = row[1]
165-
elo = row[2]
166-
error = row[3]
167-
if not epoch in epochs:
168-
if elo > maxelo - elo_range:
169-
epochs.append(epoch)
170-
elos.append(elo)
171-
errors.append(error)
172-
173-
print('Found ordo data for {} epochs'.format(len(epochs)))
174-
175-
ax_elo.errorbar(epochs, elos, yerr=errors, label=root_dir)
172+
ax_elo.set_ylabel('Elo')
173+
174+
for root_dir in split_root_dirs:
175+
rows = parse_ordo_file(ordo_file, root_dir if split else "nnue")
176+
if len(rows) == 0:
177+
continue
178+
rows = sorted(rows, key=lambda x:x[1])
179+
epochs = []
180+
elos = []
181+
errors = []
182+
maxelo = max([row[2] for row in rows])
183+
for row in rows:
184+
epoch = row[1]
185+
elo = row[2]
186+
error = row[3]
187+
if not epoch in epochs:
188+
if elo > maxelo - elo_range:
189+
epochs.append(epoch)
190+
elos.append(elo)
191+
errors.append(error)
192+
193+
print('Found ordo data for {} epochs'.format(len(epochs)))
194+
195+
ax_elo.errorbar(epochs, elos, yerr=errors, label=root_dir)
176196

177197
else:
178198
print('Did not find ordo file. Skipping.')
@@ -219,10 +239,14 @@ def main():
219239
default=0.004,
220240
help="Limit loss data shown to the best result + loss_range",
221241
)
242+
parser.add_argument("--split",
243+
action='store_true',
244+
help="Split the root dirs provided, assumes the ordo file is still at the root, and nets in that ordo file match root_dir/sub_dir/",
245+
)
222246
args = parser.parse_args()
223247

224248
print(args.root_dirs)
225-
do_plots(args.output, args.root_dirs, elo_range = args.elo_range, loss_range = args.loss_range)
249+
do_plots(args.output, args.root_dirs, elo_range = args.elo_range, loss_range = args.loss_range, split = args.split)
226250

227251
if __name__ == '__main__':
228252
main()

0 commit comments

Comments
 (0)