in vihds/config.py [0:0]
def __init__(self, args):
args = _tidy_args(args)
if args.yaml is None:
return None
with open(args.yaml, "r") as stream:
config = munchify(yaml.safe_load(stream))
# return Munch.fromYAML(stream)
self.data = apply_defaults_data(config.data)
# self.models = config.models
# self.experiments = {}
# for node, data_settings in config.experiments.items():
# self.experiments[node] = apply_defaults_data(data_settings)
self.params = apply_defaults_params(config.params)
if args.precision_hidden_layers is not None:
self.params.n_hidden_decoder_precisions = args.precision_hidden_layers
self.model = config.model
self.seed = args.seed
if (args.gpu is not None) & torch.cuda.is_available():
print("- GPU mode computation")
self.device = torch.device("cuda:" + str(args.gpu))
if self.data.dtype == "float32":
torch.set_default_tensor_type("torch.cuda.FloatTensor")
elif self.data.dtype == "float64":
torch.set_default_tensor_type("torch.cuda.DoubleTensor")
else:
raise Exception("Unknown dtype %s" % self.data.dtype)
else:
print("- CPU mode computation")
self.device = torch.device("cpu")
if self.data.dtype == "float32":
torch.set_default_tensor_type("torch.FloatTensor")
elif self.data.dtype == "float64":
torch.set_default_tensor_type("torch.DoubleTensor")
else:
raise Exception("Unknown dtype %s" % self.data.dtype)
self.trainer = None # Trainer(args, log_dir=log_dir, add_timestamp=True)