def _build_tf_config_for_ps()

in src/sagemaker_tensorflow_container/training.py [0:0]


def _build_tf_config_for_ps(hosts, current_host, ps_task=False):
    """Builds a dictionary containing cluster information based on number of hosts and number of
    parameter servers.

    Args:
        hosts (list[str]): List of host names in the cluster
        current_host (str): Current host name
        ps_task (bool): Set to True if this config is built for a parameter server process
            (default: False)

    Returns:
        dict[str: dict]: A dictionary describing the cluster setup for distributed training.
            For more information regarding TF_CONFIG:
            https://cloud.google.com/ml-engine/docs/tensorflow/distributed-training-details
    """
    # Assign the first host as the master. Rest of the hosts if any will be worker hosts.
    # The first ps_num hosts will also have a parameter task assign to them.
    masters = hosts[:1]
    workers = hosts[1:]
    ps = hosts if len(hosts) > 1 else None

    def host_addresses(hosts, port=2222):
        return ["{}:{}".format(host, port) for host in hosts]

    tf_config = {"cluster": {"master": host_addresses(masters)}, "environment": "cloud"}

    if ps:
        tf_config["cluster"]["ps"] = host_addresses(ps, port="2223")

    if workers:
        tf_config["cluster"]["worker"] = host_addresses(workers)

    if ps_task:
        if ps is None:
            raise ValueError(
                "Cannot have a ps task if there are no parameter servers in the cluster"
            )
        task_type = "ps"
        task_index = ps.index(current_host)
    elif _is_host_master(hosts, current_host):
        task_type = "master"
        task_index = 0
    else:
        task_type = "worker"
        task_index = workers.index(current_host)

    tf_config["task"] = {"index": task_index, "type": task_type}
    return tf_config