anli/src/nli/train_with_scramble.py [975:1036]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            )
            # print(debug_node_info(args), loss, logits, batch['uid'])
            # print(debug_node_info(args), loss, batch['uid'])

            # Accumulated loss
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            # if this forward step need model updates
            # handle fp16
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

                # Gradient clip: if max_grad_norm < 0
            if (forward_step + 1) % args.gradient_accumulation_steps == 0:
                if args.max_grad_norm > 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm
                        )
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            model.parameters(), args.max_grad_norm
                        )

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()

                global_step += 1

                if (
                    args.global_rank in [-1, 0]
                    and args.eval_frequency > 0
                    and global_step % args.eval_frequency == 0
                ):
                    r_dict = dict()
                    # Eval loop:
                    for i in range(len(eval_data_name)):
                        cur_eval_data_name = eval_data_name[i]
                        cur_eval_data_list = eval_data_list[i]
                        cur_eval_dataloader = eval_data_loaders[i]
                        # cur_eval_raw_data_list = eval_raw_data_list[i]

                        evaluation_dataset(
                            args,
                            cur_eval_dataloader,
                            cur_eval_data_list,
                            model,
                            r_dict,
                            eval_name=cur_eval_data_name,
                        )

                    # saving checkpoints
                    current_checkpoint_filename = f"e({epoch})|i({global_step})"

                    for i in range(len(eval_data_name)):
                        cur_eval_data_name = eval_data_name[i]
                        current_checkpoint_filename += f'|{cur_eval_data_name}#({round(r_dict[cur_eval_data_name]["acc"], 4)})'
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



anli/src/nli/training.py [863:924]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            )
            # print(debug_node_info(args), loss, logits, batch['uid'])
            # print(debug_node_info(args), loss, batch['uid'])

            # Accumulated loss
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            # if this forward step need model updates
            # handle fp16
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

                # Gradient clip: if max_grad_norm < 0
            if (forward_step + 1) % args.gradient_accumulation_steps == 0:
                if args.max_grad_norm > 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm
                        )
                    else:
                        torch.nn.utils.clip_grad_norm_(
                            model.parameters(), args.max_grad_norm
                        )

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()

                global_step += 1

                if (
                    args.global_rank in [-1, 0]
                    and args.eval_frequency > 0
                    and global_step % args.eval_frequency == 0
                ):
                    r_dict = dict()
                    # Eval loop:
                    for i in range(len(eval_data_name)):
                        cur_eval_data_name = eval_data_name[i]
                        cur_eval_data_list = eval_data_list[i]
                        cur_eval_dataloader = eval_data_loaders[i]
                        # cur_eval_raw_data_list = eval_raw_data_list[i]

                        evaluation_dataset(
                            args,
                            cur_eval_dataloader,
                            cur_eval_data_list,
                            model,
                            r_dict,
                            eval_name=cur_eval_data_name,
                        )

                    # saving checkpoints
                    current_checkpoint_filename = f"e({epoch})|i({global_step})"

                    for i in range(len(eval_data_name)):
                        cur_eval_data_name = eval_data_name[i]
                        current_checkpoint_filename += f'|{cur_eval_data_name}#({round(r_dict[cur_eval_data_name]["acc"], 4)})'
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



