inline bool CheckReorderToOpMem()

in tensorflow/tensorflow/core/util/mkl_util.h [1632:2153]


  inline bool CheckReorderToOpMem(const memory::desc& op_md,
                                  void* reorder_data_handle,
                                  const engine& engine) {
    DCHECK(reorder_data_handle);
    DCHECK(user_memory_);
    if (IsReorderNeeded(op_md)) {
      // TODO(nhasabni): can we remove dynamic memory allocation?
      // primitive reuse don't allow two same reorder prim in
      // one stream, so submit it immediately
      reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
      stream cpu_stream(engine);
      reorder(*user_memory_, *reorder_memory_)
          .execute(cpu_stream, *user_memory_, *reorder_memory_);
#else
  inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
                                  void* reorder_data_handle) {
    CHECK_NOTNULL(reorder_data_handle);
    CHECK_NOTNULL(user_memory_);
    if (IsReorderNeeded(op_pd)) {
      std::vector<primitive> net;
      reorder_memory_ = new memory(op_pd, reorder_data_handle);
      net.push_back(FindOrCreateReorder<T>(user_memory_, reorder_memory_));
      stream(stream::kind::eager).submit(net).wait();
#endif  // ENABLE_MKLDNN_V1
      return true;
    }
    return false;
  }

/// Another overloaded version of CheckReorderToOpMem that accepts Tensor
/// where output of reorder needs to be stored.
///
/// @input: op_md - memory primitive descriptor (memory descriptor for v1.x)
///                 of the given input of an operation
/// @reorder_tensor - Tensor whose buffer is to be used to store output of
///                   reorder. Primitive does not check if buffer is
///                   enough size to write.
/// @input: net - net to which to add reorder primitive in case it is needed.
/// @input: net_args - net to which user and reorder memories are added if
///                    needed. Each entry is a key-value pair of the form
///                    <argument-type, mkldnn::memory>.
/// @input: engine - MKL-DNN's abstraction of a computational device
/// @return: true in case reorder of input is needed; false, otherwise.
#ifdef ENABLE_MKLDNN_V1
  inline bool CheckReorderToOpMem(const memory::desc& op_md,
                                  Tensor* reorder_tensor,
                                  std::vector<primitive>& net,
                                  std::vector<MemoryArgsMap>& net_args,
                                  const engine& engine) {
    DCHECK(reorder_tensor);
    return CheckReorderToOpMem(op_md, GetTensorBuffer(reorder_tensor), net,
                               net_args, engine);
  }
#else
  inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
                                  Tensor* reorder_tensor,
                                  std::vector<primitive>* net) {
    CHECK_NOTNULL(net);
    CHECK_NOTNULL(reorder_tensor);
    return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), net);
  }
#endif  // ENABLE_MKLDNN_V1

  /// TODO: this is a faster path with reorder primitive cache compared with
  /// CheckReorderToOpMem(op_md, reorder_tensor, net, net_args, engine), will
  /// remove
  /// slow path in the future
  inline bool CheckReorderToOpMem(const MEMORY_PRIMITIVE_DESC& op_pd,
                                  Tensor* reorder_tensor) {
    DCHECK(reorder_tensor);
#ifdef ENABLE_MKLDNN_V1
    return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor),
                               *cpu_engine_);
#else
    return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor));
#endif  // ENABLE_MKLDNN_V1
  }

  /// Function to handle output reorder
  ///
  /// This function performs very similar functionality as input reordering
  /// function above. The only difference is that this function does not add
  /// reorder primitive to the net. The reason for this is: the reorder
  /// primitive for output needs to be added to the list only after operation
  /// has executed. But we need to prepare a temporary buffer in case output
  /// reorder is needed. And this temporary buffer will hold the output of
  /// an operation before it is fed to reorder primitive.
  ///
  /// @input - memory primitive descriptor (memory descriptor for v1.x) for the
  ///          given output of an operation
  /// @return: true in case reorder of output is needed; false, otherwise.
  inline bool PrepareReorderToUserMemIfReq(const MEMORY_PRIMITIVE_DESC& op_pd) {
    DCHECK(user_memory_);
    if (IsReorderNeeded(op_pd)) {
      // TODO(nhasabni): can we remove dynamic memory allocation?
      reorder_memory_ =
          new MEMORY_CONSTRUCTOR_WITHOUT_DATA(op_pd, *cpu_engine_);
      return true;
    }
    return false;
  }

/// Function to actually insert reorder primitive in the net
///
/// This function completes remaining part of output reordering. It inserts
/// a reordering primitive from the temporary buffer that holds the output
/// to the user-specified output buffer.
///
/// @input: net - net to which to add reorder primitive
/// @input: net_args - net to which user and reorder memories are added if
///                    needed. Each entry is a key-value pair of the form
///                    <argument-type, mkldnn::memory>.
#ifdef ENABLE_MKLDNN_V1
  inline void InsertReorderToUserMem(std::vector<primitive>& net,
                                     std::vector<MemoryArgsMap>& net_args) {
    DCHECK(user_memory_);
    DCHECK(reorder_memory_);
    net.push_back(CreateReorder(reorder_memory_, user_memory_));
    net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *reorder_memory_},
                                     {MKLDNN_ARG_TO, *user_memory_}});
  }
#else
  inline void InsertReorderToUserMem(std::vector<primitive>* net) {
    CHECK_NOTNULL(net);
    CHECK_NOTNULL(user_memory_);
    CHECK_NOTNULL(reorder_memory_);
    net->push_back(CreateReorder(reorder_memory_, user_memory_));
  }
#endif  // ENABLE_MKLDNN_V1

  /// TODO: this is a faster path with reorder primitive cache compared with
  ///       InsertReorderToUserMem(net, net_args), will remove
  ///       slow path in the future
  inline void InsertReorderToUserMem() {
    DCHECK(user_memory_);
    DCHECK(reorder_memory_);
#ifdef ENABLE_MKLDNN_V1
    DCHECK(cpu_engine_);
    stream cpu_stream(cpu_engine_);
#endif  // ENABLE_MKLDNN_V1
    // primitive reuse don't allow two same reorder prim in
    // one stream, so submit it immediately
    std::vector<primitive> net;
#ifdef ENABLE_MKLDNN_V1
    std::vector<MemoryArgsMap> net_args;
    net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_));
    net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *reorder_memory_},
                                     {MKLDNN_ARG_TO, *user_memory_}});
    DCHECK_EQ(net.size(), net_args.size());
    for (size_t i = 0; i < net.size(); ++i) {
      net.at(i).execute(cpu_stream, net_args.at(i));
    }
    cpu_stream.wait();
#else
    net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_));
    stream(stream::kind::eager).submit(net).wait();
#endif  // ENABLE_MKLDNN_V1
  }
};

/// Base class for operations with reuse of primitives
class MklPrimitive {
 public:
  virtual ~MklPrimitive() {}

  // Dummy data which MKL DNN never operates on
  unsigned char* DummyData = nullptr;
};

const mkldnn::memory::dims NONE_DIMS = {};

//
// LRUCache is a class which implements LRU (Least Recently Used) cache.
// The implementation is similar to that of
//    tensorflow/core/platform/cloud/expiring_lru_cache.h
// without its thread-safe part because the cache is supposed to be
// used as thread local (for instance, MklPrimitive caching).
//
// The LRU list maintains objects in chronological order based on
// creation time, with the least recently accessed object at the
// tail of LRU list, while the most recently accessed object
// at the head of LRU list.
//
// This class is used to maintain an upper bound on the total number of
// cached items. When the cache reaches its capacity, the LRU item will
// be removed and replaced by a new one from SetOp call.
//
template <typename T>
class LRUCache {
 public:
  explicit LRUCache(size_t capacity) {
    capacity_ = capacity;
    Clear();
  }

  T* GetOp(const string& key) {
    auto it = cache_.find(key);
    if (it == cache_.end()) {
      return nullptr;
    }

    // Move to the front of LRU list as the most recently accessed.
    lru_list_.erase(it->second.lru_iterator);
    lru_list_.push_front(it->first);
    it->second.lru_iterator = lru_list_.begin();
    return it->second.op;
  }

  void SetOp(const string& key, T* op) {
    if (lru_list_.size() >= capacity_) {
      Delete();
    }

    // Insert an entry to the front of the LRU list
    lru_list_.push_front(key);
    Entry entry(op, lru_list_.begin());
    cache_.emplace(std::make_pair(key, std::move(entry)));
  }

  void Clear() {
    if (lru_list_.empty()) return;

    // Clean up the cache
    cache_.clear();
    lru_list_.clear();
  }

 private:
  struct Entry {
    // The entry's value.
    T* op;

    // A list iterator pointing to the entry's position in the LRU list.
    std::list<string>::iterator lru_iterator;

    // Constructor
    Entry(T* op, std::list<string>::iterator it) {
      this->op = op;
      this->lru_iterator = it;
    }

    // Move construcctor
    Entry(Entry&& source) noexcept
        : lru_iterator(std::move(source.lru_iterator)) {
      op = std::move(source.op);
      source.op = std::forward<T*>(nullptr);
    }

    // Destructor
    ~Entry() {
      if (op != nullptr) delete op;
    }
  };

  // Remove the least recently accessed entry from LRU list, which
  // is the tail of lru_list_. Update cache_ correspondingly.
  bool Delete() {
    if (lru_list_.empty()) return false;
    string key = lru_list_.back();
    lru_list_.pop_back();
    cache_.erase(key);
    return true;
  }

  // Cache capacity
  size_t capacity_;

  // The cache, a map from string key to a LRU entry.
  std::unordered_map<string, Entry> cache_;

  // The LRU list of entries.
  // The front of the list contains the key of the most recently accessed
  // entry, while the back of the list is the least recently accessed entry.
  std::list<string> lru_list_;
};

template <typename T>
class MklPrimitiveFactory {
 public:
  MklPrimitiveFactory() {}

  ~MklPrimitiveFactory() {}

  MklPrimitive* GetOp(const string& key) {
    auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
    return lru_cache.GetOp(key);
  }

  void SetOp(const string& key, MklPrimitive* op) {
    auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
    lru_cache.SetOp(key, op);
  }

  /// Function to decide whether HW has AVX512 or AVX2
  /// For those legacy device(w/o AVX512 and AVX2),
  /// MKL-DNN GEMM will be used.
  static inline bool IsLegacyPlatform() {
    return (!port::TestCPUFeature(port::CPUFeature::AVX512F) &&
            !port::TestCPUFeature(port::CPUFeature::AVX2));
  }

  /// Fuction to check whether primitive memory optimization is enabled
  static inline bool IsPrimitiveMemOptEnabled() {
    bool is_primitive_mem_opt_enabled = true;
    TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true,
                                   &is_primitive_mem_opt_enabled));
    return is_primitive_mem_opt_enabled;
  }

 private:
  static inline LRUCache<MklPrimitive>& GetLRUCache() {
    static const int kCapacity = 1024;  // cache capacity
    static thread_local LRUCache<MklPrimitive> lru_cache_(kCapacity);
    return lru_cache_;
  }
};

// utility class for creating keys of MKL primitive pool.
class FactoryKeyCreator {
 public:
  FactoryKeyCreator() { key_.reserve(kMaxKeyLength); }

  ~FactoryKeyCreator() {}

  void AddAsKey(const string& str) { Append(str); }

  void AddAsKey(const mkldnn::memory::dims& dims) {
    for (unsigned int i = 0; i < dims.size(); i++) {
      AddAsKey<int>(dims[i]);
    }
  }

  template <typename T>
  void AddAsKey(const T data) {
    auto buffer = reinterpret_cast<const char*>(&data);
    Append(StringPiece(buffer, sizeof(T)));
  }

  string GetKey() { return key_; }

 private:
  string key_;
  const char delimiter = 'x';
  const int kMaxKeyLength = 256;
  void Append(StringPiece s) {
    key_.append(string(s));
    key_.append(1, delimiter);
  }
};

static inline MEMORY_FORMAT get_desired_format(int channel, bool is_2d = true) {
  MEMORY_FORMAT fmt_desired = MEMORY_FORMAT::any;

  if (port::TestCPUFeature(port::CPUFeature::AVX512F)) {
    fmt_desired = is_2d ? MEMORY_FORMAT::nChw16c : MEMORY_FORMAT::nCdhw16c;
  } else if (port::TestCPUFeature(port::CPUFeature::AVX2) &&
             (channel % 8) == 0) {
    fmt_desired = is_2d ? MEMORY_FORMAT::nChw8c
                        : MEMORY_FORMAT::ncdhw;  // no avx2 support for 3d yet.
  } else {
    fmt_desired = is_2d ? MEMORY_FORMAT::nchw : MEMORY_FORMAT::ncdhw;
  }
  return fmt_desired;
}

class MklReorderPrimitive : public MklPrimitive {
 public:
  explicit MklReorderPrimitive(const memory* from, const memory* to) {
    Setup(from, to);
  }
  ~MklReorderPrimitive() {}

  std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }

  void SetMemory(const memory* from, const memory* to) {
    context_.src_mem->set_data_handle(from->get_data_handle());
    context_.dst_mem->set_data_handle(to->get_data_handle());
  }

 private:
  struct ReorderContext {
    std::shared_ptr<mkldnn::memory> src_mem;
    std::shared_ptr<mkldnn::memory> dst_mem;
    std::shared_ptr<primitive> reorder_prim;
    ReorderContext()
        : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
  } context_;

  engine cpu_engine_ = engine(ENGINE_CPU, 0);

  void Setup(const memory* from, const memory* to) {
    context_.src_mem.reset(
        new MEMORY_CONSTRUCTOR_WITH_MEM_PD(from, cpu_engine_, DummyData));
    context_.dst_mem.reset(
        new MEMORY_CONSTRUCTOR_WITH_MEM_PD(to, cpu_engine_, DummyData));
    context_.reorder_prim = std::make_shared<mkldnn::reorder>(
        reorder(*context_.src_mem, *context_.dst_mem));
  }
};

template <typename T>
class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
 public:
  static MklReorderPrimitive* Get(const memory* from, const memory* to) {
    auto reorderPrim = static_cast<MklReorderPrimitive*>(
        MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to));
    if (reorderPrim == nullptr) {
      reorderPrim = new MklReorderPrimitive(from, to);
      MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(from, to,
                                                              reorderPrim);
    }
    reorderPrim->SetMemory(from, to);
    return reorderPrim;
  }

  static MklReorderPrimitiveFactory& GetInstance() {
    static MklReorderPrimitiveFactory instance_;
    return instance_;
  }

 private:
  MklReorderPrimitiveFactory() {}
  ~MklReorderPrimitiveFactory() {}

  static string CreateKey(const memory* from, const memory* to) {
    string prefix = "reorder";
    FactoryKeyCreator key_creator;
    auto const& from_desc = GET_MEMORY_DESC_FROM_MEM_PTR(from).data;
    auto const& to_desc = GET_MEMORY_DESC_FROM_MEM_PTR(to).data;
    const int KIdxFirstStride = 0;
    memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
    memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
    memory::dims from_strides(
#ifdef ENABLE_MKLDNN_V1
        from_desc.format_desc.blocking.strides,
        &from_desc.format_desc.blocking.strides[from_desc.ndims]);
#else
        from_desc.layout_desc.blocking.strides[KIdxFirstStride],
        &from_desc.layout_desc.blocking
             .strides[KIdxFirstStride][from_desc.ndims]);
#endif  // ENABLE_MKLDNN_V1
    memory::dims to_strides(
#ifdef ENABLE_MKLDNN_V1
        to_desc.format_desc.blocking.strides,
        &to_desc.format_desc.blocking.strides[to_desc.ndims]);
#else
        to_desc.layout_desc.blocking.strides[KIdxFirstStride],
        &to_desc.layout_desc.blocking.strides[KIdxFirstStride][to_desc.ndims]);
#endif  // ENABLE_MKLDNN_V1
    key_creator.AddAsKey(prefix);
#ifndef ENABLE_MKLDNN_V1
    // `format_kind` is not added in v1.x since it will always set to
    // `mkldnn_blocked`
    key_creator.AddAsKey(static_cast<int>(from_desc.format));
#endif  // !ENABLE_MKLDNN_V1
    key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
    key_creator.AddAsKey(from_dims);
    key_creator.AddAsKey(from_strides);
#ifndef ENABLE_MKLDNN_V1
    key_creator.AddAsKey(static_cast<int>(to_desc.format));
#endif  // !ENABLE_MKLDNN_V1
    key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
    key_creator.AddAsKey(to_dims);
    key_creator.AddAsKey(to_strides);
    return key_creator.GetKey();
  }

  MklPrimitive* GetReorder(const memory* from, const memory* to) {
    string key = CreateKey(from, to);
    return this->GetOp(key);
  }

  void SetReorder(const memory* from, const memory* to, MklPrimitive* op) {
    string key = CreateKey(from, to);
    this->SetOp(key, op);
  }
};

/// Fuction to find(or create) a reorder from memory pointed by
/// from to memory pointed by to, it will created primitive or
/// get primitive from pool if it is cached.
/// Returns the primitive.
template <typename T>
inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
  CHECK_NOTNULL(from);
  CHECK_NOTNULL(to);
  MklReorderPrimitive* reorder_prim =
      MklReorderPrimitiveFactory<T>::Get(from, to);
  return *reorder_prim->GetPrimitive();
}

// utility function to determine if it is conv 1x1 and stride != 1
// for purpose of temporarily disabling primitive reuse
inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
                                memory::dims strides) {
  if (filter_dims.size() != 4 || strides.size() != 2) return false;

  return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
          ((strides[0] != 1) || (strides[1] != 1)));
}

#undef ENGINE_CPU
#undef GET_MEMORY_DESC_FROM_MEM_PTR
#undef GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR
#undef MEMORY_CONSTRUCTOR
#undef MEMORY_CONSTRUCTOR_WITH_MEM_PD
#undef MEMORY_CONSTRUCTOR_WITHOUT_DATA
#undef MEMORY_FORMAT
#undef MKL_TENSOR_FORMAT
#undef MKL_TENSOR_FORMAT_BLOCKED
#undef MKL_TENSOR_FORMAT_INVALID
#undef MKL_TENSOR_FORMAT_NCDHW
#undef MKL_TENSOR_FORMAT_NDHWC
#undef MKL_TENSOR_FORMAT_NHWC
#undef MKL_TENSOR_FORMAT_NCHW
#undef MKL_TENSOR_FORMAT_UNDEF
#undef MEMORY_DATA_TYPE_UNDEF
#undef MEMORY_PRIMITIVE_DESC
#undef TENSOR_FORMAT
#undef TENSOR_FORMAT_NHWC

}  // namespace tensorflow