in train_net.py [0:0]
def main(args):
cfg = setup(args)
if cfg.SEMISUPNET.Trainer == "ubteacher":
Trainer = UBTeacherTrainer
elif cfg.SEMISUPNET.Trainer == "baseline":
Trainer = BaselineTrainer
else:
raise ValueError("Trainer Name is not found.")
if args.eval_only:
if cfg.SEMISUPNET.Trainer == "ubteacher":
model = Trainer.build_model(cfg)
model_teacher = Trainer.build_model(cfg)
ensem_ts_model = EnsembleTSModel(model_teacher, model)
DetectionCheckpointer(
ensem_ts_model, save_dir=cfg.OUTPUT_DIR
).resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume)
res = Trainer.test(cfg, ensem_ts_model.modelTeacher)
else:
model = Trainer.build_model(cfg)
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
res = Trainer.test(cfg, model)
return res
trainer = Trainer(cfg)
trainer.resume_or_load(resume=args.resume)
return trainer.train()