Skip to content

Commit 7d84e6d

Browse files
committed
Fix metadata.min.json
1 parent c4b4d32 commit 7d84e6d

File tree

3 files changed

+22
-2
lines changed

3 files changed

+22
-2
lines changed

evaluate.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import os
23

34
import toml
45
import torch
@@ -8,7 +9,7 @@
89
from metrics import AP, AR
910
from model import Batfd, BatfdPlus
1011
from post_process import post_process
11-
from utils import read_json
12+
from utils import generate_metadata_min, read_json
1213

1314
parser = argparse.ArgumentParser(description="BATFD evaluation")
1415
parser.add_argument("--config", type=str)
@@ -125,6 +126,10 @@ def evaluate_lavdf(config, args):
125126

126127
if __name__ == '__main__':
127128
args = parser.parse_args()
129+
130+
if os.path.exists(os.path.join(args.data_root, "metadata.min.json")):
131+
generate_metadata_min(args.data_root)
132+
128133
config = toml.load(args.config)
129134
torch.backends.cudnn.benchmark = True
130135
if config["dataset"] == "lavdf":

train.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import argparse
2+
import os
23

34
import toml
45
from pytorch_lightning import Trainer
56
from pytorch_lightning.callbacks import ModelCheckpoint
67

78
from dataset.lavdf import LavdfDataModule
89
from model import Batfd, BatfdPlus
9-
from utils import LrLogger, EarlyStoppingLR
10+
from utils import LrLogger, EarlyStoppingLR, generate_metadata_min
1011

1112
parser = argparse.ArgumentParser(description="BATFD training")
1213
parser.add_argument("--config", type=str)
@@ -24,6 +25,9 @@
2425
args = parser.parse_args()
2526
config = toml.load(args.config)
2627

28+
if os.path.exists(os.path.join(args.data_root, "metadata.min.json")):
29+
generate_metadata_min(args.data_root)
30+
2731
learning_rate = config["optimizer"]["learning_rate"]
2832
gpus = args.gpus
2933
total_batch_size = args.batch_size * gpus

utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
from importlib import metadata
12
import json
3+
import os
24
import re
35
from abc import ABC
46
from typing import List, Tuple, Optional
@@ -219,3 +221,12 @@ def _run_early_stop_checking(self, trainer: Trainer) -> None:
219221
elif self.mode == "any":
220222
if any(lr <= self.lr_threshold for lr in all_lr):
221223
trainer.should_stop = True
224+
225+
226+
def generate_metadata_min(data_root: str):
227+
metadata_full = read_json(os.path.join(data_root, "metadata.json"))
228+
metadata_min = []
229+
for meta in metadata_full:
230+
del meta["timestamps"]
231+
with open(os.path.join(data_root, "metadata.min.json"), "w") as f:
232+
json.dump(metadata_min, f)

0 commit comments

Comments
 (0)