in source/train.py [0:0]
def get_training_world():
"""
Calculates number of devices in Sagemaker distributed cluster
"""
# Get params of Sagemaker distributed cluster from predefined env variables
num_gpus = int(os.environ["SM_NUM_GPUS"])
num_cpus = int(os.environ["SM_NUM_CPUS"])
hosts = json.loads(os.environ["SM_HOSTS"])
current_host = os.environ["SM_CURRENT_HOST"]
# Define PyTorch training world
world = {}
world["number_of_processes"] = num_gpus if num_gpus > 0 else num_cpus
world["number_of_machines"] = len(hosts)
world["size"] = world["number_of_processes"] * world["number_of_machines"]
world["machine_rank"] = hosts.index(current_host)
world["master_addr"] = hosts[0]
world["master_port"] = "55555" # port is defined by Sagemaker
world["is_master"] = current_host == sorted(hosts)[0]
return world