in mm_dst/gpt2_dst/scripts/run_generation.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
)
parser.add_argument("--prompt", type=str, default="")
parser.add_argument(
"--prompts_from_file",
type=str,
default=None,
help="""
read prompts from a file and generate, overrides any prompt given on the
command line"""
)
parser.add_argument("--length", type=int, default=20)
parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped")
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="temperature of 1.0 has no effect, lower tend toward greedy sampling",
)
parser.add_argument(
"--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2"
)
parser.add_argument("--k", type=int, default=0)
parser.add_argument("--p", type=float, default=0.9)
parser.add_argument("--padding_text", type=str, default="", help="Padding text for Transfo-XL and XLNet.")
parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.")
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.")
parser.add_argument("--path_output", type=str, default=None, help="Path to output predictions in a line separated text file.")
args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
set_seed(args)
if args.prompts_from_file and not os.path.exists(args.prompts_from_file):
raise Exception(f"prompt file '{args.prompts_from_file}' not found")
# Initialize the model and tokenizer
try:
args.model_type = args.model_type.lower()
model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
except KeyError:
raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)")
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path)
model.to(args.device)
args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings)
logger.info(args)
results = []
prompts = []
if args.prompts_from_file:
with open(args.prompts_from_file) as handle:
prompts = handle.readlines()
while True:
if not prompts:
prompts = [args.prompt if args.prompt else input("Model prompt >>> ")]
if not args.prompt and (
len(prompts) == 0
or prompts[0].strip() == ''
or prompts[0].lower() == 'quit'
):
break # break while True loop
n_prompts = len(prompts)
for i, prompt_text in enumerate(prompts):
# Strip any trailing \n if provided
prompt_text = prompt_text.strip('\n')
# Different models need different input formatting and/or extra arguments
requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys()
if requires_preprocessing:
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
preprocessed_prompt_text = prepare_input(
args,
model,
tokenizer,
prompt_text
)
encoded_prompt = tokenizer.encode(
preprocessed_prompt_text,
add_special_tokens=True,
return_tensors="pt",
add_space_before_punct_symbol=True
)
else:
encoded_prompt = tokenizer.encode(
prompt_text,
add_special_tokens=True,
return_tensors="pt"
)
encoded_prompt = encoded_prompt.to(args.device)
output_sequences = model.generate(
input_ids=encoded_prompt,
max_length=args.length + len(encoded_prompt[0]),
temperature=args.temperature,
top_k=args.k,
top_p=args.p,
repetition_penalty=args.repetition_penalty,
do_sample=True,
num_return_sequences=args.num_return_sequences,
)
# Remove the batch dimension when returning multiple sequences
if len(output_sequences.shape) > 2:
output_sequences.squeeze_()
generated_sequences = []
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
print(
"=== GENERATED SEQUENCE {sequence_idx}, {promt_idx}/{n_prompts} ===".format(
sequence_idx=generated_sequence_idx + 1,
promt_idx=i + 1,
n_prompts=n_prompts
)
)
generated_sequence = generated_sequence.tolist()
# Decode text
text = tokenizer.decode(
generated_sequence,
clean_up_tokenization_spaces=True
)
# Remove all text after the stop token
text = text[: text.find(args.stop_token) if args.stop_token else None]
# Add the prompt at the beginning of the sequence. Remove the
# excess text that was used for pre-processing
total_sequence = (
prompt_text + text[
len(tokenizer.decode(
encoded_prompt[0],
clean_up_tokenization_spaces=True
))
:
]
)
generated_sequences.append(total_sequence)
print(total_sequence)
results.append(generated_sequences)
prompts = []
if args.prompt or args.prompts_from_file:
break # break while True loop
if args.path_output is not None:
# Create a directory if it does not exist
directory = os.path.dirname(args.path_output)
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
# Format results into a line-separated string file
str_results = '\n'.join([
' || '.join(generated_sequences) for generated_sequences in results
])
# Save to a file
with open(args.path_output, 'w') as f_out:
f_out.write(str_results)
return results