in utils/create_dummy_models.py [0:0]
def build(config_class, models_to_create, output_dir):
"""Create all models for a certain model type.
Args:
config_class (`PretrainedConfig`):
A subclass of `PretrainedConfig` that is used to determine `models_to_create`.
models_to_create (`dict`):
A dictionary containing the processor/model classes that we want to create the instances. These models are
of the same model type which is associated to `config_class`.
output_dir (`str`):
The directory to save all the checkpoints. Each model architecture will be saved in a subdirectory under
it. Models in different frameworks with the same architecture will be saved in the same subdirectory.
"""
if data["training_ds"] is None or data["testing_ds"] is None:
ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1")
data["training_ds"] = ds["train"]
data["testing_ds"] = ds["test"]
if config_class.model_type in [
"encoder-decoder",
"vision-encoder-decoder",
"speech-encoder-decoder",
"vision-text-dual-encoder",
]:
return build_composite_models(config_class, output_dir)
result = {k: {} for k in models_to_create}
# These will be removed at the end if they are empty
result["error"] = None
result["warnings"] = []
# Build processors
processor_classes = models_to_create["processor"]
if len(processor_classes) == 0:
error = f"No processor class could be found in {config_class.__name__}."
fill_result_with_error(result, error, None, models_to_create)
logger.error(result["error"][0])
return result
for processor_class in processor_classes:
try:
processor = build_processor(config_class, processor_class, allow_no_checkpoint=True)
if processor is not None:
result["processor"][processor_class] = processor
except Exception:
error = f"Failed to build processor for {processor_class.__name__}."
trace = traceback.format_exc()
fill_result_with_error(result, error, trace, models_to_create)
logger.error(result["error"][0])
return result
if len(result["processor"]) == 0:
error = f"No processor could be built for {config_class.__name__}."
fill_result_with_error(result, error, None, models_to_create)
logger.error(result["error"][0])
return result
try:
tiny_config = get_tiny_config(config_class)
except Exception as e:
error = f"Failed to get tiny config for {config_class.__name__}: {e}"
trace = traceback.format_exc()
fill_result_with_error(result, error, trace, models_to_create)
logger.error(result["error"][0])
return result
# Convert the processors (reduce vocabulary size, smaller image size, etc.)
processors = list(result["processor"].values())
processor_output_folder = os.path.join(output_dir, "processors")
try:
processors = convert_processors(processors, tiny_config, processor_output_folder, result)
except Exception:
error = "Failed to convert the processors."
trace = traceback.format_exc()
result["warnings"].append((error, trace))
if len(processors) == 0:
error = f"No processor is returned by `convert_processors` for {config_class.__name__}."
fill_result_with_error(result, error, None, models_to_create)
logger.error(result["error"][0])
return result
try:
config_overrides = get_config_overrides(config_class, processors)
except Exception as e:
error = f"Failure occurs while calling `get_config_overrides`: {e}"
trace = traceback.format_exc()
fill_result_with_error(result, error, trace, models_to_create)
logger.error(result["error"][0])
return result
# Just for us to see this easily in the report
if "vocab_size" in config_overrides:
result["vocab_size"] = config_overrides["vocab_size"]
# Update attributes that `vocab_size` involves
for k, v in config_overrides.items():
if hasattr(tiny_config, k):
setattr(tiny_config, k, v)
# So far, we only have to deal with `text_config`, as `config_overrides` contains text-related attributes only.
# `FuyuConfig` saves data under both FuyuConfig and its `text_config`. This is not good, but let's just update
# every involved fields to avoid potential failure.
if (
hasattr(tiny_config, "text_config")
and tiny_config.text_config is not None
and hasattr(tiny_config.text_config, k)
):
setattr(tiny_config.text_config, k, v)
# If `text_config_dict` exists, we need to update its value here too in order to # make
# `save_pretrained -> from_pretrained` work.
if hasattr(tiny_config, "text_config_dict"):
tiny_config.text_config_dict[k] = v
if result["warnings"]:
logger.warning(result["warnings"][0][0])
# update `result["processor"]`
result["processor"] = {type(p).__name__: p.__class__.__name__ for p in processors}
for pytorch_arch in models_to_create["pytorch"]:
result["pytorch"][pytorch_arch.__name__] = {}
error = None
try:
model = build_model(pytorch_arch, tiny_config, output_dir=output_dir)
except Exception as e:
model = None
error = f"Failed to create the pytorch model for {pytorch_arch}: {e}"
trace = traceback.format_exc()
result["pytorch"][pytorch_arch.__name__]["model"] = model.__class__.__name__ if model is not None else None
result["pytorch"][pytorch_arch.__name__]["checkpoint"] = (
get_checkpoint_dir(output_dir, pytorch_arch) if model is not None else None
)
if error is not None:
result["pytorch"][pytorch_arch.__name__]["error"] = (error, trace)
logger.error(f"{pytorch_arch.__name__}: {error}")
for tensorflow_arch in models_to_create["tensorflow"]:
# Make PT/TF weights compatible
pt_arch_name = tensorflow_arch.__name__[2:] # Remove `TF`
pt_arch = getattr(transformers_module, pt_arch_name)
result["tensorflow"][tensorflow_arch.__name__] = {}
error = None
if pt_arch.__name__ in result["pytorch"] and result["pytorch"][pt_arch.__name__]["checkpoint"] is not None:
ckpt = get_checkpoint_dir(output_dir, pt_arch)
# Use the same weights from PyTorch.
try:
model = tensorflow_arch.from_pretrained(ckpt)
model.save_pretrained(ckpt)
except Exception as e:
# Conversion may fail. Let's not create a model with different weights to avoid confusion (for now).
model = None
error = f"Failed to convert the pytorch model to the tensorflow model for {pt_arch}: {e}"
trace = traceback.format_exc()
else:
try:
model = build_model(tensorflow_arch, tiny_config, output_dir=output_dir)
except Exception as e:
model = None
error = f"Failed to create the tensorflow model for {tensorflow_arch}: {e}"
trace = traceback.format_exc()
result["tensorflow"][tensorflow_arch.__name__]["model"] = (
model.__class__.__name__ if model is not None else None
)
result["tensorflow"][tensorflow_arch.__name__]["checkpoint"] = (
get_checkpoint_dir(output_dir, tensorflow_arch) if model is not None else None
)
if error is not None:
result["tensorflow"][tensorflow_arch.__name__]["error"] = (error, trace)
logger.error(f"{tensorflow_arch.__name__}: {error}")
if not result["error"]:
del result["error"]
if not result["warnings"]:
del result["warnings"]
return result