in src/operator/tensor/matrix_op-inl.h [50:171]
inline bool ReshapeShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
const ReshapeParam& param_ = nnvm::get<ReshapeParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]";
CHECK_EQ(out_attrs->size(), 1U);
CHECK_EQ(param_.target_shape.ndim() > 0 ||
param_.shape.ndim() > 0, true) << "targe_shape or shape must be present.";
const TShape &dshape = (*in_attrs)[0];
if (dshape.ndim() == 0) return false;
if (param_.shape.ndim() != 0) {
std::vector<int> dshape_vec;
std::vector<int> param_shape_vec(param_.shape.begin(), param_.shape.end());
for (index_t i = 0; i < dshape.ndim(); ++i) {
dshape_vec.push_back(dshape[i]);
}
std::vector<int> tmp;
size_t src_idx = 0;
int inf_idx = -1;
if (param_.reverse) {
std::reverse(dshape_vec.begin(), dshape_vec.end());
std::reverse(param_shape_vec.begin(), param_shape_vec.end());
}
auto dshape_len = dshape_vec.size();
auto params_len = param_shape_vec.size();
for (index_t i = 0; i < params_len; ++i) {
int proposed_dim = param_shape_vec[i];
if (proposed_dim == 0) {
// keep same
CHECK_LT(src_idx, dshape_len);
tmp.push_back(dshape_vec[src_idx++]);
} else if (proposed_dim == -1) {
// infer
CHECK_LT(inf_idx, 0) << "One and only one dim can be inferred";
inf_idx = i;
tmp.push_back(1);
src_idx++;
} else if (proposed_dim == -2) {
// copy all remaining dims from source
while (src_idx < dshape_len) {
size_t dn = dshape_vec[src_idx++];
tmp.push_back(dn);
}
} else if (proposed_dim == -3) {
// merge two dims from source
CHECK_LT(src_idx, dshape_len-1);
size_t d1 = dshape_vec[src_idx++];
size_t d2 = dshape_vec[src_idx++];
size_t dn = d1 * d2;
tmp.push_back(dn);
} else if (proposed_dim == -4) {
// split the source dim s into two dims
// read the left dim and then the right dim (either can be -1)
CHECK_LT(i + 2, params_len);
CHECK_LT(src_idx, dshape_len);
size_t d0 = dshape_vec[src_idx++];
int d1 = param_shape_vec[++i];
int d2 = param_shape_vec[++i];
CHECK(d1 != -1 || d2 != -1) << "Split dims cannot both be -1.";
if (d1 == -1) d1 = d0 / d2;
if (d2 == -1) d2 = d0 / d1;
CHECK_EQ(d1 * d2, static_cast<int>(d0)) <<
"Split dims " << d1 << ", " << d2 << " do not divide original dim " << d0;
tmp.push_back(d1);
tmp.push_back(d2);
} else {
// greater than 0, new shape
tmp.push_back(proposed_dim);
src_idx++;
}
}
if (inf_idx >= 0) {
if (dshape.Size() > 0) {
int new_size = 1;
for (int x : tmp) new_size *= x;
tmp[inf_idx] = dshape.Size() / new_size;
} else {
tmp[inf_idx] = 0;
}
}
if (param_.reverse) {
std::reverse(param_shape_vec.begin(), param_shape_vec.end());
std::reverse(dshape_vec.begin(), dshape_vec.end());
std::reverse(tmp.begin(), tmp.end());
}
TShape oshape(tmp.begin(), tmp.end());
CHECK_EQ(oshape.Size(), dshape.Size())
<< "Target shape size is different to source. "
<< "Target: " << oshape
<< "\nSource: " << dshape;
out_attrs->clear();
out_attrs->push_back(oshape);
} else {
LOG(INFO) << "Using target_shape will be deprecated.";
TShape oshape = param_.target_shape;
int neg_count = 0;
index_t inf_idx = 0;
index_t start_idx = param_.keep_highest ? 1 : 0;
if (param_.keep_highest) {
oshape[0] = dshape[0];
}
for (index_t i = start_idx; i < oshape.ndim(); ++i) {
if (oshape[i] == 0) {
neg_count++;
inf_idx = i;
}
}
if (neg_count == 1) {
oshape[inf_idx] = 1;
oshape[inf_idx] = dshape.Size() / oshape.Size();
}
CHECK(oshape.Size() == dshape.Size())
<< "Target shape size is different to source. "
<< "Target: " << param_.target_shape.Size()
<< "\nSource: " << dshape.Size();
out_attrs->clear();
out_attrs->push_back(oshape);
}
return true;
}