def _build_tf_config_for_mwms()

in src/sagemaker_tensorflow_container/training.py [0:0]


def _build_tf_config_for_mwms(hosts, current_host):
    """Builds a dictionary containing cluster information based on number of workers
    for Multi Worker Mirrored distribution strategy.

    Args:
        hosts (list[str]): List of host names in the cluster
        current_host (str): Current host name

    Returns:
        dict[str: dict]: A dictionary describing the cluster setup for distributed training.
            For more information regarding TF_CONFIG:
            https://cloud.google.com/ml-engine/docs/tensorflow/distributed-training-details
    """
    workers = hosts

    def host_addresses(hosts, port=8890):
        return ["{}:{}".format(host, port) for host in hosts]

    tf_config = {"cluster": {}, "environment": "cloud"}
    tf_config["cluster"]["worker"] = host_addresses(workers)
    tf_config["task"] = {"index": workers.index(current_host), "type": "worker"}

    return tf_config