in torch_xla/csrc/view.cpp [64:122]
ir::Value ApplyUpdate(ir::Value ir_value,
const Alias::UpdateData& update_data) {
// We first bring the source IR value forward, by reshaping and slicing.
std::vector<ir::Value> tmp_values({ir_value});
for (size_t i = 0; i < update_data.view_infos.size(); ++i) {
const ViewInfo& view_info = update_data.view_infos[i];
tmp_values.push_back(ApplyViewInfo(tmp_values.back(), view_info));
}
// We then move backward given the source update value, by reshaping and
// slice-updating.
ir::Value result = update_data.ir_value;
for (size_t i = update_data.view_infos.size(); i > 0; --i) {
const ViewInfo& view_info = update_data.view_infos[i - 1];
switch (view_info.view_type) {
case ViewInfo::Type::kSelect:
result = ir::MakeNode<ir::ops::Unselect>(
tmp_values[i - 1], result, view_info.select->dim,
view_info.select->start, view_info.select->end,
view_info.select->stride);
break;
case ViewInfo::Type::kNarrow:
result = ir::MakeNode<ir::ops::UpdateSlice>(tmp_values[i - 1], result,
view_info.indices);
break;
case ViewInfo::Type::kNoOp:
break;
case ViewInfo::Type::kPermute:
result = ir::MakeNode<ir::ops::Permute>(
result, xla::InversePermutation(view_info.permutation));
break;
case ViewInfo::Type::kReshape:
result = ir::MakeNode<ir::ops::View>(
result, xla::util::ToVector<xla::int64_t>(
view_info.source_shape.dimensions()));
break;
case ViewInfo::Type::kResize:
result = ir::MakeNode<ir::ops::Resize>(
result, xla::util::ToVector<xla::int64_t>(
view_info.source_shape.dimensions()));
break;
case ViewInfo::Type::kAsStrided:
result = ir::MakeNode<ir::ops::AsStridedViewUpdate>(
tmp_values[i - 1], result,
xla::util::ToVector<xla::int64_t>(
view_info.source_shape.dimensions()),
view_info.as_strided->stride, view_info.as_strided->offset);
break;
case ViewInfo::Type::kDiagonal:
result = ir::MakeNode<ir::ops::DiagonalViewUpdate>(
tmp_values[i - 1], result, view_info.diagonal->offset,
view_info.diagonal->dim1, view_info.diagonal->dim2);
break;
default:
XLA_ERROR() << "Invalid view type: "
<< xla::util::GetEnumValue(view_info.view_type);
}
}
return result;
}