in fairmotion/tasks/motion_prediction/training.py [0:0]
def train(args):
fairmotion_utils.create_dir_if_absent(args.save_model_path)
logging.info(args._get_kwargs())
utils.log_config(args.save_model_path, args)
set_seeds()
device = "cuda" if torch.cuda.is_available() else "cpu"
device = args.device if args.device else device
logging.info(f"Using device: {device}")
logging.info("Preparing dataset...")
dataset, mean, std = utils.prepare_dataset(
*[
os.path.join(args.preprocessed_path, f"{split}.pkl")
for split in ["train", "test", "validation"]
],
batch_size=args.batch_size,
device=device,
shuffle=args.shuffle,
)
# Loss per epoch is the average loss per sequence
num_training_sequences = len(dataset["train"]) * args.batch_size
# number of predictions per time step = num_joints * angle representation
# shape is (batch_size, seq_len, num_predictions)
_, tgt_len, num_predictions = next(iter(dataset["train"]))[1].shape
model = utils.prepare_model(
input_dim=num_predictions,
hidden_dim=args.hidden_dim,
device=device,
num_layers=args.num_layers,
architecture=args.architecture,
)
criterion = nn.MSELoss()
model.init_weights()
training_losses, val_losses = [], []
epoch_loss = 0
for iterations, (src_seqs, tgt_seqs) in enumerate(dataset["train"]):
model.eval()
src_seqs, tgt_seqs = src_seqs.to(device), tgt_seqs.to(device)
outputs = model(src_seqs, tgt_seqs, teacher_forcing_ratio=1,)
loss = criterion(outputs, tgt_seqs)
epoch_loss += loss.item()
epoch_loss = epoch_loss / num_training_sequences
val_loss = generate.eval(
model, criterion, dataset["validation"], args.batch_size, device,
)
logging.info(
"Before training: "
f"Training loss {epoch_loss} | "
f"Validation loss {val_loss}"
)
logging.info("Training model...")
torch.autograd.set_detect_anomaly(True)
opt = utils.prepare_optimizer(model, args.optimizer, args.lr)
for epoch in range(args.epochs):
epoch_loss = 0
model.train()
teacher_forcing_ratio = np.clip(
(1 - 2 * epoch / args.epochs), a_min=0, a_max=1,
)
logging.info(
f"Running epoch {epoch} | "
f"teacher_forcing_ratio={teacher_forcing_ratio}"
)
for iterations, (src_seqs, tgt_seqs) in enumerate(dataset["train"]):
opt.optimizer.zero_grad()
src_seqs, tgt_seqs = src_seqs.to(device), tgt_seqs.to(device)
outputs = model(
src_seqs, tgt_seqs, teacher_forcing_ratio=teacher_forcing_ratio
)
outputs = outputs.double()
loss = criterion(
outputs,
utils.prepare_tgt_seqs(args.architecture, src_seqs, tgt_seqs),
)
loss.backward()
opt.step()
epoch_loss += loss.item()
epoch_loss = epoch_loss / num_training_sequences
training_losses.append(epoch_loss)
val_loss = generate.eval(
model, criterion, dataset["validation"], args.batch_size, device,
)
val_losses.append(val_loss)
opt.epoch_step(val_loss=val_loss)
logging.info(
f"Training loss {epoch_loss} | "
f"Validation loss {val_loss} | "
f"Iterations {iterations + 1}"
)
if epoch % args.save_model_frequency == 0:
_, rep = os.path.split(args.preprocessed_path.strip("/"))
_, mae = test.test_model(
model=model,
dataset=dataset["validation"],
rep=rep,
device=device,
mean=mean,
std=std,
max_len=tgt_len,
)
logging.info(f"Validation MAE: {mae}")
torch.save(
model.state_dict(), f"{args.save_model_path}/{epoch}.model"
)
if len(val_losses) == 0 or val_loss <= min(val_losses):
torch.save(
model.state_dict(), f"{args.save_model_path}/best.model"
)
return training_losses, val_losses