def main()

in student_specialization/recon_multilayer.py [0:0]


def main(args):
    cmd_line = " ".join(sys.argv)
    log.info(f"{cmd_line}")
    log.info(f"Working dir: {os.getcwd()}")
    set_all_seeds(args.seed)

    ks = args.ks
    lrs = parse_lr(args.lr)

    if args.perturb is not None or args.same_dir or args.same_sign:
        args.node_multi = 1

    if args.load_student is not None:
        args.num_trial = 1

    d, d_output, train_dataset, eval_dataset = init_dataset(args)

    if args.total_bp_iters > 0 and isinstance(train_dataset, RandomDataset):
        args.num_epoch = args.total_bp_iters / args.random_dataset_size
        if args.num_epoch != int(args.num_epoch):
            raise RuntimeError(f"random_dataset_size [{args.random_dataset_size}] cannot devide total_bp_iters [{args.total_bp_iters}]")

        args.num_epoch = int(args.num_epoch)
        log.info(f"#Epoch is now set to {args.num_epoch}")

    # ks = [5, 6, 7, 8]
    # ks = [10, 15, 20, 25]
    # ks = [50, 75, 100, 125]

    # ks = [50, 75, 100, 125]
    log.info(args.pretty())
    log.info(f"ks: {ks}")
    log.info(f"lr: {lrs}")

    if args.d_output > 0:
        d_output = args.d_output 

    log.info(f"d_output: {d_output}") 

    if not args.use_cnn:
        teacher = Model(d[0], ks, d_output, 
                has_bias=not args.no_bias, has_bn=args.teacher_bn, has_bn_affine=args.teacher_bn_affine, bn_before_relu=args.bn_before_relu, leaky_relu=args.leaky_relu).cuda()

    else:
        teacher = ModelConv(d, ks, d_output, has_bn=args.teacher_bn, bn_before_relu=args.bn_before_relu, leaky_relu=args.leaky_relu).cuda()

    if args.load_teacher is not None:
        log.info("Loading teacher from: " + args.load_teacher)
        checkpoint = torch.load(args.load_teacher)
        teacher.load_state_dict(checkpoint['net'])

        if "inactive_nodes" in checkpoint: 
            inactive_nodes = checkpoint["inactive_nodes"]
            masks = checkpoint["masks"]
            ratios = checkpoint["ratios"]
            inactive_nodes2, masks2 = prune(teacher, ratios)

            for m, m2 in zip(masks, masks2):
                if (m - m2).norm() > 1e-3:
                    print(m)
                    print(m2)
                    raise RuntimeError("New mask is not the same as old mask")

            for inactive, inactive2 in zip(inactive_nodes, inactive_nodes2):
                if set(inactive) != set(inactive2):
                    raise RuntimeError("New inactive set is not the same as old inactive set")

            # Make sure the last layer is normalized. 
            # teacher.normalize_last()
            # teacher.final_w.weight.data /= 3
            # teacher.final_w.bias.data /= 3
            active_nodes = [ [ kk for kk in range(k) if kk not in a ] for a, k in zip(inactive_nodes, ks) ]
            active_ks = [ len(a) for a in active_nodes ]
        else:
            active_nodes = None
            active_ks = ks
        
    else:
        log.info("Init teacher..")
        teacher.init_w(use_sep = not args.no_sep, weight_choices=list(args.weight_choices))
        if args.teacher_strength_decay > 0: 
            # Prioritize teacher node.
            teacher.prioritize(args.teacher_strength_decay)
        
        teacher.normalize()
        log.info("Teacher weights initiailzed randomly...")
        active_nodes = None
        active_ks = ks

    log.info(f"Active ks: {active_ks}")

    if args.load_student is None:
        if not args.use_cnn:
            student = Model(d[0], active_ks, d_output, 
                            multi=args.node_multi, 
                            has_bias=not args.no_bias, has_bn=args.bn, has_bn_affine=args.bn_affine, bn_before_relu=args.bn_before_relu).cuda()
        else:
            student = ModelConv(d, active_ks, d_output, multi=args.node_multi, has_bn=args.bn, bn_before_relu=args.bn_before_relu).cuda()


        # student can start with smaller norm. 
        student.scale(args.student_scale_down)

    # Specify some teacher structure.
    '''
    teacher.w0.weight.data.zero_()
    span = d // ks[0]
    for i in range(ks[0]):
        teacher.w0.weight.data[i, span*i:span*i+span] = 1
    '''

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, num_workers=4)
    eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=args.eval_batchsize, shuffle=True, num_workers=4)

    if args.teacher_bias_tune:
        teacher_tune.tune_teacher(eval_loader, teacher)
    if args.teacher_bias_last_layer_tune:
        teacher_tune.tune_teacher_last_layer(eval_loader, teacher)

    # teacher.w0.bias.data.uniform_(-1, 0)
    # teacher.init_orth()

    # init_w(teacher.w0)
    # init_w(teacher.w1)
    # init_w(teacher.w2)

    # init_w2(teacher.w0, multiplier=args.init_multi)
    # init_w2(teacher.w1, multiplier=args.init_multi)
    # init_w2(teacher.w2, multiplier=args.init_multi)

    all_all_corrs = []

    log.info("=== Start ===")
    std = args.data_std

    stats_op = stats_operator.StatsCollector(teacher, student)

    # Compute Correlation between teacher and student activations. 
    stats_op.add_stat(stats_operator.StatsCorr, active_nodes=active_nodes, cnt_thres=0.9)

    if args.cross_entropy:
        stats_op.add_stat(stats_operator.StatsCELoss)

        loss = nn.CrossEntropyLoss().cuda()
        def loss_func(predicted, target):
            _, target_y = target.max(1)
            return loss(predicted, target_y)

    else:
        stats_op.add_stat(stats_operator.StatsL2Loss)
        loss_func = nn.MSELoss().cuda()

    # Duplicate training and testing. 
    eval_stats_op = deepcopy(stats_op)
    stats_op.label = "train"
    eval_stats_op.label = "eval"

    stats_op.add_stat(stats_operator.StatsGrad)
    stats_op.add_stat(stats_operator.StatsMemory)

    if args.stats_H:
        eval_stats_op.add_stat(stats_operator.StatsHs)

    # pickle.dump(model2numpy(teacher), open("weights_gt.pickle", "wb"), protocol=2)

    all_stats = []
    for i in range(args.num_trial):
        if args.load_student is None:
            log.info("=== Trial %d, std = %f ===" % (i, std))
            student.reset_parameters()
            # student = copy.deepcopy(student_clone)
            # student.set_teacher_sign(teacher, scale=1)
            if args.perturb is not None:
                student.set_teacher(teacher, args.perturb)
            if args.same_dir:
                student.set_teacher_dir(teacher)
            if args.same_sign:
                student.set_teacher_sign(teacher)

        else:
            log.info(f"Loading student {args.load_student}")
            student = torch.load(args.load_student)

        # init_corrs[-1] = predict_last_order(student, teacher, args)
        # alter_last_layer = predict_last_order(student, teacher, args)

        # import pdb
        # pdb.set_trace()

        stats = optimize(train_loader, eval_loader, teacher, student, loss_func, stats_op, eval_stats_op, args, lrs)
        all_stats.append(stats)

    torch.save(all_stats, "stats.pickle")

    # log.info("Student network")
    # log.info(student.w1.weight)
    # log.info("Teacher network")
    # log.info(teacher.w1.weight)
    log.info(f"Working dir: {os.getcwd()}")