in mephisto/data_model/task_run.py [0:0]
def get_valid_units_for_worker(self, worker: "Worker") -> List["Unit"]:
"""
Get any units that the given worker could work on in this
task run
"""
config = self.get_task_config()
if config.allowed_concurrent != 0 or config.maximum_units_per_worker:
current_units = self.db.find_units(
task_run_id=self.db_id,
worker_id=worker.db_id,
status=AssignmentState.ASSIGNED,
)
currently_active = len(current_units)
if config.allowed_concurrent != 0:
if currently_active >= config.allowed_concurrent:
logger.debug(
f"{worker} at maximum concurrent units {currently_active}"
)
return [] # currently at the maximum number of concurrent units
if config.maximum_units_per_worker != 0:
completed_types = AssignmentState.completed()
related_units = self.db.find_units(
task_id=self.task_id,
worker_id=worker.db_id,
)
currently_completed = len(
[u for u in related_units if u.db_status in completed_types]
)
if (
currently_active + currently_completed
>= config.maximum_units_per_worker
):
logger.debug(
f"{worker} at maximum units {currently_active}, {currently_completed}"
)
return [] # Currently at the maximum number of units for this task
task_units: List["Unit"] = self.get_units()
unit_assigns: Dict[str, List["Unit"]] = {}
for unit in task_units:
assignment_id = unit.assignment_id
if assignment_id not in unit_assigns:
unit_assigns[assignment_id] = []
unit_assigns[assignment_id].append(unit)
# Cannot pair with self
units: List["Unit"] = []
for unit_set in unit_assigns.values():
is_self_set = map(lambda u: u.worker_id == worker.db_id, unit_set)
if not any(is_self_set):
units += unit_set
# Valid units must be launched and must not be special units (negative indices)
valid_units = [
u
for u in units
if u.get_status() == AssignmentState.LAUNCHED and u.unit_index >= 0
]
logger.debug(f"Found {len(valid_units)} available units")
# Should load cached blueprint for SharedTaskState
blueprint = self.get_blueprint()
ret_units = [
u
for u in valid_units
if blueprint.shared_state.worker_can_do_unit(worker, u)
]
logger.debug(f"This worker is qualified for {len(ret_units)} unit.")
logger.debug(f"Found {ret_units[:3]} for {worker}.")
return ret_units