in fairscale/experimental/nn/distributed_pipeline/partition_handler.py [0:0]
def compute(self, pipeline_record: DistributedPipelineRecord, chunk: int) -> None:
"""Runs tasks with synchronization to tensor-pipe streams."""
checkpoint_stop = self.checkpoint_stop
# Disable checkpointing if in eval mode.
if not self.module.training:
checkpoint_stop = 0
exc_info: Optional[ExcInfo] = None
batch = pipeline_record.get_batch(chunk)
if is_cuda(self.stream):
pipeline_record.sync_stream(chunk, as_cuda(self.stream))
# Determine whether checkpointing or not.
checkpoint = chunk < checkpoint_stop
if checkpoint:
def function(input: TensorOrTensors, chunk_id: int = chunk) -> TensorOrTensors:
with record_function("chunk%d-rank%d" % (chunk_id, pipeline_record.rank)):
result = self.module(*input)
if self.num_outputs is None:
result = (result,)
return tuple(result)
chk = Checkpointing(function, batch)
task = Task(self.stream, compute=chk.checkpoint, finalize=chk.recompute)
del function, chk
else:
def compute(
batch: Batch = batch,
chunk_id: int = chunk,
rank: int = pipeline_record.rank if pipeline_record is not None else -1,
) -> Batch:
with record_function("chunk%d-rank%d" % (chunk_id, pipeline_record.rank)):
result = self.module(*batch.tensors)
if self.num_outputs is None:
result = (result,)
return Batch(result, chunk_id)
task = Task(self.stream, compute=compute, finalize=None)
del compute
self.in_queue.put(task)
ok, payload = self.out_queue.get()
# Hold the first exception.
if exc_info is not None:
pass
elif not ok:
exc_info = cast(ExcInfo, payload)
else:
task, batch = cast(Tuple[Task, Batch], payload)
with use_device(self.device):
task.finalize(batch)
pipeline_record.batches[chunk] = batch
if exc_info is not None:
raise exc_info[0].with_traceback(exc_info[1], exc_info[2])