in student_specialization/recon_multilayer.py [0:0]
def main(args):
cmd_line = " ".join(sys.argv)
log.info(f"{cmd_line}")
log.info(f"Working dir: {os.getcwd()}")
set_all_seeds(args.seed)
ks = args.ks
lrs = parse_lr(args.lr)
if args.perturb is not None or args.same_dir or args.same_sign:
args.node_multi = 1
if args.load_student is not None:
args.num_trial = 1
d, d_output, train_dataset, eval_dataset = init_dataset(args)
if args.total_bp_iters > 0 and isinstance(train_dataset, RandomDataset):
args.num_epoch = args.total_bp_iters / args.random_dataset_size
if args.num_epoch != int(args.num_epoch):
raise RuntimeError(f"random_dataset_size [{args.random_dataset_size}] cannot devide total_bp_iters [{args.total_bp_iters}]")
args.num_epoch = int(args.num_epoch)
log.info(f"#Epoch is now set to {args.num_epoch}")
# ks = [5, 6, 7, 8]
# ks = [10, 15, 20, 25]
# ks = [50, 75, 100, 125]
# ks = [50, 75, 100, 125]
log.info(args.pretty())
log.info(f"ks: {ks}")
log.info(f"lr: {lrs}")
if args.d_output > 0:
d_output = args.d_output
log.info(f"d_output: {d_output}")
if not args.use_cnn:
teacher = Model(d[0], ks, d_output,
has_bias=not args.no_bias, has_bn=args.teacher_bn, has_bn_affine=args.teacher_bn_affine, bn_before_relu=args.bn_before_relu, leaky_relu=args.leaky_relu).cuda()
else:
teacher = ModelConv(d, ks, d_output, has_bn=args.teacher_bn, bn_before_relu=args.bn_before_relu, leaky_relu=args.leaky_relu).cuda()
if args.load_teacher is not None:
log.info("Loading teacher from: " + args.load_teacher)
checkpoint = torch.load(args.load_teacher)
teacher.load_state_dict(checkpoint['net'])
if "inactive_nodes" in checkpoint:
inactive_nodes = checkpoint["inactive_nodes"]
masks = checkpoint["masks"]
ratios = checkpoint["ratios"]
inactive_nodes2, masks2 = prune(teacher, ratios)
for m, m2 in zip(masks, masks2):
if (m - m2).norm() > 1e-3:
print(m)
print(m2)
raise RuntimeError("New mask is not the same as old mask")
for inactive, inactive2 in zip(inactive_nodes, inactive_nodes2):
if set(inactive) != set(inactive2):
raise RuntimeError("New inactive set is not the same as old inactive set")
# Make sure the last layer is normalized.
# teacher.normalize_last()
# teacher.final_w.weight.data /= 3
# teacher.final_w.bias.data /= 3
active_nodes = [ [ kk for kk in range(k) if kk not in a ] for a, k in zip(inactive_nodes, ks) ]
active_ks = [ len(a) for a in active_nodes ]
else:
active_nodes = None
active_ks = ks
else:
log.info("Init teacher..")
teacher.init_w(use_sep = not args.no_sep, weight_choices=list(args.weight_choices))
if args.teacher_strength_decay > 0:
# Prioritize teacher node.
teacher.prioritize(args.teacher_strength_decay)
teacher.normalize()
log.info("Teacher weights initiailzed randomly...")
active_nodes = None
active_ks = ks
log.info(f"Active ks: {active_ks}")
if args.load_student is None:
if not args.use_cnn:
student = Model(d[0], active_ks, d_output,
multi=args.node_multi,
has_bias=not args.no_bias, has_bn=args.bn, has_bn_affine=args.bn_affine, bn_before_relu=args.bn_before_relu).cuda()
else:
student = ModelConv(d, active_ks, d_output, multi=args.node_multi, has_bn=args.bn, bn_before_relu=args.bn_before_relu).cuda()
# student can start with smaller norm.
student.scale(args.student_scale_down)
# Specify some teacher structure.
'''
teacher.w0.weight.data.zero_()
span = d // ks[0]
for i in range(ks[0]):
teacher.w0.weight.data[i, span*i:span*i+span] = 1
'''
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, num_workers=4)
eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=args.eval_batchsize, shuffle=True, num_workers=4)
if args.teacher_bias_tune:
teacher_tune.tune_teacher(eval_loader, teacher)
if args.teacher_bias_last_layer_tune:
teacher_tune.tune_teacher_last_layer(eval_loader, teacher)
# teacher.w0.bias.data.uniform_(-1, 0)
# teacher.init_orth()
# init_w(teacher.w0)
# init_w(teacher.w1)
# init_w(teacher.w2)
# init_w2(teacher.w0, multiplier=args.init_multi)
# init_w2(teacher.w1, multiplier=args.init_multi)
# init_w2(teacher.w2, multiplier=args.init_multi)
all_all_corrs = []
log.info("=== Start ===")
std = args.data_std
stats_op = stats_operator.StatsCollector(teacher, student)
# Compute Correlation between teacher and student activations.
stats_op.add_stat(stats_operator.StatsCorr, active_nodes=active_nodes, cnt_thres=0.9)
if args.cross_entropy:
stats_op.add_stat(stats_operator.StatsCELoss)
loss = nn.CrossEntropyLoss().cuda()
def loss_func(predicted, target):
_, target_y = target.max(1)
return loss(predicted, target_y)
else:
stats_op.add_stat(stats_operator.StatsL2Loss)
loss_func = nn.MSELoss().cuda()
# Duplicate training and testing.
eval_stats_op = deepcopy(stats_op)
stats_op.label = "train"
eval_stats_op.label = "eval"
stats_op.add_stat(stats_operator.StatsGrad)
stats_op.add_stat(stats_operator.StatsMemory)
if args.stats_H:
eval_stats_op.add_stat(stats_operator.StatsHs)
# pickle.dump(model2numpy(teacher), open("weights_gt.pickle", "wb"), protocol=2)
all_stats = []
for i in range(args.num_trial):
if args.load_student is None:
log.info("=== Trial %d, std = %f ===" % (i, std))
student.reset_parameters()
# student = copy.deepcopy(student_clone)
# student.set_teacher_sign(teacher, scale=1)
if args.perturb is not None:
student.set_teacher(teacher, args.perturb)
if args.same_dir:
student.set_teacher_dir(teacher)
if args.same_sign:
student.set_teacher_sign(teacher)
else:
log.info(f"Loading student {args.load_student}")
student = torch.load(args.load_student)
# init_corrs[-1] = predict_last_order(student, teacher, args)
# alter_last_layer = predict_last_order(student, teacher, args)
# import pdb
# pdb.set_trace()
stats = optimize(train_loader, eval_loader, teacher, student, loss_func, stats_op, eval_stats_op, args, lrs)
all_stats.append(stats)
torch.save(all_stats, "stats.pickle")
# log.info("Student network")
# log.info(student.w1.weight)
# log.info("Teacher network")
# log.info(teacher.w1.weight)
log.info(f"Working dir: {os.getcwd()}")