void SanityCheckInputs()

in lingvo/core/ops/beam_search_step_op_kernels.cc [482:645]


  void SanityCheckInputs(OpKernelContext* ctx) {
    const Tensor& scores = ctx->input(0);
    const Tensor& atten_probs = ctx->input(1);
    const Tensor& best_scores = ctx->input(2);
    const Tensor& cumulative_scores = ctx->input(3);
    const Tensor& in_scores = ctx->input(4);
    const Tensor& in_hyps = ctx->input(5);
    const Tensor& in_prev_hyps = ctx->input(6);
    const Tensor& in_done_hyps = ctx->input(7);
    const Tensor& in_atten_probs = ctx->input(8);
    const Tensor& cur_step = ctx->input(op_version == 2 ? 11 : 10);

    OP_REQUIRES(
        ctx, scores.dims() == 2,
        errors::InvalidArgument(
            "Failed tensor shape sanity check. scores.dims() == 2. Got ",
            scores.dims()));
    OP_REQUIRES(
        ctx, atten_probs.dims() == 2,
        errors::InvalidArgument(
            "Failed tensor shape sanity check. atten_probs.dims() == 2. Got ",
            atten_probs.dims()));
    OP_REQUIRES(
        ctx, best_scores.dims() == 1,
        errors::InvalidArgument(
            "Failed tensor shape sanity check. best_scores.dims() == 1. Got ",
            best_scores.dims()));
    OP_REQUIRES(ctx, cumulative_scores.dims() == 1,
                errors::InvalidArgument("Failed tensor shape sanity check. "
                                        "cumulative_scores.dims() == 1. Got ",
                                        cumulative_scores.dims()));
    OP_REQUIRES(
        ctx, in_scores.dims() == 2,
        errors::InvalidArgument(
            "Failed tensor shape sanity check. in_scores.dims() == 2. Got ",
            in_scores.dims()));
    OP_REQUIRES(ctx, in_hyps.dims() == 2,
                errors::InvalidArgument("Failed tensor shape sanity check. "
                                        "in_hyps.dims() == 2. Got ",
                                        in_hyps.dims()));

    OP_REQUIRES(
        ctx, in_prev_hyps.dims() == 2,
        errors::InvalidArgument(
            "Failed tensor shape sanity check. in_prev_hyps.dims() == 2. Got ",
            in_prev_hyps.dims()));
    OP_REQUIRES(
        ctx, in_done_hyps.dims() == 2,
        errors::InvalidArgument(
            "Failed tensor shape sanity check. in_done_hyps.dims() == 2. Got ",
            in_done_hyps.dims()));
    OP_REQUIRES(ctx, in_atten_probs.dims() == 3,
                errors::InvalidArgument("Failed tensor shape sanity check. "
                                        "in_atten_probs.dims() == 3. Got ",
                                        in_atten_probs.dims()));
    OP_REQUIRES(
        ctx, cur_step.dims() == 0,
        errors::InvalidArgument(
            "Failed tensor shape sanity check. cur_step.dims() == 0. Got ",
            cur_step.dims()));
    OP_REQUIRES(ctx, scores.dim_size(0) == atten_probs.dim_size(0),
                errors::InvalidArgument("Failed tensor shape sanity check. "
                                        "scores.dim_size(0) == "
                                        "atten_probs.dim_size(0). Got ",
                                        scores.dim_size(0), " and ",
                                        atten_probs.dim_size(0)));
    OP_REQUIRES(ctx, scores.dim_size(0) == cumulative_scores.dim_size(0),
                errors::InvalidArgument("Failed tensor shape sanity check. "
                                        "scores.dim_size(0) == "
                                        "cumulative_scores.dim_size(0). Got ",
                                        scores.dim_size(0), " and ",
                                        cumulative_scores.dim_size(0)));
    OP_REQUIRES(ctx, scores.dim_size(0) % best_scores.dim_size(0) == 0,
                errors::InvalidArgument("Failed tensor shape sanity check. "
                                        "scores.dim_size(0) % "
                                        "best_scores.dim_size(0) == 0. Got ",
                                        scores.dim_size(0), " and ",
                                        best_scores.dim_size(0)));
    OP_REQUIRES(
        ctx, scores.dim_size(0) / best_scores.dim_size(0) == num_hyps_per_beam_,
        errors::InvalidArgument(
            "Failed tensor shape sanity check. "
            "scores.dim_size(0) / best_scores.dim_size(0) "
            "== num_hyps_per_beam_. Got ",
            scores.dim_size(0), " and ", best_scores.dim_size(0),
            " where num_hyps_per_beam_ = ", num_hyps_per_beam_));
    OP_REQUIRES(ctx, scores.dim_size(0) == in_hyps.dim_size(1),
                errors::InvalidArgument(
                    "Failed tensor shape sanity check. "
                    "scores.dim_size(0) == in_hyps.dim_size(1). Got ",
                    scores.dim_size(0), " and ", in_hyps.dim_size(1)));
    OP_REQUIRES(ctx, in_hyps.dim_size(0) == in_scores.dim_size(0),
                errors::InvalidArgument("Failed tensor shape sanity check. "
                                        "in_hyps.dim_size(0) == "
                                        "in_scores.dim_size(0). Got ",
                                        in_hyps.dim_size(0), " and ",
                                        in_scores.dim_size(0)));
    OP_REQUIRES(ctx, in_hyps.dim_size(0) == in_prev_hyps.dim_size(0),
                errors::InvalidArgument("Failed tensor shape sanity check. "
                                        "in_hyps.dim_size(0) == "
                                        "in_prev_hyps.dim_size(0). Got ",
                                        in_hyps.dim_size(0), " and ",
                                        in_prev_hyps.dim_size(0)));
    OP_REQUIRES(ctx, in_hyps.dim_size(0) == in_done_hyps.dim_size(0),
                errors::InvalidArgument("Failed tensor shape sanity check. "
                                        "in_hyps.dim_size(0) == "
                                        "in_done_hyps.dim_size(0). Got ",
                                        in_hyps.dim_size(0), " and ",
                                        in_done_hyps.dim_size(0)));
    OP_REQUIRES(ctx, in_hyps.dim_size(0) == in_atten_probs.dim_size(0),
                errors::InvalidArgument("Failed tensor shape sanity check. "
                                        "in_hyps.dim_size(0) == "
                                        "in_atten_probs.dim_size(0). Got ",
                                        in_hyps.dim_size(0), " and ",
                                        in_atten_probs.dim_size(0)));
    OP_REQUIRES(ctx, in_hyps.dim_size(1) == in_scores.dim_size(1),
                errors::InvalidArgument("Failed tensor shape sanity check. "
                                        "in_hyps.dim_size(1) == "
                                        "in_scores.dim_size(1). Got ",
                                        in_hyps.dim_size(1), " and ",
                                        in_scores.dim_size(1)));
    OP_REQUIRES(ctx, in_hyps.dim_size(1) == in_prev_hyps.dim_size(1),
                errors::InvalidArgument("Failed tensor shape sanity check. "
                                        "in_hyps.dim_size(1) == "
                                        "in_prev_hyps.dim_size(1). Got ",
                                        in_hyps.dim_size(1), " and ",
                                        in_prev_hyps.dim_size(1)));
    OP_REQUIRES(ctx, in_hyps.dim_size(1) == in_done_hyps.dim_size(1),
                errors::InvalidArgument("Failed tensor shape sanity check. "
                                        "in_hyps.dim_size(1) == "
                                        "in_done_hyps.dim_size(1). Got ",
                                        in_hyps.dim_size(1), " and ",
                                        in_done_hyps.dim_size(1)));
    OP_REQUIRES(ctx, in_hyps.dim_size(1) == in_atten_probs.dim_size(1),
                errors::InvalidArgument("Failed tensor shape sanity check. "
                                        "in_hyps.dim_size(1) == "
                                        "in_atten_probs.dim_size(1). Got ",
                                        in_hyps.dim_size(1), " and ",
                                        in_atten_probs.dim_size(1)));
    OP_REQUIRES(
        ctx, atten_probs.dim_size(1) == in_atten_probs.dim_size(2),
        errors::InvalidArgument(
            "Failed tensor shape sanity check. "
            "atten_probs.dim_size(1) == in_atten_probs.dim_size(2). Got ",
            atten_probs.dim_size(1), " and ", in_atten_probs.dim_size(2)));

    if (op_version == 2) {
      const Tensor& in_beam_done = ctx->input(9);
      OP_REQUIRES(ctx, in_beam_done.dtype() == DT_BOOL,
                  errors::InvalidArgument("Failed tensor type sanity check. "
                                          "in_beam_done is tf.bool. Got ",
                                          in_beam_done.dtype()));
      OP_REQUIRES(ctx, in_beam_done.dims() == 1,
                  errors::InvalidArgument("Failed tensor shape sanity check. "
                                          "in_beam_done.dims() == 1. Got ",
                                          in_beam_done.dims()));
      OP_REQUIRES(ctx, in_beam_done.dim_size(0) == best_scores.dim_size(0),
                  errors::InvalidArgument("Failed tensor shape sanity check. "
                                          "in_beam_done.dim_size(0) == "
                                          "best_scores.dim_size(0). Got ",
                                          in_beam_done.dim_size(0), " and ",
                                          best_scores.dim_size(0)));
    }
  }