in scripts/run_prompt_creation.py [0:0]
def main():
# 1. Parse input arguments
parser = HfArgumentParser((ModelArguments, DataArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args = parser.parse_args_into_dataclasses()
# 2. Setup logging
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
if data_args.is_single_speaker and data_args.speaker_name is None:
raise ValueError("`is_single_speaker=True` but `speaker_name` is not specified. Specify it or remove `is_single_speaker`.")
if not data_args.is_single_speaker and data_args.speaker_name:
raise ValueError(f"`is_single_speaker=False` but `speaker_name=data_args.speaker_name` is not specified. Add `--is_single_speaker` or remove `speaker_name`.")
# Create the custom configuration
process_group_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=3600*3))
accelerator = Accelerator(kwargs_handlers=[process_group_kwargs])
if data_args.overwrite_output_dir and os.path.exists(data_args.output_dir) and os.path.isdir(data_args.output_dir):
logger.info("Cleaning output dir from previous run...")
shutil.rmtree(data_args.output_dir)
# 3. Load annotated dataset
logger.info("*** Load annotated dataset ***")
if data_args.dataset_split_name is not None:
raw_datasets = DatasetDict()
data_splits = data_args.dataset_split_name.split("+")
# load on a split-wise basis
for split in data_splits:
with accelerator.local_main_process_first():
raw_datasets[split] = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
split=split,
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
else:
with accelerator.local_main_process_first():
# load all splits for annotation
raw_datasets = load_dataset(
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
token=model_args.token,
num_proc=data_args.preprocessing_num_workers,
)
raw_datasets_features = set(raw_datasets[next(iter(raw_datasets))].features.keys())
if data_args.max_eval_samples is not None:
for split in raw_datasets:
raw_datasets[split] = raw_datasets[split].select(range(data_args.max_eval_samples))
EXPECTED_COLUMNS = {"gender", "pitch", "noise", "reverberation", "speech_monotony", "speaking_rate"}
if data_args.is_single_speaker:
EXPECTED_COLUMNS = {"noise", "reverberation", "speech_monotony", "speaking_rate"}
if data_args.is_new_speaker_prompt:
EXPECTED_COLUMNS.remove("noise")
EXPECTED_COLUMNS.add("sdr_noise")
speaker_ids_to_name = {}
speaker_id_column = data_args.speaker_id_column
if data_args.speaker_id_column and data_args.speaker_ids_to_name_json:
import json
if data_args.is_single_speaker:
raise ValueError(f"`is_single_speaker=True` but `speaker_ids_to_name_json={data_args.speaker_ids_to_name_json}`. Specify one or another.")
EXPECTED_COLUMNS.add(data_args.speaker_id_column)
with open(data_args.speaker_ids_to_name_json, "r") as read_file:
speaker_ids_to_name = json.load(read_file)
if not EXPECTED_COLUMNS.issubset(raw_datasets_features):
missing_columns = EXPECTED_COLUMNS - raw_datasets_features
raise ValueError(
f"Missing columns {missing_columns} from the dataset features. Got dataset features {raw_datasets_features}"
)
# 4. Load pre-trained model
logger.info("*** Load pretrained model ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
revision=model_args.model_revision,
variant=model_args.model_variant,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
low_cpu_mem_usage=True,
token=model_args.token,
).eval()
if model_args.torch_compile:
# torch compile only compatible with gemma and llama
if not callable(getattr(model, "_setup_cache", None)):
raise ValueError(
f"Static k/v cache is not compatible with the model {model.__class__.__name__}. Set `--torch_compile=False"
"for dynamic k/v cache"
)
model.generation_config.cache_implementation = "static"
# compile the forward pass (but not the top-{p,k} sampling)
model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer,
padding_side="left",
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.bos_token_id
model.generation_config.pad_token_id = model.generation_config.eos_token_id
speaker_name = data_args.speaker_name
is_single_speaker = data_args.is_single_speaker
is_new_speaker_prompt = data_args.is_new_speaker_prompt
accent_column_name = data_args.accent_column
def prepare_dataset(sample):
sample_prompt = PROMPT
if is_single_speaker:
sample_prompt = SINGLE_SPEAKER_PROMPT if not is_new_speaker_prompt else NEW_SINGLE_SPEAKER_PROMPT
sample_prompt = sample_prompt.replace(f"[speaker_name]", speaker_name)
elif (speaker_id_column and speaker_ids_to_name.get(str(sample.get(speaker_id_column)), None)):
name = speaker_ids_to_name.get(str(sample.get(speaker_id_column)), None)
sample_prompt = SINGLE_SPEAKER_PROMPT if not is_new_speaker_prompt else NEW_SINGLE_SPEAKER_PROMPT
sample_prompt = sample_prompt.replace(f"[speaker_name]", name)
elif is_new_speaker_prompt and accent_column_name is not None:
sample_prompt = NEW_PROMPT if sample.get(accent_column_name, "Unindentified") == "Unindentified" else NEW_PROMPT_WITH_ACCENT
elif is_new_speaker_prompt:
sample_prompt = NEW_PROMPT
for key in EXPECTED_COLUMNS:
sample_prompt = sample_prompt.replace(f"[{key}]", sample[key])
if accent_column_name is not None and sample.get(accent_column_name, "Unindentified") != "Unindentified":
sample_prompt = sample_prompt.replace("[accent]", sample["accent"])
sample_prompt = [{"role": "user", "content": sample_prompt}]
token_ids = tokenizer.apply_chat_template(sample_prompt)
sample["input_ids"] = token_ids
return sample
with accelerator.local_main_process_first():
vectorized_datasets = raw_datasets.map(
prepare_dataset, num_proc=data_args.preprocessing_num_workers, desc="Preparing prompts"
)
# Prepare everything with our `accelerator`
model = accelerator.prepare(model)
data_collator = DataCollatorWithPadding(tokenizer)
def generate_step(batch):
output_ids = accelerator.unwrap_model(model).generate(
batch["input_ids"],
attention_mask=batch["attention_mask"],
do_sample=model_args.do_sample,
temperature=model_args.temperature,
max_new_tokens=model_args.max_new_tokens,
)
output_ids = accelerator.pad_across_processes(output_ids, dim=1, pad_index=tokenizer.pad_token_id)
return output_ids
def postprocess_dataset(batch):
prompt_texts = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True)
generated_texts = tokenizer.batch_decode(batch["generated_ids"], skip_special_tokens=True)
batch["text_description"] = [generated_text[len(prompt_text) :] for (prompt_text, generated_text) in zip(prompt_texts, generated_texts)]
return batch
for split in vectorized_datasets:
data_loader = DataLoader(
vectorized_datasets[split],
batch_size=model_args.per_device_eval_batch_size,
collate_fn=data_collator,
num_workers=data_args.dataloader_num_workers,
pin_memory=True,
)
data_loader = accelerator.prepare(data_loader)
total_inference_steps = len(data_loader)
progress_bar = tqdm(
range(total_inference_steps), desc=" ... ", position=0, disable=not accelerator.is_local_main_process
)
split_output_dir = os.path.join(data_args.output_dir, split)
all_generated_ids, cur_step = get_last_checkpoint(split_output_dir, accelerator.is_local_main_process)
accelerator.wait_for_everyone()
if cur_step > 0:
logger.info(f"Resuming {split} from step {cur_step}")
# efficiently skip the first n batches
data_loader = skip_first_batches(data_loader, cur_step)
progress_bar.update(cur_step)
while cur_step < total_inference_steps:
for batch in data_loader:
generated_ids = generate_step(batch)
generated_ids = accelerator.gather_for_metrics(generated_ids)
if accelerator.is_local_main_process:
all_generated_ids.extend(generated_ids.cpu().numpy())
cur_step += 1
progress_bar.update(1)
if (cur_step % data_args.save_steps == 0) or (cur_step == total_inference_steps):
if accelerator.is_main_process:
save_checkpoint(split_output_dir, all_generated_ids, cur_step)
rotate_checkpoints(data_args.save_total_limit, output_dir=split_output_dir)
accelerator.wait_for_everyone()
if accelerator.is_local_main_process:
vectorized_datasets[split] = vectorized_datasets[split].add_column("generated_ids", all_generated_ids)
if accelerator.is_main_process:
vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset,
batched=True,
num_proc=data_args.preprocessing_num_workers,
desc="Postprocessing dataset",
remove_columns=["input_ids", "generated_ids"],
)
accelerator.wait_for_everyone()
if accelerator.is_main_process:
vectorized_datasets.save_to_disk(data_args.output_dir)
if data_args.push_to_hub:
vectorized_datasets.push_to_hub(
data_args.hub_dataset_id,
config_name=data_args.dataset_config_name if data_args.dataset_config_name is not None else "default",
token=model_args.token,
)
accelerator.wait_for_everyone()
accelerator.end_training()