src/sagemaker/cli/compatibility/v2/modifiers/airflow.py (24 lines of code) (raw):
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""A class to handle argument changes for Airflow functions."""
from __future__ import absolute_import
import ast
from sagemaker.cli.compatibility.v2.modifiers import matching, renamed_params
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
FUNCTION_NAMES = ("model_config", "model_config_from_estimator")
NAMESPACES = ("sagemaker.workflow.airflow", "workflow.airflow", "airflow")
FUNCTIONS = {name: NAMESPACES for name in FUNCTION_NAMES}
class ModelConfigArgModifier(Modifier):
"""A class to handle argument changes for Airflow model config functions."""
def node_should_be_modified(self, node):
"""Function to check Airflow model config and if it contains positional arguments.
Checks if the ``ast.Call`` node creates an Airflow model config and
contains positional arguments. This looks for the following formats:
- ``model_config``
- ``airflow.model_config``
- ``workflow.airflow.model_config``
- ``sagemaker.workflow.airflow.model_config``
where ``model_config`` is either ``model_config`` or ``model_config_from_estimator``.
Args:
node (ast.Call): a node that represents a function call. For more,
see https://docs.python.org/3/library/ast.html#abstract-grammar.
Returns:
bool: If the ``ast.Call`` is either a ``model_config`` call or
a ``model_config_from_estimator`` call and has positional arguments.
"""
return matching.matches_any(node, FUNCTIONS) and len(node.args) > 0
def modify_node(self, node):
"""Modifies the ``ast.Call`` node's arguments.
The first argument, the instance type, is turned into a keyword arg,
leaving the second argument, the model, to be the first argument.
Args:
node (ast.Call): a node that represents either a ``model_config`` call or
a ``model_config_from_estimator`` call.
Returns:
ast.AST: the original node, which has been potentially modified.
"""
instance_type = node.args.pop(0)
node.keywords.append(ast.keyword(arg="instance_type", value=instance_type))
return node
class ModelConfigImageURIRenamer(renamed_params.ParamRenamer):
"""A class to rename the ``image`` attribute to ``image_uri`` in Airflow model config functions.
This looks for the following formats:
- ``model_config``
- ``airflow.model_config``
- ``workflow.airflow.model_config``
- ``sagemaker.workflow.airflow.model_config``
where ``model_config`` is either ``model_config`` or ``model_config_from_estimator``.
"""
@property
def calls_to_modify(self):
"""A dictionary mapping Airflow model config functions to their respective namespaces."""
return FUNCTIONS
@property
def old_param_name(self):
"""The previous name for the image URI argument."""
return "image"
@property
def new_param_name(self):
"""The new name for the image URI argument."""
return "image_uri"