def main()

in ssl/hltm/simCLR_hltm.py [0:0]


def main(args):
    # torch.backends.cudnn.benchmark = True
    log.info(f"Working dir: {os.getcwd()}")
    log.info("\n" + common_utils.get_git_hash())
    log.info("\n" + common_utils.get_git_diffs())

    common_utils.set_all_seeds(args.seed)
    log.info(args.pretty())

    # Setup latent variables. 
    root = Latent("s", args)
    iterator = LatentIterator(root)
    dataset = SimpleDataset(iterator, args.N)

    bs = args.batchsize
    loader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=True, num_workers=4)

    load_saved = False
    need_train = True

    if load_saved:
        model = torch.load(args.save_file)
    else:
        model = Model(iterator, args.hid, args.eps)
    model.cuda()

    if need_train:
        loss_func = nn.MSELoss().cuda()
        optimizer = optim.SGD(model.parameters(), lr=args.lr)

        stats = common_utils.MultiCounter("./")
        stats_corrs = { parent.name : common_utils.StatsCorr() for parent in iterator.top_down() }
        
        for i in range(args.num_epoch):
            for _, v in stats_corrs.items():
                v.reset()

            connections = dict()
            n_samples = [0, 0]

            with torch.no_grad():
                for batch in tqdm.tqdm(loader, total=int(len(loader))):
                    d = batch2dict(batch)
                    label = d["x_label"]
                    f, d_hs, d_inputs = model(d["x"])
                    Js = model.computeJ(d_hs)

                    # Correlation. 
                    # batch_size * K
                    d_gt = dataset.split_generated(d["x_all"])
                    for v in iterator.top_down():
                        name = v.name
                        stats_corrs[name].add(d_gt[name].unsqueeze(1).float(), d_hs[name])

                        J = Js[name]
                        inputs = d_inputs[name].detach()
                        conn = torch.einsum("ia,ibc->iabc", inputs, J)
                        conn = conn.view(conn.size(0), conn.size(1)*conn.size(2), conn.size(3))
                        # group by labels. 
                        conn0 = conn[label == 0, :, :].sum(dim=0)
                        conn1 = conn[label == 1, :, :].sum(dim=0)
                        conns = torch.stack([ conn0, conn1 ])

                        # Accumulate connection. 
                        if name in connections:
                            connections[name] += conns   
                        else:
                            connections[name] = conns

                    for j in range(2):
                        n_samples[j] += (label == j).sum().item()
            
            json_result = dict(epoch=i)
            for name in connections.keys():
                conns = connections[name]
                n_total_sample = n_samples[0] + n_samples[1]
                avg_conn = conns.sum(dim=0) / n_total_sample

                cov_op = torch.zeros(avg_conn.size(0), avg_conn.size(0)).to(avg_conn.device)

                for j in range(2):
                    conns[j,:,:] /= n_samples[j]
                    diff = conns[j,:,:] - avg_conn
                    cov_op += diff @ diff.t() * n_samples[j] / n_total_sample

                dd = cov_op.size(0)
                json_result["conn_" + name] = dict(size=dd, norm=cov_op.norm().item())
                json_result["weight_norm_" + name] = model.nets[name].weight.norm().item()

            layer_avgs = [ [0,0] for j in range(args.depth + 1) ]
            for p in iterator.top_down():
                corr = stats_corrs[p.name].get()["corr"]
                # Note that we need to take absolute value (since -1/+1 are both good)
                res = common_utils.corr_summary(corr.abs())
                best = res["best_corr"].item()
                json_result["best_corr_" + p.name] = best 

                layer_avgs[p.depth][0] += best 
                layer_avgs[p.depth][1] += 1 

            # Check average correlation for each layer
            # log.info("CovOp norm at every location:")
            for d, (sum_corr, n) in enumerate(layer_avgs):
                if n > 0:
                    log.info(f"[{d}] Mean of the best corr: {sum_corr/n:.3f} [{n}]")

            log.info(f"json_str: {json.dumps(json_result)}")

            # Training
            stats.reset()
            for batch in tqdm.tqdm(loader, total=int(len(loader))):
                optimizer.zero_grad()

                d = batch2dict(batch)

                f, _, _ = model(d["x"])
                f_pos, _, _ = model(d["x_pos"])
                f_neg, _, _ = model(d["x_neg"])

                pos_loss = loss_func(f, f_pos)
                neg_loss = loss_func(f, f_neg)

                # import pdb
                # pdb.set_trace()
                if args.loss == "nce":
                    loss = -(-pos_loss / args.temp).exp() / ( (-pos_loss / args.temp).exp() + (-neg_loss / args.temp).exp())
                elif args.loss == "subtract":
                    loss = pos_loss - neg_loss
                else:
                    raise NotImplementedError(f"{args.loss} is unknown")

                #loss = pos_loss.exp() / ( pos_loss.exp() + neg_loss.exp())
                stats["train_loss"].feed(loss.detach().item())

                loss.backward()
                optimizer.step()

            log.info("\n" + stats.summary(i))

            '''
            measures = generator.check(model.linear1.weight.detach())
            for k, v in measures.items():
                for vv in v:
                    stats["stats_" + k].feed(vv)
            '''
            
            # log.info(f"\n{best_corrs}\n")

    torch.save(model, args.save_file)