launcher/accelerator_devices.py (82 lines of code) (raw):
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
acceleratorDevices = {
"p4d.24xlarge": 8,
"p4de.24xlarge": 8,
"p5.48xlarge": 8,
"p5e.48xlarge": 8,
"p5en.48xlarge": 8,
"trn1.2xlarge": 1,
"trn1.32xlarge": 16,
"trn1n.32xlarge": 16,
"trn2.48xlarge": 16,
"g5.xlarge": 1,
"g5.2xlarge": 1,
"g5.4xlarge": 1,
"g5.8xlarge": 1,
"g5.12xlarge": 4,
"g5.16xlarge": 1,
"g5.24xlarge": 4,
"g5.48xlarge": 8,
"g6.xlarge": 1,
"g6.2xlarge": 1,
"g6.4xlarge": 1,
"g6.8xlarge": 1,
"g6.16xlarge": 1,
"g6.12xlarge": 4,
"g6.24xlarge": 4,
"g6.48xlarge": 8,
"gr6.4xlarge": 1,
"gr6.8xlarge": 1,
"g6e.xlarge": 1,
"g6e.2xlarge": 1,
"g6e.4xlarge": 1,
"g6e.8xlarge": 1,
"g6e.16xlarge": 1,
"g6e.12xlarge": 4,
"g6e.24xlarge": 4,
"g6e.48xlarge": 8,
}
coresPerAcceleratorDevice = {
"p4d.24xlarge": 1,
"p4de.24xlarge": 1,
"p5.48xlarge": 1,
"p5e.48xlarge": 1,
"p5en.48xlarge": 1,
"trn1.2xlarge": 2,
"trn1.32xlarge": 2,
"trn1n.32xlarge": 2,
"trn2.48xlarge": 2,
"g5.xlarge": 1,
"g5.2xlarge": 1,
"g5.4xlarge": 1,
"g5.8xlarge": 1,
"g5.12xlarge": 1,
"g5.16xlarge": 1,
"g5.24xlarge": 1,
"g5.48xlarge": 1,
"g6.xlarge": 1,
"g6.2xlarge": 1,
"g6.4xlarge": 1,
"g6.8xlarge": 1,
"g6.16xlarge": 1,
"g6.12xlarge": 1,
"g6.24xlarge": 1,
"g6.48xlarge": 1,
"gr6.4xlarge": 1,
"gr6.8xlarge": 1,
"g6e.xlarge": 1,
"g6e.2xlarge": 1,
"g6e.4xlarge": 1,
"g6e.8xlarge": 1,
"g6e.16xlarge": 1,
"g6e.12xlarge": 1,
"g6e.24xlarge": 1,
"g6e.48xlarge": 1,
}
def get_num_accelerator_devices(instance_type: str):
"""
Get the number of accelerator devices on an instance type.
Accelerator device could be GPU or Trainium chips
:param instance_type: AWS EC2 instance type
:return: number of accelerator devices for the instance type or None if instance
type not in the accelerator devices map
"""
if instance_type not in acceleratorDevices:
return None
return acceleratorDevices[instance_type]
def get_num_cores_per_accelerator(instance_type: str):
"""
Get the number of cores per accelerator device on an instance type.
Currently, Trainium has 2 cores per device while Nvida has 1 core per device.
:param instance_type: AWS EC2 instance type
:return: number of cores for the accelerator device or None if instance type
not in the map
"""
if instance_type not in coresPerAcceleratorDevice:
return None
return coresPerAcceleratorDevice[instance_type]