in eland/cli/eland_import_hub_model.py [0:0]
def main():
# Configure logging
logging.basicConfig(format="%(asctime)s %(levelname)s : %(message)s")
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
try:
from eland.ml.pytorch import PyTorchModel
from eland.ml.pytorch.transformers import (
SUPPORTED_TASK_TYPES,
TaskTypeError,
TransformerModel,
UnknownModelInputSizeError,
)
except ModuleNotFoundError as e:
logger.error(
textwrap.dedent(
f"""\
\033[31mFailed to run because module '{e.name}' is not available.\033[0m
This script requires PyTorch extras to run. You can install these by running:
\033[1m{sys.executable} -m pip install 'eland[pytorch]'
\033[0m"""
)
)
exit(1)
assert SUPPORTED_TASK_TYPES
# Parse arguments
args = parse_args()
# Connect to ES
logger.info("Establishing connection to Elasticsearch")
es = get_es_client(args, logger)
cluster_version = check_cluster_version(es, logger)
# Trace and save model, then upload it from temp file
with tempfile.TemporaryDirectory() as tmp_dir:
logger.info(
f"Loading HuggingFace transformer tokenizer and model '{args.hub_model_id}'"
)
try:
tm = TransformerModel(
model_id=args.hub_model_id,
access_token=args.hub_access_token,
task_type=args.task_type,
es_version=cluster_version,
quantize=args.quantize,
ingest_prefix=args.ingest_prefix,
search_prefix=args.search_prefix,
max_model_input_size=args.max_model_input_length,
)
model_path, config, vocab_path = tm.save(tmp_dir)
except TaskTypeError as err:
logger.error(
f"Failed to get model for task type, please provide valid task type via '--task-type' parameter. Caused by {err}"
)
exit(1)
except UnknownModelInputSizeError as err:
logger.error(
f"""Could not automatically determine the model's max input size from the model configuration.
Please provde the max input size via the --max-model-input-length parameter. Caused by {err}"""
)
exit(1)
ptm = PyTorchModel(
es, args.es_model_id if args.es_model_id else tm.elasticsearch_model_id()
)
model_exists = (
es.options(ignore_status=404)
.ml.get_trained_models(model_id=ptm.model_id)
.meta.status
== 200
)
if model_exists:
if args.clear_previous:
logger.info(f"Stopping deployment for model with id '{ptm.model_id}'")
ptm.stop()
logger.info(f"Deleting model with id '{ptm.model_id}'")
ptm.delete()
else:
logger.error(f"Trained model with id '{ptm.model_id}' already exists")
logger.info(
"Run the script with the '--clear-previous' flag if you want to overwrite the existing model."
)
exit(1)
logger.info(f"Creating model with id '{ptm.model_id}'")
ptm.put_config(config=config)
logger.info("Uploading model definition")
ptm.put_model(model_path)
logger.info("Uploading model vocabulary")
ptm.put_vocab(vocab_path)
# Start the deployed model
if args.start:
logger.info("Starting model deployment")
ptm.start()
logger.info(f"Model successfully imported with id '{ptm.model_id}'")