easy_rec/python/layers/utils.py (186 lines of code) (raw):

# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Common util functions used by layers.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import json from google.protobuf import struct_pb2 from google.protobuf.descriptor import FieldDescriptor from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import variables try: from tensorflow.python.ops import kv_variable_ops except ImportError: kv_variable_ops = None ColumnNameInCollection = {} def _tensor_to_map(tensor): return { 'node_path': tensor.name, 'shape': tensor.shape.as_list() if tensor.shape else None, 'dtype': tensor.dtype.name } def _tensor_to_tensorinfo(tensor): tensor_info = {} if isinstance(tensor, sparse_tensor.SparseTensor): tensor_info['is_dense'] = False tensor_info['values'] = _tensor_to_map(tensor.values) tensor_info['indices'] = _tensor_to_map(tensor.indices) tensor_info['dense_shape'] = _tensor_to_map(tensor.dense_shape) else: tensor_info['is_dense'] = True tensor_info.update(_tensor_to_map(tensor)) return tensor_info def add_tensor_to_collection(collection_name, name, tensor): tensor_info = _tensor_to_tensorinfo(tensor) tensor_info['name'] = name update_attr_to_collection(collection_name, tensor_info) def append_tensor_to_collection(collection_name, name, key, tensor): tensor_info = _tensor_to_tensorinfo(tensor) append_attr_to_collection(collection_name, name, key, tensor_info) def _collection_item_key(col, name): return '%d#%s' % (id(col), name) def _process_item(collection_name, name, func): col = ops.get_collection_ref(collection_name) item_found = {} idx_found = -1 # add id(col) because col may re-new sometimes key = _collection_item_key(col, name) if key in ColumnNameInCollection: idx_found = ColumnNameInCollection[key] if idx_found >= len(col): raise Exception( 'Find column name in collection failed: index out of range') item_found = json.loads(col[idx_found]) if item_found['name'] != name: raise Exception( 'Find column name in collection failed: item name not match') func(item_found) col[idx_found] = json.dumps(item_found) else: func(item_found) col.append(json.dumps(item_found)) ColumnNameInCollection[key] = len(col) - 1 def append_attr_to_collection(collection_name, name, key, value): def append(item_found): if key not in item_found: item_found[key] = [] item_found[key].append(value) _process_item(collection_name, name, append) def update_attr_to_collection(collection_name, attrs): def update(item_found): item_found.update(attrs) _process_item(collection_name, attrs['name'], update) def unique_name_in_collection(collection_name, name): col = ops.get_collection_ref(collection_name) unique_name = name index = 0 while True: key = _collection_item_key(col, unique_name) if key not in ColumnNameInCollection: break index += 1 unique_name = '%s_%d' % (name, index) return unique_name def gen_embedding_attrs(column=None, variable=None, bucket_size=None, combiner=None, is_embedding_var=None): attrs = dict() attrs['name'] = column.name attrs['bucket_size'] = bucket_size attrs['combiner'] = combiner attrs['is_embedding_var'] = is_embedding_var attrs['weights_op_path'] = variable.name if kv_variable_ops: if isinstance(variable, kv_variable_ops.EmbeddingVariable): attrs['is_embedding_var'] = True attrs['embedding_var_keys'] = variable._shared_name + '-keys' attrs['embedding_var_values'] = variable._shared_name + '-values' elif (isinstance(variable, variables.PartitionedVariable)) and \ (isinstance(variable._get_variable_list()[0], kv_variable_ops.EmbeddingVariable)): attrs['embedding_var_keys'] = [v._shared_name + '-keys' for v in variable] attrs['embedding_var_values'] = [ v._shared_name + '-values' for v in variable ] else: attrs['is_embedding_var'] = False else: attrs['is_embedding_var'] = False return attrs def mark_input_src(name, src_desc): ops.add_to_collection(ops.GraphKeys.RANK_SERVICE_INPUT_SRC, json.dumps({ 'name': name, 'src': src_desc })) def is_proto_message(pb_obj, field): if not hasattr(pb_obj, 'DESCRIPTOR'): return False if field not in pb_obj.DESCRIPTOR.fields_by_name: return False field_type = pb_obj.DESCRIPTOR.fields_by_name[field].type return field_type == FieldDescriptor.TYPE_MESSAGE class Parameter(object): def __init__(self, params, is_struct, l2_reg=None): self.params = params self.is_struct = is_struct self._l2_reg = l2_reg @staticmethod def make_from_pb(config): return Parameter(config, False) def get_pb_config(self): assert not self.is_struct, 'Struct parameter can not convert to pb config' return self.params @property def l2_regularizer(self): return self._l2_reg @l2_regularizer.setter def l2_regularizer(self, value): self._l2_reg = value def __getattr__(self, key): if self.is_struct: if key not in self.params: return None value = self.params[key] if type(value) == struct_pb2.Struct: return Parameter(value, True, self._l2_reg) else: return value value = getattr(self.params, key) if is_proto_message(self.params, key): return Parameter(value, False, self._l2_reg) return value def __getitem__(self, key): return self.__getattr__(key) def get_or_default(self, key, def_val): if self.is_struct: if key in self.params: if def_val is None: return self.params[key] value = self.params[key] if type(value) == float: return type(def_val)(value) return value return def_val else: # pb message value = getattr(self.params, key, def_val) if hasattr(value, '__len__'): # repeated return value if len(value) > 0 else def_val try: if self.params.HasField(key): return value except ValueError: pass return def_val # maybe not equal to the default value of msg field def check_required(self, keys): if not self.is_struct: return if not isinstance(keys, (list, tuple)): keys = [keys] for key in keys: if key not in self.params: raise KeyError('%s must be set in params' % key) def has_field(self, key): if self.is_struct: return key in self.params else: return self.params.HasField(key)