XlaOpVector Lower()

in Sources/CX10/functional_while.cc [182:260]


  XlaOpVector Lower(LoweringContext* loctx) const {
    size_t last_i = placeholders_.size();

    auto body_builder = loctx->builder()->CreateSubBuilder("loop_body");
    xla::XlaOp initial;
    {
      std::vector<xla::XlaOp> args;
      args.reserve(operands().size() + 1);
      for (size_t i = 0; i < last_i; ++i) {
        args.push_back(loctx->GetOutputOp(operand(i)));
      }
      auto tmp = loctx->GetOutputOp(operand(last_i));
      auto it = zeroLike(tmp);
      args.push_back(it);
      args.push_back(tmp);
      for (size_t i = last_i + 1; i < operands().size(); ++i) {
        args.push_back(loctx->GetOutputOp(operand(i)));
      }

      initial = xla::Tuple(loctx->builder(), args);
    }
    xla::XlaOp body_result;
    {
      auto* b = body_builder.get();
      swift_xla::ir::Util::EmissionMap emap;
      for (const auto& placeholder : placeholders_) {
        emap[placeholder.node.get()] = swift_xla::ir::Util::kEmitted;
      }
      for (size_t i = last_i + 1; i < operands().size(); ++i) {
        emap[operand(i).node] = swift_xla::ir::Util::kEmitted;
      }
      emap[index_placeholder_.node.get()] = swift_xla::ir::Util::kEmitted;
      swift_xla::ir::LoweringContext body_loctx(b, loctx->device(),
                                                std::move(emap));
      auto t = xla::Parameter(
          b, 0, swift_xla::XlaHelpers::ShapeOfXlaOp(initial), "tuple");
      auto p1 = xla::GetTupleElement(t, last_i);
      auto p2 = xla::GetTupleElement(t, last_i + 1);
      for (size_t i = 0; i < placeholders_.size(); ++i) {
        body_loctx.AssignOutputOp(placeholders_[i], xla::GetTupleElement(t, i));
      }
      for (size_t i = last_i + 1; i < operands().size(); ++i) {
        body_loctx.AssignOutputOp(operand(i), xla::GetTupleElement(t, i + 1));
      }
      body_loctx.AssignOutputOp(index_placeholder_, p1);

      std::vector<xla::XlaOp> tmps;
      for (auto& result : results_) {
        tmps.push_back(body_loctx.GetOutputOp(result));
      }
      tmps.push_back(p1 + oneLike(p1));
      tmps.push_back(p2);
      for (size_t i = last_i + 1; i < operands().size(); ++i) {
        tmps.push_back(body_loctx.GetOutputOp(operand(i)));
      }
      body_result = xla::Tuple(b, tmps);
    }

    auto cond_builder = loctx->builder()->CreateSubBuilder("cond_body");
    xla::XlaOp cond_result;
    {
      auto* b = cond_builder.get();
      auto t = xla::Parameter(
          b, 0, swift_xla::XlaHelpers::ShapeOfXlaOp(initial), "tuple");
      auto p1 = xla::GetTupleElement(t, last_i);
      auto p2 = xla::GetTupleElement(t, last_i + 1);
      cond_result = xla::Lt(p1, p2);
    }

    auto result = xla::While(
        cond_builder->Build(cond_result).ConsumeValueOrDie(),
        body_builder->Build(body_result).ConsumeValueOrDie(), initial);

    std::vector<xla::XlaOp> results;
    for (size_t i = 0; i < last_i; ++i) {
      results.push_back(xla::GetTupleElement(result, i));
    }
    return ReturnOps(results, loctx);
  }