cpp-package/include/mxnet-cpp/monitor.hpp (86 lines of code) (raw):

/*! * Copyright (c) 2017 by Contributors * \file monitor.hpp * \brief monitor implementation * \author Xin Li */ #ifndef CPP_PACKAGE_INCLUDE_MXNET_CPP_MONITOR_HPP_ #define CPP_PACKAGE_INCLUDE_MXNET_CPP_MONITOR_HPP_ #include <cmath> #include <sstream> #include <algorithm> #include <vector> #include <string> #include "mxnet-cpp/monitor.h" namespace mxnet { namespace cpp { inline NDArray _default_monitor_func(const NDArray &x) { return Operator("norm").PushInput(x).Invoke()[0] / std::sqrt(x.Size()); } inline Monitor::Monitor(int interval, std::regex pattern, StatFunc stat_func) : interval(interval), pattern(pattern), stat_func(stat_func), step(0) { } inline void Monitor::install(Executor *exe) { MXExecutorSetMonitorCallback(exe->handle_, static_cast<ExecutorMonitorCallback>(&Monitor::executor_callback), this); exes.push_back(exe); } inline void Monitor::tic() { if (step % interval == 0) { activated = true; stats.clear(); } } inline std::vector<Monitor::Stat> Monitor::toc() { std::vector<Monitor::Stat> results; if (activated) { activated = false; for (auto* exe : exes) { for (auto& arg : exe->arg_arrays) { arg.WaitToRead(); } for (auto& aux : exe->aux_arrays) { aux.WaitToRead(); } for (auto &pair : exe->arg_dict()) { if (std::regex_match(pair.first, pattern)) { stats.emplace_back(step, pair.first, stat_func(pair.second)); } } for (auto &pair : exe->aux_dict()) { if (std::regex_match(pair.first, pattern)) { stats.emplace_back(step, pair.first, stat_func(pair.second)); } } } results.swap(stats); } ++step; return results; } inline void Monitor::toc_print() { auto results = toc(); std::vector<float> data(1); for (auto& stat : results) { NDArray ndarray = std::get<2>(stat); std::string str; if (ndarray.Size() == 1) { if (ndarray.GetContext().GetDeviceType() != DeviceType::kGPU) { data[0] = ndarray.GetData()[0]; } else { ndarray.SyncCopyToCPU(&data); } str = std::to_string(data[0]); } else { std::ostringstream out; out << ndarray; str = out.str(); } LG << "Batch: " << std::get<0>(stat) << ' ' << std::get<1>(stat) << ' ' << str; } } inline void Monitor::executor_callback(const char *name, NDArrayHandle handle, void *monitor_ptr) { Monitor *monitor = static_cast<Monitor*>(monitor_ptr); if (monitor->activated && std::regex_match(name, monitor->pattern)) { monitor->stats.emplace_back(monitor->step, name, monitor->stat_func(NDArray(handle))); } } } // namespace cpp } // namespace mxnet #endif // CPP_PACKAGE_INCLUDE_MXNET_CPP_MONITOR_HPP_