in src/transformers/commands/add_new_model_like.py [0:0]
def get_user_input():
"""
Ask the user for the necessary inputs to add the new model.
"""
model_types = list(auto_module.configuration_auto.MODEL_NAMES_MAPPING.keys())
# Get old model type
valid_model_type = False
while not valid_model_type:
old_model_type = input(
"What is the model you would like to duplicate? Please provide the lowercase `model_type` (e.g. roberta): "
)
if old_model_type in model_types:
valid_model_type = True
else:
print(f"{old_model_type} is not a valid model type.")
near_choices = difflib.get_close_matches(old_model_type, model_types)
if len(near_choices) >= 1:
if len(near_choices) > 1:
near_choices = " or ".join(near_choices)
print(f"Did you mean {near_choices}?")
old_model_info = retrieve_info_for_model(old_model_type)
old_tokenizer_class = old_model_info["model_patterns"].tokenizer_class
old_image_processor_class = old_model_info["model_patterns"].image_processor_class
old_image_processor_fast_class = old_model_info["model_patterns"].image_processor_fast_class
old_feature_extractor_class = old_model_info["model_patterns"].feature_extractor_class
old_processor_class = old_model_info["model_patterns"].processor_class
old_frameworks = old_model_info["frameworks"]
old_checkpoint = None
if len(old_model_info["model_patterns"].checkpoint) == 0:
old_checkpoint = get_user_field(
"We couldn't find the name of the base checkpoint for that model, please enter it here."
)
model_name = get_user_field(
"What is the name (with no special casing) for your new model in the paper (e.g. RoBERTa)? "
)
default_patterns = ModelPatterns(model_name, model_name)
model_type = get_user_field(
"What identifier would you like to use for the `model_type` of this model? ",
default_value=default_patterns.model_type,
)
model_lower_cased = get_user_field(
"What lowercase name would you like to use for the module (folder) of this model? ",
default_value=default_patterns.model_lower_cased,
)
model_camel_cased = get_user_field(
"What prefix (camel-cased) would you like to use for the model classes of this model (e.g. Roberta)? ",
default_value=default_patterns.model_camel_cased,
)
model_upper_cased = get_user_field(
"What prefix (upper-cased) would you like to use for the constants relative to this model? ",
default_value=default_patterns.model_upper_cased,
)
config_class = get_user_field(
"What will be the name of the config class for this model? ", default_value=f"{model_camel_cased}Config"
)
checkpoint = get_user_field(
"Please give a checkpoint identifier (on the model Hub) for this new model (e.g. facebook/FacebookAI/roberta-base): "
)
old_processing_classes = [
c if not isinstance(c, tuple) else c[0]
for c in [
old_image_processor_class,
old_image_processor_fast_class,
old_feature_extractor_class,
old_tokenizer_class,
old_processor_class,
]
if c is not None
]
old_processing_classes = ", ".join(old_processing_classes)
keep_processing = get_user_field(
f"Will your new model use the same processing class as {old_model_type} ({old_processing_classes}) (yes/no)? ",
convert_to=convert_to_bool,
fallback_message="Please answer yes/no, y/n, true/false or 1/0. ",
)
if keep_processing:
image_processor_class = old_image_processor_class
image_processor_fast_class = old_image_processor_fast_class
feature_extractor_class = old_feature_extractor_class
processor_class = old_processor_class
tokenizer_class = old_tokenizer_class
create_fast_image_processor = False
else:
if old_tokenizer_class is not None:
tokenizer_class = get_user_field(
"What will be the name of the tokenizer class for this model? ",
default_value=f"{model_camel_cased}Tokenizer",
)
else:
tokenizer_class = None
if old_image_processor_class is not None:
image_processor_class = get_user_field(
"What will be the name of the image processor class for this model? ",
default_value=f"{model_camel_cased}ImageProcessor",
)
else:
image_processor_class = None
if old_image_processor_fast_class is not None:
image_processor_fast_class = get_user_field(
"What will be the name of the fast image processor class for this model? ",
default_value=f"{model_camel_cased}ImageProcessorFast",
)
else:
image_processor_fast_class = None
if old_feature_extractor_class is not None:
feature_extractor_class = get_user_field(
"What will be the name of the feature extractor class for this model? ",
default_value=f"{model_camel_cased}FeatureExtractor",
)
else:
feature_extractor_class = None
if old_processor_class is not None:
processor_class = get_user_field(
"What will be the name of the processor class for this model? ",
default_value=f"{model_camel_cased}Processor",
)
else:
processor_class = None
if old_image_processor_class is not None and old_image_processor_fast_class is None:
create_fast_image_processor = get_user_field(
"A fast image processor can be created from the slow one, but modifications might be needed. "
"Should we add a fast image processor class for this model (recommended, yes/no)? ",
convert_to=convert_to_bool,
default_value="yes",
fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
)
else:
create_fast_image_processor = False
model_patterns = ModelPatterns(
model_name,
checkpoint,
model_type=model_type,
model_lower_cased=model_lower_cased,
model_camel_cased=model_camel_cased,
model_upper_cased=model_upper_cased,
config_class=config_class,
tokenizer_class=tokenizer_class,
image_processor_class=image_processor_class,
image_processor_fast_class=image_processor_fast_class,
feature_extractor_class=feature_extractor_class,
processor_class=processor_class,
)
add_copied_from = get_user_field(
"Should we add # Copied from statements when creating the new modeling file (yes/no)? ",
convert_to=convert_to_bool,
default_value="yes",
fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
)
all_frameworks = get_user_field(
"Should we add a version of your new model in all the frameworks implemented by"
f" {old_model_type} ({old_frameworks}) (yes/no)? ",
convert_to=convert_to_bool,
default_value="yes",
fallback_message="Please answer yes/no, y/n, true/false or 1/0.",
)
if all_frameworks:
frameworks = None
else:
frameworks = get_user_field(
"Please enter the list of frameworks you want (pt, tf, flax) separated by spaces",
is_valid_answer=lambda x: all(p in ["pt", "tf", "flax"] for p in x.split(" ")),
)
frameworks = list(set(frameworks.split(" ")))
return (old_model_type, model_patterns, add_copied_from, frameworks, old_checkpoint, create_fast_image_processor)