in trainers/train_simplexes.py [0:0]
def train(models, writer, data_loader, optimizers, criterion, epoch):
model = models[0]
optimizer = optimizers[0]
model.zero_grad()
model.train()
avg_loss = 0.0
train_loader = data_loader.train_loader
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(args.device), target.to(args.device)
# To sample from a simplex, sample from an exponential distribution then renormalize.
if args.layerwise:
for m in model.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d):
Z = np.random.exponential(scale=1.0, size=args.n)
Z = Z / Z.sum()
for i in range(1, args.n):
setattr(m, f"t{i}", Z[i])
else:
Z = np.random.exponential(scale=1.0, size=args.n)
Z = Z / Z.sum()
for m in model.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.BatchNorm2d):
for i in range(1, args.n):
setattr(m, f"t{i}", Z[i])
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
if args.beta > 0:
out = random.sample([i for i in range(args.n)], 2)
i, j = out[0], out[1]
num = 0.0
normi = 0.0
normj = 0.0
for m in model.modules():
if isinstance(m, nn.Conv2d):
vi = get_weight(m, i)
vj = get_weight(m, j)
num += (vi * vj).sum()
normi += vi.pow(2).sum()
normj += vj.pow(2).sum()
loss += args.beta * (num.pow(2) / (normi * normj))
loss.backward()
optimizer.step()
avg_loss += loss.item()
it = len(train_loader) * epoch + batch_idx
if batch_idx % args.log_interval == 0:
num_samples = batch_idx * len(data)
num_epochs = len(train_loader.dataset)
percent_complete = 100.0 * batch_idx / len(train_loader)
print(
f"Train Epoch: {epoch} [{num_samples}/{num_epochs} ({percent_complete:.0f}%)]\t"
f"Loss: {loss.item():.6f}"
)
if args.save:
writer.add_scalar(f"train/loss", loss.item(), it)
if args.save and it in args.save_iters:
utils.save_cpt(epoch, it, models, optimizers, -1, -1)
avg_loss = avg_loss / len(train_loader)
return avg_loss, optimizers