ir::Value ApplyUpdate()

in torch_xla/csrc/view.cpp [64:122]


ir::Value ApplyUpdate(ir::Value ir_value,
                      const Alias::UpdateData& update_data) {
  // We first bring the source IR value forward, by reshaping and slicing.
  std::vector<ir::Value> tmp_values({ir_value});
  for (size_t i = 0; i < update_data.view_infos.size(); ++i) {
    const ViewInfo& view_info = update_data.view_infos[i];
    tmp_values.push_back(ApplyViewInfo(tmp_values.back(), view_info));
  }
  // We then move backward given the source update value, by reshaping and
  // slice-updating.
  ir::Value result = update_data.ir_value;
  for (size_t i = update_data.view_infos.size(); i > 0; --i) {
    const ViewInfo& view_info = update_data.view_infos[i - 1];
    switch (view_info.view_type) {
      case ViewInfo::Type::kSelect:
        result = ir::MakeNode<ir::ops::Unselect>(
            tmp_values[i - 1], result, view_info.select->dim,
            view_info.select->start, view_info.select->end,
            view_info.select->stride);
        break;
      case ViewInfo::Type::kNarrow:
        result = ir::MakeNode<ir::ops::UpdateSlice>(tmp_values[i - 1], result,
                                                    view_info.indices);
        break;
      case ViewInfo::Type::kNoOp:
        break;
      case ViewInfo::Type::kPermute:
        result = ir::MakeNode<ir::ops::Permute>(
            result, xla::InversePermutation(view_info.permutation));
        break;
      case ViewInfo::Type::kReshape:
        result = ir::MakeNode<ir::ops::View>(
            result, xla::util::ToVector<xla::int64_t>(
                        view_info.source_shape.dimensions()));
        break;
      case ViewInfo::Type::kResize:
        result = ir::MakeNode<ir::ops::Resize>(
            result, xla::util::ToVector<xla::int64_t>(
                        view_info.source_shape.dimensions()));
        break;
      case ViewInfo::Type::kAsStrided:
        result = ir::MakeNode<ir::ops::AsStridedViewUpdate>(
            tmp_values[i - 1], result,
            xla::util::ToVector<xla::int64_t>(
                view_info.source_shape.dimensions()),
            view_info.as_strided->stride, view_info.as_strided->offset);
        break;
      case ViewInfo::Type::kDiagonal:
        result = ir::MakeNode<ir::ops::DiagonalViewUpdate>(
            tmp_values[i - 1], result, view_info.diagonal->offset,
            view_info.diagonal->dim1, view_info.diagonal->dim2);
        break;
      default:
        XLA_ERROR() << "Invalid view type: "
                    << xla::util::GetEnumValue(view_info.view_type);
    }
  }
  return result;
}