in tensorflow_recommenders_addons/dynamic_embedding/core/ops/cuckoo_hashtable_ops.cc [47:113]
Status ValidateTableResourceHandle(InferenceContext* c, ShapeHandle keys,
const string& key_dtype_attr,
const string& value_dtype_attr,
bool is_lookup,
ShapeAndType* output_shape_and_type) {
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data == nullptr || handle_data->size() != 2) {
output_shape_and_type->shape = c->UnknownShape();
output_shape_and_type->dtype = DT_INVALID;
} else {
const ShapeAndType& key_shape_and_type = (*handle_data)[0];
const ShapeAndType& value_shape_and_type = (*handle_data)[1];
DataType key_dtype;
TF_RETURN_IF_ERROR(c->GetAttr(key_dtype_attr, &key_dtype));
if (key_shape_and_type.dtype != key_dtype) {
return errors::InvalidArgument(
"Trying to read value with wrong dtype. "
"Expected ",
DataTypeString(key_shape_and_type.dtype), " got ",
DataTypeString(key_dtype));
}
DataType value_dtype;
TF_RETURN_IF_ERROR(c->GetAttr(value_dtype_attr, &value_dtype));
if (value_shape_and_type.dtype != value_dtype) {
return errors::InvalidArgument(
"Trying to read value with wrong dtype. "
"Expected ",
DataTypeString(value_shape_and_type.dtype), " got ",
DataTypeString(value_dtype));
}
output_shape_and_type->dtype = value_shape_and_type.dtype;
if (is_lookup) {
if (c->RankKnown(key_shape_and_type.shape) && c->RankKnown(keys)) {
int keys_rank = c->Rank(keys);
int key_suffix_rank = c->Rank(key_shape_and_type.shape);
if (keys_rank < key_suffix_rank) {
return errors::InvalidArgument(
"Expected keys to have suffix ",
c->DebugString(key_shape_and_type.shape),
" but saw shape: ", c->DebugString(keys));
}
for (int d = 0; d < key_suffix_rank; d++) {
// Ensure the suffix of keys match what's in the Table.
DimensionHandle dim = c->Dim(key_shape_and_type.shape, d);
TF_RETURN_IF_ERROR(
c->ReplaceDim(keys, keys_rank - key_suffix_rank + d, dim, &keys));
}
std::vector<DimensionHandle> keys_prefix_vec;
keys_prefix_vec.reserve(keys_rank - key_suffix_rank);
for (int d = 0; d < keys_rank - key_suffix_rank; ++d) {
keys_prefix_vec.push_back(c->Dim(keys, d));
}
ShapeHandle keys_prefix = c->MakeShape(keys_prefix_vec);
TF_RETURN_IF_ERROR(c->Concatenate(keys_prefix,
value_shape_and_type.shape,
&output_shape_and_type->shape));
} else {
output_shape_and_type->shape = c->UnknownShape();
}
} else {
TF_RETURN_IF_ERROR(c->Concatenate(keys, value_shape_and_type.shape,
&output_shape_and_type->shape));
}
}
return Status::OK();
}