in src/sagemaker_tensorflow_container/training.py [0:0]
def _build_tf_config_for_mwms(hosts, current_host):
"""Builds a dictionary containing cluster information based on number of workers
for Multi Worker Mirrored distribution strategy.
Args:
hosts (list[str]): List of host names in the cluster
current_host (str): Current host name
Returns:
dict[str: dict]: A dictionary describing the cluster setup for distributed training.
For more information regarding TF_CONFIG:
https://cloud.google.com/ml-engine/docs/tensorflow/distributed-training-details
"""
workers = hosts
def host_addresses(hosts, port=8890):
return ["{}:{}".format(host, port) for host in hosts]
tf_config = {"cluster": {}, "environment": "cloud"}
tf_config["cluster"]["worker"] = host_addresses(workers)
tf_config["task"] = {"index": workers.index(current_host), "type": "worker"}
return tf_config