def __call__()

in vision_charts/recon.py [0:0]


	def __call__(self) -> float:
		# initial data
		if  self.args.GEOmetrics:
			self.adj_info, initial_positions = utils.load_mesh_vision(self.args, f'../data/sphere.obj')
		else:
			self.adj_info, initial_positions = utils.load_mesh_vision(self.args, f'../data/vision_sheets.obj')
		self.encoder = models.Encoder(self.adj_info, Variable(initial_positions.cuda()), self.args)
		self.encoder.cuda()
		params = list(self.encoder.parameters())
		self.optimizer = optim.Adam(params, lr=self.args.lr, weight_decay=0)

		writer = SummaryWriter(os.path.join('experiments/tensorboard/', self.args.exp_type ))
		train_loader, valid_loaders = self.get_loaders()

		if self.args.eval:
			if self.args.pretrained != 'no':
				self.load_pretrained()
			else:
				self.load('')
			with torch.no_grad():
				self.validate(valid_loaders, writer)
			exit()
		# training loop
		for epoch in range(3000):
			self.epoch = epoch
			self.train(train_loader, writer)
			with torch.no_grad():
				self.validate(valid_loaders, writer)
			self.check_values()