in lingvo/core/ops/beam_search_step_op_kernels.cc [482:645]
void SanityCheckInputs(OpKernelContext* ctx) {
const Tensor& scores = ctx->input(0);
const Tensor& atten_probs = ctx->input(1);
const Tensor& best_scores = ctx->input(2);
const Tensor& cumulative_scores = ctx->input(3);
const Tensor& in_scores = ctx->input(4);
const Tensor& in_hyps = ctx->input(5);
const Tensor& in_prev_hyps = ctx->input(6);
const Tensor& in_done_hyps = ctx->input(7);
const Tensor& in_atten_probs = ctx->input(8);
const Tensor& cur_step = ctx->input(op_version == 2 ? 11 : 10);
OP_REQUIRES(
ctx, scores.dims() == 2,
errors::InvalidArgument(
"Failed tensor shape sanity check. scores.dims() == 2. Got ",
scores.dims()));
OP_REQUIRES(
ctx, atten_probs.dims() == 2,
errors::InvalidArgument(
"Failed tensor shape sanity check. atten_probs.dims() == 2. Got ",
atten_probs.dims()));
OP_REQUIRES(
ctx, best_scores.dims() == 1,
errors::InvalidArgument(
"Failed tensor shape sanity check. best_scores.dims() == 1. Got ",
best_scores.dims()));
OP_REQUIRES(ctx, cumulative_scores.dims() == 1,
errors::InvalidArgument("Failed tensor shape sanity check. "
"cumulative_scores.dims() == 1. Got ",
cumulative_scores.dims()));
OP_REQUIRES(
ctx, in_scores.dims() == 2,
errors::InvalidArgument(
"Failed tensor shape sanity check. in_scores.dims() == 2. Got ",
in_scores.dims()));
OP_REQUIRES(ctx, in_hyps.dims() == 2,
errors::InvalidArgument("Failed tensor shape sanity check. "
"in_hyps.dims() == 2. Got ",
in_hyps.dims()));
OP_REQUIRES(
ctx, in_prev_hyps.dims() == 2,
errors::InvalidArgument(
"Failed tensor shape sanity check. in_prev_hyps.dims() == 2. Got ",
in_prev_hyps.dims()));
OP_REQUIRES(
ctx, in_done_hyps.dims() == 2,
errors::InvalidArgument(
"Failed tensor shape sanity check. in_done_hyps.dims() == 2. Got ",
in_done_hyps.dims()));
OP_REQUIRES(ctx, in_atten_probs.dims() == 3,
errors::InvalidArgument("Failed tensor shape sanity check. "
"in_atten_probs.dims() == 3. Got ",
in_atten_probs.dims()));
OP_REQUIRES(
ctx, cur_step.dims() == 0,
errors::InvalidArgument(
"Failed tensor shape sanity check. cur_step.dims() == 0. Got ",
cur_step.dims()));
OP_REQUIRES(ctx, scores.dim_size(0) == atten_probs.dim_size(0),
errors::InvalidArgument("Failed tensor shape sanity check. "
"scores.dim_size(0) == "
"atten_probs.dim_size(0). Got ",
scores.dim_size(0), " and ",
atten_probs.dim_size(0)));
OP_REQUIRES(ctx, scores.dim_size(0) == cumulative_scores.dim_size(0),
errors::InvalidArgument("Failed tensor shape sanity check. "
"scores.dim_size(0) == "
"cumulative_scores.dim_size(0). Got ",
scores.dim_size(0), " and ",
cumulative_scores.dim_size(0)));
OP_REQUIRES(ctx, scores.dim_size(0) % best_scores.dim_size(0) == 0,
errors::InvalidArgument("Failed tensor shape sanity check. "
"scores.dim_size(0) % "
"best_scores.dim_size(0) == 0. Got ",
scores.dim_size(0), " and ",
best_scores.dim_size(0)));
OP_REQUIRES(
ctx, scores.dim_size(0) / best_scores.dim_size(0) == num_hyps_per_beam_,
errors::InvalidArgument(
"Failed tensor shape sanity check. "
"scores.dim_size(0) / best_scores.dim_size(0) "
"== num_hyps_per_beam_. Got ",
scores.dim_size(0), " and ", best_scores.dim_size(0),
" where num_hyps_per_beam_ = ", num_hyps_per_beam_));
OP_REQUIRES(ctx, scores.dim_size(0) == in_hyps.dim_size(1),
errors::InvalidArgument(
"Failed tensor shape sanity check. "
"scores.dim_size(0) == in_hyps.dim_size(1). Got ",
scores.dim_size(0), " and ", in_hyps.dim_size(1)));
OP_REQUIRES(ctx, in_hyps.dim_size(0) == in_scores.dim_size(0),
errors::InvalidArgument("Failed tensor shape sanity check. "
"in_hyps.dim_size(0) == "
"in_scores.dim_size(0). Got ",
in_hyps.dim_size(0), " and ",
in_scores.dim_size(0)));
OP_REQUIRES(ctx, in_hyps.dim_size(0) == in_prev_hyps.dim_size(0),
errors::InvalidArgument("Failed tensor shape sanity check. "
"in_hyps.dim_size(0) == "
"in_prev_hyps.dim_size(0). Got ",
in_hyps.dim_size(0), " and ",
in_prev_hyps.dim_size(0)));
OP_REQUIRES(ctx, in_hyps.dim_size(0) == in_done_hyps.dim_size(0),
errors::InvalidArgument("Failed tensor shape sanity check. "
"in_hyps.dim_size(0) == "
"in_done_hyps.dim_size(0). Got ",
in_hyps.dim_size(0), " and ",
in_done_hyps.dim_size(0)));
OP_REQUIRES(ctx, in_hyps.dim_size(0) == in_atten_probs.dim_size(0),
errors::InvalidArgument("Failed tensor shape sanity check. "
"in_hyps.dim_size(0) == "
"in_atten_probs.dim_size(0). Got ",
in_hyps.dim_size(0), " and ",
in_atten_probs.dim_size(0)));
OP_REQUIRES(ctx, in_hyps.dim_size(1) == in_scores.dim_size(1),
errors::InvalidArgument("Failed tensor shape sanity check. "
"in_hyps.dim_size(1) == "
"in_scores.dim_size(1). Got ",
in_hyps.dim_size(1), " and ",
in_scores.dim_size(1)));
OP_REQUIRES(ctx, in_hyps.dim_size(1) == in_prev_hyps.dim_size(1),
errors::InvalidArgument("Failed tensor shape sanity check. "
"in_hyps.dim_size(1) == "
"in_prev_hyps.dim_size(1). Got ",
in_hyps.dim_size(1), " and ",
in_prev_hyps.dim_size(1)));
OP_REQUIRES(ctx, in_hyps.dim_size(1) == in_done_hyps.dim_size(1),
errors::InvalidArgument("Failed tensor shape sanity check. "
"in_hyps.dim_size(1) == "
"in_done_hyps.dim_size(1). Got ",
in_hyps.dim_size(1), " and ",
in_done_hyps.dim_size(1)));
OP_REQUIRES(ctx, in_hyps.dim_size(1) == in_atten_probs.dim_size(1),
errors::InvalidArgument("Failed tensor shape sanity check. "
"in_hyps.dim_size(1) == "
"in_atten_probs.dim_size(1). Got ",
in_hyps.dim_size(1), " and ",
in_atten_probs.dim_size(1)));
OP_REQUIRES(
ctx, atten_probs.dim_size(1) == in_atten_probs.dim_size(2),
errors::InvalidArgument(
"Failed tensor shape sanity check. "
"atten_probs.dim_size(1) == in_atten_probs.dim_size(2). Got ",
atten_probs.dim_size(1), " and ", in_atten_probs.dim_size(2)));
if (op_version == 2) {
const Tensor& in_beam_done = ctx->input(9);
OP_REQUIRES(ctx, in_beam_done.dtype() == DT_BOOL,
errors::InvalidArgument("Failed tensor type sanity check. "
"in_beam_done is tf.bool. Got ",
in_beam_done.dtype()));
OP_REQUIRES(ctx, in_beam_done.dims() == 1,
errors::InvalidArgument("Failed tensor shape sanity check. "
"in_beam_done.dims() == 1. Got ",
in_beam_done.dims()));
OP_REQUIRES(ctx, in_beam_done.dim_size(0) == best_scores.dim_size(0),
errors::InvalidArgument("Failed tensor shape sanity check. "
"in_beam_done.dim_size(0) == "
"best_scores.dim_size(0). Got ",
in_beam_done.dim_size(0), " and ",
best_scores.dim_size(0)));
}
}