in pai/pipeline/core.py [0:0]
def _infer_pipeline_graph(cls, steps, inputs, outputs):
"""Inference the DAG graph of pipeline by pipelines inputs, steps and outputs. The
function walks through the pipeline graph bottom-up from outputs and steps,
finding out all the required steps and inputs of the pipeline.
Args:
steps: steps used in the pipeline graph.
inputs: inputs used in the pipeline.
outputs: outputs definition of the pipeline.
Returns:
Tuple: Returns all the required steps, inputs and outputs.
"""
inputs = inputs or []
outputs = outputs or []
if isinstance(outputs, dict):
outputs = list(outputs.values())
# find out all steps in the pipeline by topological sort.
steps = steps or []
visited_steps = set(
steps + [output.parent for output in outputs if output.parent]
)
cur_steps = visited_steps.copy()
while cur_steps:
next_steps = set()
for step in cur_steps:
for output in step.outputs.artifacts:
if output.repeated:
output.reset_count()
for depend in step.depends:
if depend not in visited_steps:
next_steps.add(depend)
visited_steps.add(depend)
cur_steps = next_steps
# infer the pipeline inputs from step inputs.
infer_inputs = set()
for step in visited_steps:
for ipt in step.inputs:
infer_inputs |= cls._infer_pipeline_inputs(ipt)
cls._set_step_artifact_count(visited_steps)
if inputs:
if len(infer_inputs) != len(inputs):
raise ValueError(
"Please provide complete pipeline inputs list: expected=%s, given=%s"
% (len(infer_inputs), len(inputs))
)
unexpected = [ipt.name for ipt in inputs if ipt not in infer_inputs]
if unexpected:
raise ValueError(
"Do not provide inputs which is not used in pipeline: %s"
% unexpected
)
else:
inputs = sorted(
list(infer_inputs),
key=lambda x: 0 if x.variable_category == "parameters" else 1,
)
sorted_steps = cls._topo_sort(visited_steps)
cls._check_steps(steps)
return sorted_steps, inputs, outputs