def _get_instance_types()

in src/sagemaker_algorithm_toolkit/metadata.py [0:0]


def _get_instance_types(region_name="us-east-1", location="US East (N. Virginia)"):
    s = boto3.client("pricing", region_name=region_name)

    NAME = "AmazonSageMaker"
    FILTERS = [
        {"Type": "TERM_MATCH", "Field": "productFamily", "Value": "ML Instance"},
        {"Type": "TERM_MATCH", "Field": "location", "Value": location},
    ]
    results = s.get_products(ServiceCode=NAME, Filters=FILTERS)

    total_results = []
    while results.get("NextToken"):
        total_results += results["PriceList"]
        results = s.get_products(ServiceCode=NAME, Filters=FILTERS, NextToken=results["NextToken"])

    instance_types = {}
    for result in total_results:
        result = json.loads(result)
        instance_type = result["product"]["attributes"]["instanceType"]
        gpu = result["product"]["attributes"]["gpu"]

        instance_types[instance_type] = int(gpu)
    return instance_types