in src/train.py [0:0]
def train_config(parser: BvhConverter, c: Train_Config, update: bool):
"""
:param parser: a parser that implements "load_numpy"
:param c: Definition of model, skeleton, training params and feature sets
:param update: renew training features - if False, last features will be loaded from disk
:param plot: plot prediction after training
:return:
"""
criterion = c.get_loss_fun()
par_indices = c.skeleton.parent_idx_vector()
bone_lengths_ = torch.from_numpy(get_bone_length_dataset(par_indices, parser)).to(c.device)
bone_offsets = get_bone_offsets_dataset_normalized(parser).to(c.device)
for epoch in range(c.epochs):
print("")
for _X, _y in create_training_data_config(parser, c, update, save_all_to_mem=True, shuffle_all=True,
dataset_name=parser.name(), test_set=False):
X = _X.to(c.device)
y = _y.to(c.device)
if c.y_scaler is None:
X_init = X[:, -1] if c.model.recurrent else X
c.x_scaler = c.scale(X_init, group_size=3, idx_end=c.in_features.features_wo_occlusions())
c.y_scaler = c.scale(y, group_size=3, idx_end=c.out_features.features_wo_occlusions())
y = c.y_scaler.transform(y)
X = c.x_scaler.transform_recurrent(X) if c.model.recurrent else c.x_scaler.transform(X)
if c.random_noise > 0:
X += torch.rand_like(X, device=X.device) * c.random_noise - torch.ones_like(X, device=X.device) * (c.random_noise * 0.5)
y += torch.rand_like(y, device=y.device) * c.random_noise - torch.ones_like(y, device=y.device) * (c.random_noise * 0.5)
X_train = X[:X.shape[0] * 8 // 10]
y_train = y[:y.shape[0] * 8 // 10]
X_val = X[X_train.shape[0]:]
y_val = y[y_train.shape[0]:]
batch_size = X_train.shape[0] if c.batch_size is None else c.batch_size
X_train, y_train = shuffle_X_y(X_train, y_train)
batches = max(X_train.shape[0] // batch_size, 1)
end_idx = X_train.shape[0] - (X_train.shape[0] % batch_size)
for batch in range(batches):
c.model.train()
c.optimizer.zero_grad()
X_train_batch = X_train[batch:end_idx:batches]
y_train_batch = y_train[batch:end_idx:batches]
if c.rotate_random > 0.0001:
X_train_batch, y_train_batch = rotate_features_random(X_train_batch, c, y_train_batch)
if c.translate_random > 0.0001:
X_train_batch, y_train_batch = translate_features_random(X_train_batch, c, y_train_batch)
if c.mask_noise is not None:
X_train_batch, y_train_batch = add_mask_noise(X_train_batch, c, y_train_batch)
if c.model.recurrent:
c.model.init_hidden(batch_size)
out = c.model(X_train_batch)
# how many features before the occlusion mask at the end
X_current_frame = X_train_batch[:, -1] if c.model.recurrent else X_train_batch
joint_mask = get_joint_mask(X_current_frame, c)
float_mask = get_float_mask(X_current_frame, c)
loss_mask = torch_tile(joint_mask, 1, c.in_features.features_wo_occlusions() // (joint_mask.shape[-1]))
loss = criterion(input=out, target=y_train_batch, mask=loss_mask, config=c, bone_lengths=bone_lengths_, body=X_current_frame,
bone_offsets=bone_offsets, float_mask=float_mask)
loss.backward()
c.optimizer.step()
out = c.model.post_loss_activation(out)
print(f"\rloss={loss.item()}", end="")
combine_out_finger_in_body = False
if combine_out_finger_in_body:
pred_pos = c.out_features.features_to_pose().solve_batch(out.detach(), body=X_current_frame, mask=joint_mask, ref_positions=y_train_batch,
bone_offsets=bone_offsets)
pred_pos = pred_pos.reshape((len(out), -1, 3))
from plots import Vispy3DScatter
vispy = Vispy3DScatter()
vispy.plot_skeleton_with_bones(pred_pos.cpu().numpy()[:, :, [0, 2, 1]], c.skeleton, fps=parser.dataset_fps / 30, speed=0.2)
c.model.eval()
if batch == batches - 1:
print_val_loss(epoch, c, criterion, X_val[:c.max_batch_size], y_val[:c.max_batch_size], loss, bone_lengths_, bone_offsets)
eval_pose_accuracy(X_val, c, y_val, bone_lengths_, bone_offsets)
update = False
c.model.save()