in tensorflow/tensorflow/core/util/mkl_util.h [1632:2153]
inline bool CheckReorderToOpMem(const memory::desc& op_md,
void* reorder_data_handle,
const engine& engine) {
DCHECK(reorder_data_handle);
DCHECK(user_memory_);
if (IsReorderNeeded(op_md)) {
// TODO(nhasabni): can we remove dynamic memory allocation?
// primitive reuse don't allow two same reorder prim in
// one stream, so submit it immediately
reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
stream cpu_stream(engine);
reorder(*user_memory_, *reorder_memory_)
.execute(cpu_stream, *user_memory_, *reorder_memory_);
#else
inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
void* reorder_data_handle) {
CHECK_NOTNULL(reorder_data_handle);
CHECK_NOTNULL(user_memory_);
if (IsReorderNeeded(op_pd)) {
std::vector<primitive> net;
reorder_memory_ = new memory(op_pd, reorder_data_handle);
net.push_back(FindOrCreateReorder<T>(user_memory_, reorder_memory_));
stream(stream::kind::eager).submit(net).wait();
#endif // ENABLE_MKLDNN_V1
return true;
}
return false;
}
/// Another overloaded version of CheckReorderToOpMem that accepts Tensor
/// where output of reorder needs to be stored.
///
/// @input: op_md - memory primitive descriptor (memory descriptor for v1.x)
/// of the given input of an operation
/// @reorder_tensor - Tensor whose buffer is to be used to store output of
/// reorder. Primitive does not check if buffer is
/// enough size to write.
/// @input: net - net to which to add reorder primitive in case it is needed.
/// @input: net_args - net to which user and reorder memories are added if
/// needed. Each entry is a key-value pair of the form
/// <argument-type, mkldnn::memory>.
/// @input: engine - MKL-DNN's abstraction of a computational device
/// @return: true in case reorder of input is needed; false, otherwise.
#ifdef ENABLE_MKLDNN_V1
inline bool CheckReorderToOpMem(const memory::desc& op_md,
Tensor* reorder_tensor,
std::vector<primitive>& net,
std::vector<MemoryArgsMap>& net_args,
const engine& engine) {
DCHECK(reorder_tensor);
return CheckReorderToOpMem(op_md, GetTensorBuffer(reorder_tensor), net,
net_args, engine);
}
#else
inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
Tensor* reorder_tensor,
std::vector<primitive>* net) {
CHECK_NOTNULL(net);
CHECK_NOTNULL(reorder_tensor);
return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), net);
}
#endif // ENABLE_MKLDNN_V1
/// TODO: this is a faster path with reorder primitive cache compared with
/// CheckReorderToOpMem(op_md, reorder_tensor, net, net_args, engine), will
/// remove
/// slow path in the future
inline bool CheckReorderToOpMem(const MEMORY_PRIMITIVE_DESC& op_pd,
Tensor* reorder_tensor) {
DCHECK(reorder_tensor);
#ifdef ENABLE_MKLDNN_V1
return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor),
*cpu_engine_);
#else
return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor));
#endif // ENABLE_MKLDNN_V1
}
/// Function to handle output reorder
///
/// This function performs very similar functionality as input reordering
/// function above. The only difference is that this function does not add
/// reorder primitive to the net. The reason for this is: the reorder
/// primitive for output needs to be added to the list only after operation
/// has executed. But we need to prepare a temporary buffer in case output
/// reorder is needed. And this temporary buffer will hold the output of
/// an operation before it is fed to reorder primitive.
///
/// @input - memory primitive descriptor (memory descriptor for v1.x) for the
/// given output of an operation
/// @return: true in case reorder of output is needed; false, otherwise.
inline bool PrepareReorderToUserMemIfReq(const MEMORY_PRIMITIVE_DESC& op_pd) {
DCHECK(user_memory_);
if (IsReorderNeeded(op_pd)) {
// TODO(nhasabni): can we remove dynamic memory allocation?
reorder_memory_ =
new MEMORY_CONSTRUCTOR_WITHOUT_DATA(op_pd, *cpu_engine_);
return true;
}
return false;
}
/// Function to actually insert reorder primitive in the net
///
/// This function completes remaining part of output reordering. It inserts
/// a reordering primitive from the temporary buffer that holds the output
/// to the user-specified output buffer.
///
/// @input: net - net to which to add reorder primitive
/// @input: net_args - net to which user and reorder memories are added if
/// needed. Each entry is a key-value pair of the form
/// <argument-type, mkldnn::memory>.
#ifdef ENABLE_MKLDNN_V1
inline void InsertReorderToUserMem(std::vector<primitive>& net,
std::vector<MemoryArgsMap>& net_args) {
DCHECK(user_memory_);
DCHECK(reorder_memory_);
net.push_back(CreateReorder(reorder_memory_, user_memory_));
net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *reorder_memory_},
{MKLDNN_ARG_TO, *user_memory_}});
}
#else
inline void InsertReorderToUserMem(std::vector<primitive>* net) {
CHECK_NOTNULL(net);
CHECK_NOTNULL(user_memory_);
CHECK_NOTNULL(reorder_memory_);
net->push_back(CreateReorder(reorder_memory_, user_memory_));
}
#endif // ENABLE_MKLDNN_V1
/// TODO: this is a faster path with reorder primitive cache compared with
/// InsertReorderToUserMem(net, net_args), will remove
/// slow path in the future
inline void InsertReorderToUserMem() {
DCHECK(user_memory_);
DCHECK(reorder_memory_);
#ifdef ENABLE_MKLDNN_V1
DCHECK(cpu_engine_);
stream cpu_stream(cpu_engine_);
#endif // ENABLE_MKLDNN_V1
// primitive reuse don't allow two same reorder prim in
// one stream, so submit it immediately
std::vector<primitive> net;
#ifdef ENABLE_MKLDNN_V1
std::vector<MemoryArgsMap> net_args;
net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_));
net_args.push_back(MemoryArgsMap{{MKLDNN_ARG_FROM, *reorder_memory_},
{MKLDNN_ARG_TO, *user_memory_}});
DCHECK_EQ(net.size(), net_args.size());
for (size_t i = 0; i < net.size(); ++i) {
net.at(i).execute(cpu_stream, net_args.at(i));
}
cpu_stream.wait();
#else
net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_));
stream(stream::kind::eager).submit(net).wait();
#endif // ENABLE_MKLDNN_V1
}
};
/// Base class for operations with reuse of primitives
class MklPrimitive {
public:
virtual ~MklPrimitive() {}
// Dummy data which MKL DNN never operates on
unsigned char* DummyData = nullptr;
};
const mkldnn::memory::dims NONE_DIMS = {};
//
// LRUCache is a class which implements LRU (Least Recently Used) cache.
// The implementation is similar to that of
// tensorflow/core/platform/cloud/expiring_lru_cache.h
// without its thread-safe part because the cache is supposed to be
// used as thread local (for instance, MklPrimitive caching).
//
// The LRU list maintains objects in chronological order based on
// creation time, with the least recently accessed object at the
// tail of LRU list, while the most recently accessed object
// at the head of LRU list.
//
// This class is used to maintain an upper bound on the total number of
// cached items. When the cache reaches its capacity, the LRU item will
// be removed and replaced by a new one from SetOp call.
//
template <typename T>
class LRUCache {
public:
explicit LRUCache(size_t capacity) {
capacity_ = capacity;
Clear();
}
T* GetOp(const string& key) {
auto it = cache_.find(key);
if (it == cache_.end()) {
return nullptr;
}
// Move to the front of LRU list as the most recently accessed.
lru_list_.erase(it->second.lru_iterator);
lru_list_.push_front(it->first);
it->second.lru_iterator = lru_list_.begin();
return it->second.op;
}
void SetOp(const string& key, T* op) {
if (lru_list_.size() >= capacity_) {
Delete();
}
// Insert an entry to the front of the LRU list
lru_list_.push_front(key);
Entry entry(op, lru_list_.begin());
cache_.emplace(std::make_pair(key, std::move(entry)));
}
void Clear() {
if (lru_list_.empty()) return;
// Clean up the cache
cache_.clear();
lru_list_.clear();
}
private:
struct Entry {
// The entry's value.
T* op;
// A list iterator pointing to the entry's position in the LRU list.
std::list<string>::iterator lru_iterator;
// Constructor
Entry(T* op, std::list<string>::iterator it) {
this->op = op;
this->lru_iterator = it;
}
// Move construcctor
Entry(Entry&& source) noexcept
: lru_iterator(std::move(source.lru_iterator)) {
op = std::move(source.op);
source.op = std::forward<T*>(nullptr);
}
// Destructor
~Entry() {
if (op != nullptr) delete op;
}
};
// Remove the least recently accessed entry from LRU list, which
// is the tail of lru_list_. Update cache_ correspondingly.
bool Delete() {
if (lru_list_.empty()) return false;
string key = lru_list_.back();
lru_list_.pop_back();
cache_.erase(key);
return true;
}
// Cache capacity
size_t capacity_;
// The cache, a map from string key to a LRU entry.
std::unordered_map<string, Entry> cache_;
// The LRU list of entries.
// The front of the list contains the key of the most recently accessed
// entry, while the back of the list is the least recently accessed entry.
std::list<string> lru_list_;
};
template <typename T>
class MklPrimitiveFactory {
public:
MklPrimitiveFactory() {}
~MklPrimitiveFactory() {}
MklPrimitive* GetOp(const string& key) {
auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
return lru_cache.GetOp(key);
}
void SetOp(const string& key, MklPrimitive* op) {
auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
lru_cache.SetOp(key, op);
}
/// Function to decide whether HW has AVX512 or AVX2
/// For those legacy device(w/o AVX512 and AVX2),
/// MKL-DNN GEMM will be used.
static inline bool IsLegacyPlatform() {
return (!port::TestCPUFeature(port::CPUFeature::AVX512F) &&
!port::TestCPUFeature(port::CPUFeature::AVX2));
}
/// Fuction to check whether primitive memory optimization is enabled
static inline bool IsPrimitiveMemOptEnabled() {
bool is_primitive_mem_opt_enabled = true;
TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true,
&is_primitive_mem_opt_enabled));
return is_primitive_mem_opt_enabled;
}
private:
static inline LRUCache<MklPrimitive>& GetLRUCache() {
static const int kCapacity = 1024; // cache capacity
static thread_local LRUCache<MklPrimitive> lru_cache_(kCapacity);
return lru_cache_;
}
};
// utility class for creating keys of MKL primitive pool.
class FactoryKeyCreator {
public:
FactoryKeyCreator() { key_.reserve(kMaxKeyLength); }
~FactoryKeyCreator() {}
void AddAsKey(const string& str) { Append(str); }
void AddAsKey(const mkldnn::memory::dims& dims) {
for (unsigned int i = 0; i < dims.size(); i++) {
AddAsKey<int>(dims[i]);
}
}
template <typename T>
void AddAsKey(const T data) {
auto buffer = reinterpret_cast<const char*>(&data);
Append(StringPiece(buffer, sizeof(T)));
}
string GetKey() { return key_; }
private:
string key_;
const char delimiter = 'x';
const int kMaxKeyLength = 256;
void Append(StringPiece s) {
key_.append(string(s));
key_.append(1, delimiter);
}
};
static inline MEMORY_FORMAT get_desired_format(int channel, bool is_2d = true) {
MEMORY_FORMAT fmt_desired = MEMORY_FORMAT::any;
if (port::TestCPUFeature(port::CPUFeature::AVX512F)) {
fmt_desired = is_2d ? MEMORY_FORMAT::nChw16c : MEMORY_FORMAT::nCdhw16c;
} else if (port::TestCPUFeature(port::CPUFeature::AVX2) &&
(channel % 8) == 0) {
fmt_desired = is_2d ? MEMORY_FORMAT::nChw8c
: MEMORY_FORMAT::ncdhw; // no avx2 support for 3d yet.
} else {
fmt_desired = is_2d ? MEMORY_FORMAT::nchw : MEMORY_FORMAT::ncdhw;
}
return fmt_desired;
}
class MklReorderPrimitive : public MklPrimitive {
public:
explicit MklReorderPrimitive(const memory* from, const memory* to) {
Setup(from, to);
}
~MklReorderPrimitive() {}
std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
void SetMemory(const memory* from, const memory* to) {
context_.src_mem->set_data_handle(from->get_data_handle());
context_.dst_mem->set_data_handle(to->get_data_handle());
}
private:
struct ReorderContext {
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> dst_mem;
std::shared_ptr<primitive> reorder_prim;
ReorderContext()
: src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
} context_;
engine cpu_engine_ = engine(ENGINE_CPU, 0);
void Setup(const memory* from, const memory* to) {
context_.src_mem.reset(
new MEMORY_CONSTRUCTOR_WITH_MEM_PD(from, cpu_engine_, DummyData));
context_.dst_mem.reset(
new MEMORY_CONSTRUCTOR_WITH_MEM_PD(to, cpu_engine_, DummyData));
context_.reorder_prim = std::make_shared<mkldnn::reorder>(
reorder(*context_.src_mem, *context_.dst_mem));
}
};
template <typename T>
class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklReorderPrimitive* Get(const memory* from, const memory* to) {
auto reorderPrim = static_cast<MklReorderPrimitive*>(
MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to));
if (reorderPrim == nullptr) {
reorderPrim = new MklReorderPrimitive(from, to);
MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(from, to,
reorderPrim);
}
reorderPrim->SetMemory(from, to);
return reorderPrim;
}
static MklReorderPrimitiveFactory& GetInstance() {
static MklReorderPrimitiveFactory instance_;
return instance_;
}
private:
MklReorderPrimitiveFactory() {}
~MklReorderPrimitiveFactory() {}
static string CreateKey(const memory* from, const memory* to) {
string prefix = "reorder";
FactoryKeyCreator key_creator;
auto const& from_desc = GET_MEMORY_DESC_FROM_MEM_PTR(from).data;
auto const& to_desc = GET_MEMORY_DESC_FROM_MEM_PTR(to).data;
const int KIdxFirstStride = 0;
memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
memory::dims from_strides(
#ifdef ENABLE_MKLDNN_V1
from_desc.format_desc.blocking.strides,
&from_desc.format_desc.blocking.strides[from_desc.ndims]);
#else
from_desc.layout_desc.blocking.strides[KIdxFirstStride],
&from_desc.layout_desc.blocking
.strides[KIdxFirstStride][from_desc.ndims]);
#endif // ENABLE_MKLDNN_V1
memory::dims to_strides(
#ifdef ENABLE_MKLDNN_V1
to_desc.format_desc.blocking.strides,
&to_desc.format_desc.blocking.strides[to_desc.ndims]);
#else
to_desc.layout_desc.blocking.strides[KIdxFirstStride],
&to_desc.layout_desc.blocking.strides[KIdxFirstStride][to_desc.ndims]);
#endif // ENABLE_MKLDNN_V1
key_creator.AddAsKey(prefix);
#ifndef ENABLE_MKLDNN_V1
// `format_kind` is not added in v1.x since it will always set to
// `mkldnn_blocked`
key_creator.AddAsKey(static_cast<int>(from_desc.format));
#endif // !ENABLE_MKLDNN_V1
key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
key_creator.AddAsKey(from_dims);
key_creator.AddAsKey(from_strides);
#ifndef ENABLE_MKLDNN_V1
key_creator.AddAsKey(static_cast<int>(to_desc.format));
#endif // !ENABLE_MKLDNN_V1
key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
key_creator.AddAsKey(to_dims);
key_creator.AddAsKey(to_strides);
return key_creator.GetKey();
}
MklPrimitive* GetReorder(const memory* from, const memory* to) {
string key = CreateKey(from, to);
return this->GetOp(key);
}
void SetReorder(const memory* from, const memory* to, MklPrimitive* op) {
string key = CreateKey(from, to);
this->SetOp(key, op);
}
};
/// Fuction to find(or create) a reorder from memory pointed by
/// from to memory pointed by to, it will created primitive or
/// get primitive from pool if it is cached.
/// Returns the primitive.
template <typename T>
inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
CHECK_NOTNULL(from);
CHECK_NOTNULL(to);
MklReorderPrimitive* reorder_prim =
MklReorderPrimitiveFactory<T>::Get(from, to);
return *reorder_prim->GetPrimitive();
}
// utility function to determine if it is conv 1x1 and stride != 1
// for purpose of temporarily disabling primitive reuse
inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
memory::dims strides) {
if (filter_dims.size() != 4 || strides.size() != 2) return false;
return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
((strides[0] != 1) || (strides[1] != 1)));
}
#undef ENGINE_CPU
#undef GET_MEMORY_DESC_FROM_MEM_PTR
#undef GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR
#undef MEMORY_CONSTRUCTOR
#undef MEMORY_CONSTRUCTOR_WITH_MEM_PD
#undef MEMORY_CONSTRUCTOR_WITHOUT_DATA
#undef MEMORY_FORMAT
#undef MKL_TENSOR_FORMAT
#undef MKL_TENSOR_FORMAT_BLOCKED
#undef MKL_TENSOR_FORMAT_INVALID
#undef MKL_TENSOR_FORMAT_NCDHW
#undef MKL_TENSOR_FORMAT_NDHWC
#undef MKL_TENSOR_FORMAT_NHWC
#undef MKL_TENSOR_FORMAT_NCHW
#undef MKL_TENSOR_FORMAT_UNDEF
#undef MEMORY_DATA_TYPE_UNDEF
#undef MEMORY_PRIMITIVE_DESC
#undef TENSOR_FORMAT
#undef TENSOR_FORMAT_NHWC
} // namespace tensorflow