in trainers/train_one_dim_subspaces.py [0:0]
def train(models, writer, data_loader, optimizers, criterion, epoch):
# We consider only a single model here. Multiple models are for ensembles and SWA baselines.
model = models[0]
optimizer = optimizers[0]
if args.num_samples > 1:
model.apply(lambda m: setattr(m, "return_feats", True))
model.zero_grad()
model.train()
avg_loss = 0.0
train_loader = data_loader.train_loader
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(args.device), target.to(args.device)
# num_samples is the number of samples to draw from the subspace for batch.
# in all experiments in the main paper it is 1.
if args.num_samples == 1:
if args.layerwise:
for m in model.modules():
if isinstance(m, nn.Conv2d) or isinstance(
m, nn.BatchNorm2d
):
alpha = np.random.uniform(0, 1)
setattr(m, f"alpha", alpha)
else:
alpha = np.random.uniform(0, 1)
for m in model.modules():
if isinstance(m, nn.Conv2d) or isinstance(
m, nn.BatchNorm2d
):
setattr(m, f"alpha", alpha)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
else:
# Feel free to ignore this part as it is often not used.
# This corresponds to section B of the appendix, where multiple samples from the subsapce are
# taken for each batch.
div = data.size(0) // args.num_samples
feats = []
ts = []
optimizer.zero_grad()
for sample in range(args.num_samples):
if args.layerwise:
for m in model.modules():
if isinstance(m, nn.Conv2d) or isinstance(
m, nn.BatchNorm2d
):
alpha = np.random.uniform(0, 1)
setattr(m, f"alpha", alpha)
else:
alpha = np.random.uniform(0, 1)
for m in model.modules():
if isinstance(m, nn.Conv2d) or isinstance(
m, nn.BatchNorm2d
):
setattr(m, f"alpha", alpha)
output, f = model(data[sample * div : (sample + 1) * div])
feats.append(f)
if sample == 0:
loss = (
criterion(
output, target[sample * div : (sample + 1) * div]
)
/ args.num_samples
)
else:
loss += (
criterion(
output, target[sample * div : (sample + 1) * div]
)
/ args.num_samples
)
if args.lamb > 0:
out = random.sample([i for i in range(args.num_samples)], 2)
i, j = out[0], out[1]
fi, fj = feats[i], feats[j]
ti, tj = ts[i], ts[j]
loss += (
args.fcos_weight
* abs(ti - tj)
* (
(fi * fj).sum().pow(2)
/ (fi.pow(2).sum() * fj.pow(2).sum())
)
)
# Application of the regularization term, equation 3.
num_points = 2 if args.conv_type is "LinesConv" else 3
if args.beta > 0:
out = random.sample([i for i in range(num_points)], 2)
i, j = out[0], out[1]
num = 0.0
normi = 0.0
normj = 0.0
for m in model.modules():
if isinstance(m, nn.Conv2d):
vi = get_weight(m, i)
vj = get_weight(m, j)
num += (vi * vj).sum()
normi += vi.pow(2).sum()
normj += vj.pow(2).sum()
loss += args.beta * (num.pow(2) / (normi * normj))
loss.backward()
optimizer.step()
avg_loss += loss.item()
it = len(train_loader) * epoch + batch_idx
if batch_idx % args.log_interval == 0:
num_samples = batch_idx * len(data)
num_epochs = len(train_loader.dataset)
percent_complete = 100.0 * batch_idx / len(train_loader)
print(
f"Train Epoch: {epoch} [{num_samples}/{num_epochs} ({percent_complete:.0f}%)]\t"
f"Loss: {loss.item():.6f}"
)
if args.save:
writer.add_scalar(f"train/loss", loss.item(), it)
if args.save and it in args.save_iters:
utils.save_cpt(epoch, it, models, optimizers, -1, -1)
model.apply(lambda m: setattr(m, "return_feats", False))
avg_loss = avg_loss / len(train_loader)
return avg_loss, optimizers