in torch_xla/distributed/cluster.py [0:0]
def validate(self):
"""Validates the current cluster configuration.
Raises:
RuntimeError: If the cluster is misconfigured, this validation will
raise an error. For example, if the VMs are in different zones,
or not all of the CPU workers have the same size (number of CPU
cores, RAM size) we raise an exception. For TPUs we similarly
raise an exception if different zones or machine/accelerator_type.
"""
if len(self._client_workers) == 0 or len(self._service_workers) == 0:
raise RuntimeError(
'Both client_workers and service_workers should not be empty')
if len(self._client_workers) != len(self._service_workers):
raise RuntimeError(
'The client_workers and service_workers must have a 1:1 mapping')
zones = {worker._zone for worker in self._client_workers}
zones.update(worker._zone for worker in self._service_workers)
if len(zones) != 1:
raise RuntimeError(
'All workers must be in the same zone, got: {}'.format(zones))
if self._check_client_machine_type:
client_machine_types = {
worker._machine_type for worker in self._client_workers
}
if len(client_machine_types) != 1:
raise RuntimeError(
'All client_workers must have the same machine_type, got: {}'.
format(client_machine_types))
if self._check_service_machine_type:
server_machine_types = {
worker._machine_type for worker in self._service_workers
}
if len(server_machine_types) != 1:
raise RuntimeError(
'All service_workers must have the same machine_type, got: {}'.
format(server_machine_types))
runtime_versions = {
worker._runtime_version for worker in self._service_workers
}
if len(runtime_versions) != 1:
raise RuntimeError(
'All service workers must have the same runtime_version, got: {}'.
format(zones))