in main.py [0:0]
def run(args):
command = args.command
assert command in VALID_COMMAND, \
"%s is not a valid command." % command
_set_seed(args.seed)
config = yaml.safe_load(open(args.config, 'r'))
if command == 'train-oracle':
kwargs = dict()
mode = 'train'
from solver.oracle import OracleSolver as Solver
elif command == 'test-oracle':
kwargs = dict()
mode = 'test'
from solver.oracle import OracleSolver as Solver
elif command == 'train-oracle-rcnn':
kwargs = dict()
mode = 'train'
from solver.oracle_rcnn import OracleSolver as Solver
elif command == 'test-oracle-rcnn':
kwargs = dict()
mode = 'test'
from solver.oracle_rcnn import OracleSolver as Solver
elif command == 'train-oracle-vilbert':
kwargs = dict()
mode = 'train'
from solver.oracle_vilbert import OracleSolver as Solver
elif command == 'test-oracle-vilbert':
kwargs = dict()
mode = 'test'
from solver.oracle_vilbert import OracleSolver as Solver
elif command == 'train-guesser':
kwargs = dict()
mode = 'train'
from solver.guesser import GuesserSolver as Solver
elif command == 'test-guesser':
kwargs = dict()
mode = 'test'
from solver.guesser import GuesserSolver as Solver
elif command == 'train-guesser-vilbert':
kwargs = dict()
mode = 'train'
from solver.guesser_vilbert import GuesserSolver as Solver
elif command == 'test-guesser-vilbert':
kwargs = dict()
mode = 'test'
from solver.guesser_vilbert import GuesserSolver as Solver
elif command == 'train-qgen':
kwargs = dict()
mode = 'train'
from solver.qgen import QGenSolver as Solver
elif command == 'train-qgen-vdst':
kwargs = dict()
mode = 'train'
from solver.qgen_vdst import QGenSolver as Solver
elif command == 'train-qgen-vilbert':
kwargs = dict()
mode = 'train'
from solver.qgen_vilbert import QGenSolver as Solver
elif command == 'test-self-play':
kwargs = dict()
mode = 'test'
from solver.self_play import SelfPlaySolver as Solver
elif command == 'test-self-play-qgen-vdst':
kwargs = dict()
mode = 'test'
from solver.self_play_qgen_vdst import SelfPlaySolver as Solver
elif command == 'test-self-play-qgen-vilbert':
kwargs = dict()
mode = 'test'
from solver.self_play_qgen_vilbert import SelfPlaySolver as Solver
elif command == 'test-self-play-qgen-vdst-oracle-vilbert':
kwargs = dict()
mode = 'test'
from solver.self_play_qgen_vdst_oracle_vilbert import SelfPlaySolver as Solver
elif command == 'test-self-play-qgen-vdst-guesser-vilbert':
kwargs = dict()
mode = 'test'
from solver.self_play_qgen_vdst_guesser_vilbert import SelfPlaySolver as Solver
elif command == 'test-self-play-qgen-vdst-oracle-vilbert-guesser-vilbert':
kwargs = dict()
mode = 'test'
from solver.self_play_qgen_vdst_oracle_vilbert_guesser_vilbert import SelfPlaySolver as Solver
elif command == 'test-self-play-all-vilbert':
kwargs = dict()
mode = 'test'
from solver.self_play_all_vilbert import SelfPlaySolver as Solver
else:
raise NotImplementedError
solver = Solver(config, args, mode, **kwargs)
solver.load_data()
solver.set_model()
solver.exec()