in scripts/train_memory.py [0:0]
def train(model_id, rank, dtype, monitor_tensors, max_seq_length, batch_size, max_steps, path_config):
init_cuda()
cuda_memory_init = torch.cuda.max_memory_allocated()
cuda_memory_log = []
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.model_max_length = max_seq_length
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
data = get_data(tokenizer)
if dtype == "int4":
quant_config = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, quantization_config=quant_config)
model = prepare_model_for_kbit_training(model)
elif dtype == "int8":
quant_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, quantization_config=quant_config)
model = prepare_model_for_kbit_training(model)
elif dtype == "bfloat16":
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, torch_dtype=torch.bfloat16)
elif dtype == "float16":
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device, torch_dtype=torch.float16)
elif dtype == "float32":
model = AutoModelForCausalLM.from_pretrained(model_id, device_map=device)
else:
raise ValueError(f"Invalid dtype: {dtype}")
if rank > 0:
if path_config is None:
raise RuntimeError("LoRA rank > 0 requires a path to a LoRA config")
if path_config.endswith(CONFIG_NAME):
path_config = path_config.removesuffix(CONFIG_NAME)
config = LoraConfig.from_pretrained(path_config)
model = get_peft_model(model, config)
model.print_trainable_parameters()
else:
print("Not using LoRA")
model.config.use_cache = False
storage = []
def pack(x):
storage.append(x)
return len(storage) - 1
def unpack(x):
return storage[x]
train_ctx = partial(torch.autograd.graph.saved_tensors_hooks, pack, unpack) if monitor_tensors else nullcontext
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
losses = []
sample = 0
tic_total = time.perf_counter()
for i in range(0, max_steps):
storage.clear()
tic = time.perf_counter()
try:
batch = tokenizer.pad(data["train"][sample : sample + batch_size], return_tensors="pt").to(model.device)
sample += batch_size
# add targets
batch["labels"] = batch["input_ids"].clone()
optimizer.zero_grad()
with train_ctx():
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
losses.append(loss.item())
cuda_memory_log.append(torch.cuda.memory_allocated() - cuda_memory_init)
torch.cuda.empty_cache()
gc.collect()
toc = time.perf_counter()
print(f"step {i:3d} loss {loss.item():.6f} time {toc - tic:.2f}s", file=sys.stderr)
except KeyboardInterrupt:
print("canceled training")
break
if monitor_tensors:
break
toc_total = time.perf_counter()
cuda_memory_final = torch.cuda.max_memory_allocated()
cuda_memory_avg = int(sum(cuda_memory_log) / len(cuda_memory_log))
print(f"cuda memory avg: {cuda_memory_avg // 2**20}MB")
print(f"cuda memory max: {(cuda_memory_final - cuda_memory_init) // 2**20}MB")
print(f"total time: {toc_total - tic_total:.2f}s")
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
stat = os.stat(os.path.join(tmp_dir, SAFETENSORS_WEIGHTS_NAME))
file_size = stat.st_size
print(f"file size: {file_size / 2**20:.1f}MB")
if monitor_tensors:
dtype_counts = Counter(t.dtype for t in storage)
shape_counts = Counter(t.shape for t in storage)
param_shape_counts = Counter(p.shape for p in model.parameters())
param_shape_counts_copy = dict(param_shape_counts).copy()
# shape counts includes the params, so we need to subtract them; note that they can be transposed
# this is an approximation
diff_shape_counts = {}
for shape, count in shape_counts.items():
if shape in param_shape_counts_copy:
diff_count = count - param_shape_counts[shape]
if diff_count > 0:
diff_shape_counts[shape] = diff_count
param_shape_counts_copy[shape] = max(0, param_shape_counts_copy[shape] - diff_count)
elif shape[::-1] in param_shape_counts:
diff_count = count - param_shape_counts[shape[::-1]]
if diff_count > 0:
diff_shape_counts[shape] = diff_count
param_shape_counts_copy[shape[::-1]] = max(0, param_shape_counts_copy[shape[::-1]] - diff_count)
else:
diff_shape_counts[shape] = count
total_size = sum(t.numel() * t.element_size() for t in storage)
total_size_mb = f"{total_size // 2**20}MB"
diff_size = 0
for shape, count in diff_shape_counts.items():
diff_size += count * torch.zeros(shape).numel() * dtype_to_bytes_linear[dtype]
param_size = total_size - diff_size
diff_size_mb = f"{diff_size // 2**20}MB"
param_size_mb = f"{param_size // 2**20}MB"
print(f"Dtype counts: {dtype_counts.most_common()}")
print(f"Total size of tensors: {total_size_mb: >12}")
print(f"Total size of activations: {diff_size_mb: >12}")
print(f"Total size of parameters: {param_size_mb: >12}")