in Sources/CX10/functional_while.cc [182:260]
XlaOpVector Lower(LoweringContext* loctx) const {
size_t last_i = placeholders_.size();
auto body_builder = loctx->builder()->CreateSubBuilder("loop_body");
xla::XlaOp initial;
{
std::vector<xla::XlaOp> args;
args.reserve(operands().size() + 1);
for (size_t i = 0; i < last_i; ++i) {
args.push_back(loctx->GetOutputOp(operand(i)));
}
auto tmp = loctx->GetOutputOp(operand(last_i));
auto it = zeroLike(tmp);
args.push_back(it);
args.push_back(tmp);
for (size_t i = last_i + 1; i < operands().size(); ++i) {
args.push_back(loctx->GetOutputOp(operand(i)));
}
initial = xla::Tuple(loctx->builder(), args);
}
xla::XlaOp body_result;
{
auto* b = body_builder.get();
swift_xla::ir::Util::EmissionMap emap;
for (const auto& placeholder : placeholders_) {
emap[placeholder.node.get()] = swift_xla::ir::Util::kEmitted;
}
for (size_t i = last_i + 1; i < operands().size(); ++i) {
emap[operand(i).node] = swift_xla::ir::Util::kEmitted;
}
emap[index_placeholder_.node.get()] = swift_xla::ir::Util::kEmitted;
swift_xla::ir::LoweringContext body_loctx(b, loctx->device(),
std::move(emap));
auto t = xla::Parameter(
b, 0, swift_xla::XlaHelpers::ShapeOfXlaOp(initial), "tuple");
auto p1 = xla::GetTupleElement(t, last_i);
auto p2 = xla::GetTupleElement(t, last_i + 1);
for (size_t i = 0; i < placeholders_.size(); ++i) {
body_loctx.AssignOutputOp(placeholders_[i], xla::GetTupleElement(t, i));
}
for (size_t i = last_i + 1; i < operands().size(); ++i) {
body_loctx.AssignOutputOp(operand(i), xla::GetTupleElement(t, i + 1));
}
body_loctx.AssignOutputOp(index_placeholder_, p1);
std::vector<xla::XlaOp> tmps;
for (auto& result : results_) {
tmps.push_back(body_loctx.GetOutputOp(result));
}
tmps.push_back(p1 + oneLike(p1));
tmps.push_back(p2);
for (size_t i = last_i + 1; i < operands().size(); ++i) {
tmps.push_back(body_loctx.GetOutputOp(operand(i)));
}
body_result = xla::Tuple(b, tmps);
}
auto cond_builder = loctx->builder()->CreateSubBuilder("cond_body");
xla::XlaOp cond_result;
{
auto* b = cond_builder.get();
auto t = xla::Parameter(
b, 0, swift_xla::XlaHelpers::ShapeOfXlaOp(initial), "tuple");
auto p1 = xla::GetTupleElement(t, last_i);
auto p2 = xla::GetTupleElement(t, last_i + 1);
cond_result = xla::Lt(p1, p2);
}
auto result = xla::While(
cond_builder->Build(cond_result).ConsumeValueOrDie(),
body_builder->Build(body_result).ConsumeValueOrDie(), initial);
std::vector<xla::XlaOp> results;
for (size_t i = 0; i < last_i; ++i) {
results.push_back(xla::GetTupleElement(result, i));
}
return ReturnOps(results, loctx);
}