File tree 3 files changed +22
-2
lines changed
3 files changed +22
-2
lines changed Original file line number Diff line number Diff line change 1
1
import argparse
2
+ import os
2
3
3
4
import toml
4
5
import torch
8
9
from metrics import AP , AR
9
10
from model import Batfd , BatfdPlus
10
11
from post_process import post_process
11
- from utils import read_json
12
+ from utils import generate_metadata_min , read_json
12
13
13
14
parser = argparse .ArgumentParser (description = "BATFD evaluation" )
14
15
parser .add_argument ("--config" , type = str )
@@ -125,6 +126,10 @@ def evaluate_lavdf(config, args):
125
126
126
127
if __name__ == '__main__' :
127
128
args = parser .parse_args ()
129
+
130
+ if os .path .exists (os .path .join (args .data_root , "metadata.min.json" )):
131
+ generate_metadata_min (args .data_root )
132
+
128
133
config = toml .load (args .config )
129
134
torch .backends .cudnn .benchmark = True
130
135
if config ["dataset" ] == "lavdf" :
Original file line number Diff line number Diff line change 1
1
import argparse
2
+ import os
2
3
3
4
import toml
4
5
from pytorch_lightning import Trainer
5
6
from pytorch_lightning .callbacks import ModelCheckpoint
6
7
7
8
from dataset .lavdf import LavdfDataModule
8
9
from model import Batfd , BatfdPlus
9
- from utils import LrLogger , EarlyStoppingLR
10
+ from utils import LrLogger , EarlyStoppingLR , generate_metadata_min
10
11
11
12
parser = argparse .ArgumentParser (description = "BATFD training" )
12
13
parser .add_argument ("--config" , type = str )
24
25
args = parser .parse_args ()
25
26
config = toml .load (args .config )
26
27
28
+ if os .path .exists (os .path .join (args .data_root , "metadata.min.json" )):
29
+ generate_metadata_min (args .data_root )
30
+
27
31
learning_rate = config ["optimizer" ]["learning_rate" ]
28
32
gpus = args .gpus
29
33
total_batch_size = args .batch_size * gpus
Original file line number Diff line number Diff line change
1
+ from importlib import metadata
1
2
import json
3
+ import os
2
4
import re
3
5
from abc import ABC
4
6
from typing import List , Tuple , Optional
@@ -219,3 +221,12 @@ def _run_early_stop_checking(self, trainer: Trainer) -> None:
219
221
elif self .mode == "any" :
220
222
if any (lr <= self .lr_threshold for lr in all_lr ):
221
223
trainer .should_stop = True
224
+
225
+
226
+ def generate_metadata_min (data_root : str ):
227
+ metadata_full = read_json (os .path .join (data_root , "metadata.json" ))
228
+ metadata_min = []
229
+ for meta in metadata_full :
230
+ del meta ["timestamps" ]
231
+ with open (os .path .join (data_root , "metadata.min.json" ), "w" ) as f :
232
+ json .dump (metadata_min , f )
You can’t perform that action at this time.
0 commit comments