def training_function()

in archived/smddp_deepspeed_example/code/train.py [0:0]


def training_function(args):
  # smddp example specifically tailored for p4d(e) instance types.
  local_rank = dist.get_rank() % 8
  seed = args.seed
  set_seed(seed)
  torch.cuda.set_device(local_rank)

  dataset = {
    'train': StubDataset(),
    'validation': StubDataset()
  }
    
  dtype = torch.bfloat16

  from transformers import LlamaConfig
  configuration = LlamaConfig(use_cache=False)
  from transformers.models.llama import LlamaForCausalLM
  with deepspeed.zero.Init(dtype=dtype, enabled=True):
    model = AutoModelForCausalLM.from_config(configuration)
  model.gradient_checkpointing_enable()

  train_dataset = dataset["train"]
  eval_dataset = dataset["validation"]
  train_dataloader, eval_dataloader = create_dataloaders(
    train_dataset, eval_dataset, dist.get_rank(), dist.get_world_size(), 
    seed, args.batch_size, args.batch_size)
 
  no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
  optimizer_grouped_parameters = [{
      "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
      "weight_decay": args.weight_decay,
    },{
      "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
      "weight_decay": 0.0,
    }] 

  optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

   # Scheduler and math around the number of training steps.
  overrode_max_train_steps = False
  num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
  if dist.get_rank()==0:
    print(f"Number of update steps per epoch {num_update_steps_per_epoch}")
  if args.max_train_steps is None:
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    overrode_max_train_steps = True

  lr_scheduler = get_scheduler(
    name=args.lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
    num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
  )

  model, optimizer, _, _ = deepspeed.initialize(
    model=model,
    optimizer=optimizer,
    model_parameters=model.parameters(),
    config=args.deepspeed_config
  )
  device = torch.device(f"cuda:{local_rank}")
  for epoch in range(args.num_train_epochs):
    model.train()
    total_steps=0
    ds_loss = torch.zeros(2).to(local_rank)

    for batch_idx, batch in enumerate(train_dataloader):
      batch = {k: v.to(device) for k, v in batch.items()}  
      output = model(**batch)
      if dist.get_rank() == 0: print(f"Processing training batch {batch_idx}")
      loss = output["loss"]
      loss.backward()
      ds_loss[0] += loss.item()
      ds_loss[1] += len(batch["input_ids"])
      optimizer.zero_grad()
      lr_scheduler.step()
      total_steps += 1
      if args.max_steps is not None and total_steps > args.max_steps:
        break
    
    torch.distributed.all_reduce(ds_loss, op=torch.distributed.ReduceOp.SUM)
    train_loss = ds_loss[0] / ds_loss[1]
    train_ppl = torch.exp(train_loss)

    if dist.get_rank()==0:
      print(f"******{epoch=}: {train_ppl=} {train_loss=}******")
    
    model.eval()
    eval_loss = 0
    ds_eval_loss = torch.zeros(2).to(local_rank)
    for steps, batch in enumerate(eval_dataloader):
      batch = {k: v.to(device) for k, v in batch.items()}

      if dist.get_rank() == 0: print(f"Performing validation on training batch {batch_idx}")
      with torch.no_grad():
        outputs = model(**batch)
      loss = outputs["loss"]
      ds_eval_loss[0] += loss.item()
      ds_eval_loss[1] += len(batch["input_ids"])
      if args.max_steps is not None and steps > args.max_steps:
        break

    torch.distributed.all_reduce(ds_eval_loss, op=torch.distributed.ReduceOp.SUM)
    eval_loss = ds_eval_loss[0] / ds_eval_loss[1]
    eval_ppl = torch.exp(eval_loss)

    if dist.get_rank()==0:
      print(f"*******{epoch=}: {eval_ppl=} {eval_loss=}*******")
    
    if args.max_steps is not None and total_steps > args.max_steps:
        break

  if dist.get_rank() == 0:
    print("Training done!")
  dist.barrier()