void validatePointwiseArgs()

in cpp/ops/TensorMath.cpp [54:83]


void validatePointwiseArgs(const MathArg<FloatType<kWidth>::T>& a,
                           const MathArg<FloatType<kWidth>::T>& b,
                           CLTensor<FloatType<kWidth>::T>& out) {

  CL_ASSERT(out.isContiguous());

  if (a.t) {
    CL_ASSERT(a.t->isContiguous());

    if (a.useScalar) {
      CL_ASSERT(a.t->numElements() == 1);
    } else {
      CL_ASSERT(a.t->numElements() == out.numElements());
    }
  }

  if (b.t) {
    CL_ASSERT(b.t->isContiguous());

    if (b.useScalar) {
      CL_ASSERT(b.t->numElements() == 1);
    } else {
      CL_ASSERT(b.t->numElements() == out.numElements());
    }
  }

  if (a.t && b.t && !a.useScalar && !b.useScalar) {
    CL_ASSERT(a.t->isSameSize(*b.t));
  }
}