in easy_rec/python/core/sampler.py [0:0]
def _init_graph(self):
if 'TF_CONFIG' in os.environ:
tf_config = json.loads(os.environ['TF_CONFIG'])
if 'ps' in tf_config['cluster']:
# ps mode
tf_config = json.loads(os.environ['TF_CONFIG'])
if 'worker' in tf_config['cluster']:
task_count = len(tf_config['cluster']['worker']) + 2
else:
task_count = 2
if self._is_on_ds:
gl.set_tracker_mode(0)
server_hosts = [
host.split(':')[0] + ':888' + str(i)
for i, host in enumerate(tf_config['cluster']['ps'])
]
cluster = {
'server': ','.join(server_hosts),
'client_count': task_count
}
else:
ps_count = len(tf_config['cluster']['ps'])
cluster = {'server_count': ps_count, 'client_count': task_count}
if tf_config['task']['type'] in ['chief', 'master']:
self._g.init(cluster=cluster, job_name='client', task_index=0)
elif tf_config['task']['type'] == 'worker':
self._g.init(
cluster=cluster,
job_name='client',
task_index=tf_config['task']['index'] + 2)
# TODO(hongsheng.jhs): check cluster has evaluator or not?
elif tf_config['task']['type'] == 'evaluator':
self._g.init(
cluster=cluster,
job_name='client',
task_index=tf_config['task']['index'] + 1)
if self._num_eval_sample is not None and self._num_eval_sample > 0:
self._num_sample = self._num_eval_sample
elif tf_config['task']['type'] == 'ps':
self._g.init(
cluster=cluster,
job_name='server',
task_index=tf_config['task']['index'])
else:
# worker mode
task_count = len(tf_config['cluster']['worker']) + 1
if not self._is_on_ds:
if tf_config['task']['type'] in ['chief', 'master']:
self._g.init(task_index=0, task_count=task_count)
elif tf_config['task']['type'] == 'worker':
self._g.init(
task_index=tf_config['task']['index'] + 1,
task_count=task_count)
else:
gl.set_tracker_mode(0)
if tf_config['cluster'].get('chief', ''):
chief_host = tf_config['cluster']['chief'][0].split(
':')[0] + ':8880'
else:
chief_host = tf_config['cluster']['master'][0].split(
':')[0] + ':8880'
worker_hosts = chief_host + [
host.split(':')[0] + ':888' + str(i)
for i, host in enumerate(tf_config['cluster']['worker'])
]
if tf_config['task']['type'] in ['chief', 'master']:
self._g.init(
task_index=0,
task_count=task_count,
hosts=','.join(worker_hosts))
elif tf_config['task']['type'] == 'worker':
self._g.init(
task_index=tf_config['task']['index'] + 1,
task_count=task_count,
hosts=worker_hosts)
# TODO(hongsheng.jhs): check cluster has evaluator or not?
else:
# local mode
self._g.init()