in doc/doc_utils/jumpstart_doc_utils.py [0:0]
def create_jumpstart_model_table():
sdk_manifest = get_jumpstart_sdk_manifest()
sdk_manifest_top_versions_for_models = {}
for model in sdk_manifest:
if model["model_id"] not in sdk_manifest_top_versions_for_models:
sdk_manifest_top_versions_for_models[model["model_id"]] = model
else:
if Version(
sdk_manifest_top_versions_for_models[model["model_id"]]["version"]
) < Version(model["version"]):
sdk_manifest_top_versions_for_models[model["model_id"]] = model
file_content_intro = []
file_content_intro.append(".. _all-pretrained-models:\n\n")
file_content_intro.append(".. |external-link| raw:: html\n\n")
file_content_intro.append(' <i class="fa fa-external-link"></i>\n\n')
file_content_intro.append("================================================\n")
file_content_intro.append("Built-in Algorithms with pre-trained Model Table\n")
file_content_intro.append("================================================\n")
file_content_intro.append(
"""
The SageMaker Python SDK uses model IDs and model versions to access the necessary
utilities for pre-trained models. This table serves to provide the core material plus
some extra information that can be useful in selecting the correct model ID and
corresponding parameters.\n"""
)
file_content_intro.append(
"""
If you want to automatically use the latest version of the model, use "*" for the `model_version` attribute.
We highly suggest pinning an exact model version however.\n"""
)
file_content_intro.append(
"""
These models are also available through the
`JumpStart UI in SageMaker Studio <https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html>`__\n"""
)
file_content_intro.append("\n")
file_content_intro.append(".. list-table:: Available Models\n")
file_content_intro.append(" :widths: 50 20 20 20 30 20\n")
file_content_intro.append(" :header-rows: 1\n")
file_content_intro.append(" :class: datatable\n")
file_content_intro.append("\n")
file_content_intro.append(" * - Model ID\n")
file_content_intro.append(" - Fine Tunable?\n")
file_content_intro.append(" - Latest Version\n")
file_content_intro.append(" - Min SDK Version\n")
file_content_intro.append(" - Problem Type\n")
file_content_intro.append(" - Source\n")
dynamic_table_files = []
open_weight_content_entries = []
for model in sdk_manifest_top_versions_for_models.values():
model_spec = get_jumpstart_sdk_spec(model["spec_key"])
model_task = get_model_task(model_spec["model_id"])
string_model_task = get_string_model_task(model_spec["model_id"])
model_source = get_model_source(model_spec["url"])
open_weight_content_entries.append(" * - {}\n".format(model_spec["model_id"]))
open_weight_content_entries.append(" - {}\n".format(model_spec["training_supported"]))
open_weight_content_entries.append(" - {}\n".format(model["version"]))
open_weight_content_entries.append(" - {}\n".format(model["min_version"]))
open_weight_content_entries.append(" - {}\n".format(model_task))
open_weight_content_entries.append(
" - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"])
)
if (string_model_task, TO_FRAMEWORK[model_source]) in MODALITY_MAP:
file_content_single_entry = []
if (
MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])]
not in dynamic_table_files
):
file_content_single_entry.append("\n")
file_content_single_entry.append(".. list-table:: Available Models\n")
file_content_single_entry.append(" :widths: 50 20 20 20 20\n")
file_content_single_entry.append(" :header-rows: 1\n")
file_content_single_entry.append(" :class: datatable\n")
file_content_single_entry.append("\n")
file_content_single_entry.append(" * - Model ID\n")
file_content_single_entry.append(" - Fine Tunable?\n")
file_content_single_entry.append(" - Latest Version\n")
file_content_single_entry.append(" - Min SDK Version\n")
file_content_single_entry.append(" - Source\n")
dynamic_table_files.append(
MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])]
)
file_content_single_entry.append(" * - {}\n".format(model_spec["model_id"]))
file_content_single_entry.append(" - {}\n".format(model_spec["training_supported"]))
file_content_single_entry.append(" - {}\n".format(model["version"]))
file_content_single_entry.append(" - {}\n".format(model["min_version"]))
file_content_single_entry.append(
" - `{} <{}>`__\n".format(model_source, model_spec["url"])
)
f = open(MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])], "a")
f.writelines(file_content_single_entry)
f.close()
proprietary_content_entries = create_proprietary_model_table()
f = open("doc_utils/pretrainedmodels.rst", "a")
f.writelines(file_content_intro)
f.writelines(open_weight_content_entries)
f.writelines(proprietary_content_entries)
f.close()