def train()

in scripts/train_memory.py [0:0]


def train(model_id, rank, dtype, monitor_tensors, max_seq_length, batch_size, max_steps, path_config):
    init_cuda()
    cuda_memory_init = torch.cuda.max_memory_allocated()
    cuda_memory_log = []

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.model_max_length = max_seq_length
    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token
    data = get_data(tokenizer)

    if dtype == "int4":
        quant_config = BitsAndBytesConfig(load_in_4bit=True)
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, quantization_config=quant_config)
        model = prepare_model_for_kbit_training(model)
    elif dtype == "int8":
        quant_config = BitsAndBytesConfig(load_in_8bit=True)
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, quantization_config=quant_config)
        model = prepare_model_for_kbit_training(model)
    elif dtype == "bfloat16":
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, torch_dtype=torch.bfloat16)
    elif dtype == "float16":
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, torch_dtype=torch.float16)
    elif dtype == "float32":
        model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)
    else:
        raise ValueError(f"Invalid dtype: {dtype}")

    if rank > 0:
        if path_config is None:
            raise RuntimeError("LoRA rank > 0 requires a path to a LoRA config")
        if path_config.endswith(CONFIG_NAME):
            path_config = path_config.removesuffix(CONFIG_NAME)
        config = LoraConfig.from_pretrained(path_config)
        model = get_peft_model(model, config)
        model.print_trainable_parameters()
    else:
        print("Not using LoRA")

    model.config.use_cache = False
    storage = []

    def pack(x):
        storage.append(x)
        return len(storage) - 1

    def unpack(x):
        return storage[x]

    train_ctx = partial(torch.autograd.graph.saved_tensors_hooks, pack, unpack) if monitor_tensors else nullcontext

    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    losses = []
    sample = 0
    tic_total = time.perf_counter()
    for i in range(0, max_steps):
        storage.clear()
        tic = time.perf_counter()
        try:
            batch = tokenizer.pad(data["train"][sample : sample + batch_size], return_tensors="pt").to(model.device)
            sample += batch_size

            # add targets
            batch["labels"] = batch["input_ids"].clone()
            optimizer.zero_grad()

            with train_ctx():
                outputs = model(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            cuda_memory_log.append(torch.cuda.memory_allocated() - cuda_memory_init)
            torch.cuda.empty_cache()
            gc.collect()
            toc = time.perf_counter()
            print(f"step {i:3d} loss {loss.item():.6f} time {toc - tic:.2f}s", file=sys.stderr)
        except KeyboardInterrupt:
            print("canceled training")
            break

        if monitor_tensors:
            break

    toc_total = time.perf_counter()

    cuda_memory_final = torch.cuda.max_memory_allocated()
    cuda_memory_avg = int(sum(cuda_memory_log) / len(cuda_memory_log))
    print(f"cuda memory avg: {cuda_memory_avg // 2**20}MB")
    print(f"cuda memory max: {(cuda_memory_final - cuda_memory_init) // 2**20}MB")
    print(f"total time: {toc_total - tic_total:.2f}s")

    with tempfile.TemporaryDirectory() as tmp_dir:
        model.save_pretrained(tmp_dir)
        stat = os.stat(os.path.join(tmp_dir, SAFETENSORS_WEIGHTS_NAME))
    file_size = stat.st_size
    print(f"file size: {file_size / 2**20:.1f}MB")

    if monitor_tensors:
        dtype_counts = Counter(t.dtype for t in storage)
        shape_counts = Counter(t.shape for t in storage)
        param_shape_counts = Counter(p.shape for p in model.parameters())
        param_shape_counts_copy = dict(param_shape_counts).copy()

        # shape counts includes the params, so we need to subtract them; note that they can be transposed
        # this is an approximation
        diff_shape_counts = {}
        for shape, count in shape_counts.items():
            if shape in param_shape_counts_copy:
                diff_count = count - param_shape_counts[shape]
                if diff_count > 0:
                    diff_shape_counts[shape] = diff_count
                    param_shape_counts_copy[shape] = max(0, param_shape_counts_copy[shape] - diff_count)
            elif shape[::-1] in param_shape_counts:
                diff_count = count - param_shape_counts[shape[::-1]]
                if diff_count > 0:
                    diff_shape_counts[shape] = diff_count
                    param_shape_counts_copy[shape[::-1]] = max(0, param_shape_counts_copy[shape[::-1]] - diff_count)
            else:
                diff_shape_counts[shape] = count

        total_size = sum(t.numel() * t.element_size() for t in storage)
        total_size_mb = f"{total_size // 2**20}MB"
        diff_size = 0
        for shape, count in diff_shape_counts.items():
            diff_size += count * torch.zeros(shape).numel() * dtype_to_bytes_linear[dtype]
        param_size = total_size - diff_size

        diff_size_mb = f"{diff_size // 2**20}MB"
        param_size_mb = f"{param_size // 2**20}MB"

        print(f"Dtype counts: {dtype_counts.most_common()}")
        print(f"Total size of tensors:     {total_size_mb: >12}")
        print(f"Total size of activations: {diff_size_mb: >12}")
        print(f"Total size of parameters:  {param_size_mb: >12}")