in ssl/hltm/simCLR_hltm.py [0:0]
def main(args):
# torch.backends.cudnn.benchmark = True
log.info(f"Working dir: {os.getcwd()}")
log.info("\n" + common_utils.get_git_hash())
log.info("\n" + common_utils.get_git_diffs())
common_utils.set_all_seeds(args.seed)
log.info(args.pretty())
# Setup latent variables.
root = Latent("s", args)
iterator = LatentIterator(root)
dataset = SimpleDataset(iterator, args.N)
bs = args.batchsize
loader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=True, num_workers=4)
load_saved = False
need_train = True
if load_saved:
model = torch.load(args.save_file)
else:
model = Model(iterator, args.hid, args.eps)
model.cuda()
if need_train:
loss_func = nn.MSELoss().cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr)
stats = common_utils.MultiCounter("./")
stats_corrs = { parent.name : common_utils.StatsCorr() for parent in iterator.top_down() }
for i in range(args.num_epoch):
for _, v in stats_corrs.items():
v.reset()
connections = dict()
n_samples = [0, 0]
with torch.no_grad():
for batch in tqdm.tqdm(loader, total=int(len(loader))):
d = batch2dict(batch)
label = d["x_label"]
f, d_hs, d_inputs = model(d["x"])
Js = model.computeJ(d_hs)
# Correlation.
# batch_size * K
d_gt = dataset.split_generated(d["x_all"])
for v in iterator.top_down():
name = v.name
stats_corrs[name].add(d_gt[name].unsqueeze(1).float(), d_hs[name])
J = Js[name]
inputs = d_inputs[name].detach()
conn = torch.einsum("ia,ibc->iabc", inputs, J)
conn = conn.view(conn.size(0), conn.size(1)*conn.size(2), conn.size(3))
# group by labels.
conn0 = conn[label == 0, :, :].sum(dim=0)
conn1 = conn[label == 1, :, :].sum(dim=0)
conns = torch.stack([ conn0, conn1 ])
# Accumulate connection.
if name in connections:
connections[name] += conns
else:
connections[name] = conns
for j in range(2):
n_samples[j] += (label == j).sum().item()
json_result = dict(epoch=i)
for name in connections.keys():
conns = connections[name]
n_total_sample = n_samples[0] + n_samples[1]
avg_conn = conns.sum(dim=0) / n_total_sample
cov_op = torch.zeros(avg_conn.size(0), avg_conn.size(0)).to(avg_conn.device)
for j in range(2):
conns[j,:,:] /= n_samples[j]
diff = conns[j,:,:] - avg_conn
cov_op += diff @ diff.t() * n_samples[j] / n_total_sample
dd = cov_op.size(0)
json_result["conn_" + name] = dict(size=dd, norm=cov_op.norm().item())
json_result["weight_norm_" + name] = model.nets[name].weight.norm().item()
layer_avgs = [ [0,0] for j in range(args.depth + 1) ]
for p in iterator.top_down():
corr = stats_corrs[p.name].get()["corr"]
# Note that we need to take absolute value (since -1/+1 are both good)
res = common_utils.corr_summary(corr.abs())
best = res["best_corr"].item()
json_result["best_corr_" + p.name] = best
layer_avgs[p.depth][0] += best
layer_avgs[p.depth][1] += 1
# Check average correlation for each layer
# log.info("CovOp norm at every location:")
for d, (sum_corr, n) in enumerate(layer_avgs):
if n > 0:
log.info(f"[{d}] Mean of the best corr: {sum_corr/n:.3f} [{n}]")
log.info(f"json_str: {json.dumps(json_result)}")
# Training
stats.reset()
for batch in tqdm.tqdm(loader, total=int(len(loader))):
optimizer.zero_grad()
d = batch2dict(batch)
f, _, _ = model(d["x"])
f_pos, _, _ = model(d["x_pos"])
f_neg, _, _ = model(d["x_neg"])
pos_loss = loss_func(f, f_pos)
neg_loss = loss_func(f, f_neg)
# import pdb
# pdb.set_trace()
if args.loss == "nce":
loss = -(-pos_loss / args.temp).exp() / ( (-pos_loss / args.temp).exp() + (-neg_loss / args.temp).exp())
elif args.loss == "subtract":
loss = pos_loss - neg_loss
else:
raise NotImplementedError(f"{args.loss} is unknown")
#loss = pos_loss.exp() / ( pos_loss.exp() + neg_loss.exp())
stats["train_loss"].feed(loss.detach().item())
loss.backward()
optimizer.step()
log.info("\n" + stats.summary(i))
'''
measures = generator.check(model.linear1.weight.detach())
for k, v in measures.items():
for vv in v:
stats["stats_" + k].feed(vv)
'''
# log.info(f"\n{best_corrs}\n")
torch.save(model, args.save_file)