in smdebug/tensorflow/base_hook.py [0:0]
def _get_writers(self, tensor_name, tensor_ref) -> List[FileWriter]:
"""
For tensors generated during distributed tf jobs, we map the tensor to a writer
with its device attribute.
If the device attribute is CPU, we map it to all the writers.
For all other frameworks and single worker jobs we return a list with a single worker.
If include workers is False, we return a writer only if the
chief device is attempting to write.
:param tensor_name:
:return: List[FileWriter]
"""
if self.distribution_strategy in [
TFDistributionStrategy.PARAMETER_SERVER,
TFDistributionStrategy.HOROVOD,
TFDistributionStrategy.SMDATAPARALLEL,
]:
if self.save_all_workers is True or self.worker == self.chief_worker:
return self._get_main_writer()
elif self.distribution_strategy == TFDistributionStrategy.MIRRORED:
if len(self.device_map):
# else is for metrics in Keras
if tensor_ref is not None and tensor_ref.tf_obj is not None:
worker = tensor_ref.tf_obj.device
else:
worker = "CPU"
# if device str is empty or cpu in worker
if not bool(worker) or "CPU" in worker:
if self.save_all_workers:
return list(self.writer_map.values())
else:
return [self.writer_map[self.device_map[self.chief_worker]]]
elif self.save_all_workers or worker == self.chief_worker:
return [self.writer_map[self.device_map[worker]]]
else:
# training on CPU when all device strings have cpu
return self._get_main_writer()
elif self.distribution_strategy == TFDistributionStrategy.NONE:
return self._get_main_writer()
else:
raise NotImplementedError
# when self.writer is None, returns empty list
return []