Code style for test util (#7723)

* Enabled code style for ngraph test util

* remove some methods

* Fixed backends code style
This commit is contained in:
Ilya Churaev
2021-09-29 06:31:37 +03:00
committed by GitHub
parent d074eea063
commit d2878e4012
81 changed files with 2735 additions and 3843 deletions

View File

@@ -9,7 +9,6 @@
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/engine/test_engines.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"

View File

@@ -21,7 +21,6 @@
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/engine/test_engines.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"

View File

@@ -9,7 +9,6 @@
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/engine/test_engines.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"

View File

@@ -9,7 +9,6 @@
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/engine/test_engines.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"
@@ -2933,4 +2932,4 @@ NGRAPH_TEST(${BACKEND_NAME}, deformable_convolution_opset8_2D_neg_offsets_groups
DeformableConvolutionOpset8Test(inputs, inputs_shape, offsets, offsets_shape, filter,
filter_shape, mask, mask_shape, outputs, outputs_shape, strides, padding,
dilations, group, deformable_group, tolerance_bits, true);
}
}

View File

@@ -11,7 +11,6 @@
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/engine/test_engines.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"

View File

@@ -30,7 +30,6 @@
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_control.hpp"

View File

@@ -18,7 +18,6 @@
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_control.hpp"

View File

@@ -32,7 +32,6 @@
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/random.hpp"
#include "util/test_control.hpp"

View File

@@ -9,7 +9,6 @@
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/engine/test_engines.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"

View File

@@ -9,7 +9,6 @@
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/engine/test_engines.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"

View File

@@ -11,7 +11,6 @@
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/engine/test_engines.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"

View File

@@ -9,7 +9,6 @@
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/engine/test_engines.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"

View File

@@ -18,7 +18,6 @@
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"

View File

@@ -8,7 +8,6 @@
#include "runtime/backend.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"

View File

@@ -18,7 +18,6 @@
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"

View File

@@ -8,7 +8,6 @@
#include "runtime/backend.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"

View File

@@ -8,7 +8,6 @@
#include "runtime/backend.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"

View File

@@ -8,7 +8,6 @@
#include "runtime/backend.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"

View File

@@ -18,7 +18,6 @@
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"

View File

@@ -18,7 +18,6 @@
#include "ngraph/ngraph.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"

View File

@@ -8,7 +8,6 @@
#include "runtime/backend.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"

View File

@@ -8,7 +8,6 @@
#include "runtime/backend.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/known_element_types.hpp"
#include "util/ndarray.hpp"
#include "util/test_control.hpp"
#include "util/test_tools.hpp"

View File

@@ -52,3 +52,6 @@ install(TARGETS ngraph_backend
add_subdirectory(interpreter)
add_subdirectory(ie)
file(GLOB_RECURSE all_backends_src "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp")
add_clang_format_target(ngraph_backend_clang FOR_SOURCES ${all_backends_src})

View File

@@ -3,15 +3,15 @@
//
#ifdef _WIN32
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#if defined(WINAPI_FAMILY) && !WINAPI_PARTITION_DESKTOP
#error "Only WINAPI_PARTITION_DESKTOP is supported, because of LoadLibrary[A|W]"
#endif
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
# if defined(WINAPI_FAMILY) && !WINAPI_PARTITION_DESKTOP
# error "Only WINAPI_PARTITION_DESKTOP is supported, because of LoadLibrary[A|W]"
# endif
#elif defined(__linux) || defined(__APPLE__)
#include <dlfcn.h>
# include <dlfcn.h>
#endif
#include <sstream>
@@ -29,8 +29,7 @@ std::mutex runtime::Backend::m_mtx;
std::string runtime::Backend::s_backend_shared_library_search_directory;
// This finds the full path of the containing shared library
static string find_my_pathname()
{
static string find_my_pathname() {
#ifdef _WIN32
HMODULE hModule = GetModuleHandleW(SHARED_LIB_PREFIX L"ngraph" SHARED_LIB_SUFFIX);
WCHAR wpath[MAX_PATH];
@@ -48,77 +47,63 @@ static string find_my_pathname()
dladdr(reinterpret_cast<void*>(ngraph::to_lower), &dl_info);
return dl_info.dli_fname;
#else
#error "Unsupported OS"
# error "Unsupported OS"
#endif
}
runtime::Backend::~Backend() {}
std::shared_ptr<runtime::Backend> runtime::Backend::create(const string& t,
bool must_support_dynamic)
{
std::shared_ptr<runtime::Backend> runtime::Backend::create(const string& t, bool must_support_dynamic) {
// Rewrite backend name BACKEND_OPTION to BACKEND:OPTION
string type = t;
auto pos = type.find('_');
if (pos != string::npos)
{
if (pos != string::npos) {
type = type.replace(pos, 1, ":");
}
auto inner_backend = BackendManager::create_backend(type);
if (!must_support_dynamic || inner_backend->supports_dynamic_tensors())
{
if (!must_support_dynamic || inner_backend->supports_dynamic_tensors()) {
return inner_backend;
}
else
{
} else {
return make_shared<runtime::dynamic::DynamicBackend>(inner_backend);
}
return inner_backend;
}
vector<string> runtime::Backend::get_registered_devices()
{
vector<string> runtime::Backend::get_registered_devices() {
return BackendManager::get_registered_backends();
}
std::shared_ptr<ngraph::runtime::Tensor>
runtime::Backend::create_dynamic_tensor(const ngraph::element::Type& /* element_type */,
const PartialShape& /* shape */)
{
std::shared_ptr<ngraph::runtime::Tensor> runtime::Backend::create_dynamic_tensor(
const ngraph::element::Type& /* element_type */,
const PartialShape& /* shape */) {
throw std::invalid_argument("This backend does not support dynamic tensors");
}
bool runtime::Backend::is_supported(const Node& /* node */) const
{
bool runtime::Backend::is_supported(const Node& /* node */) const {
// The default behavior is that a backend does not support any ops. If this is not the case
// then override this method and enhance.
return false;
}
std::shared_ptr<runtime::Executable> runtime::Backend::load(istream& /* input_stream */)
{
std::shared_ptr<runtime::Executable> runtime::Backend::load(istream& /* input_stream */) {
throw runtime_error("load operation unimplemented.");
}
void runtime::Backend::set_backend_shared_library_search_directory(const string& path)
{
void runtime::Backend::set_backend_shared_library_search_directory(const string& path) {
std::lock_guard<std::mutex> lock(runtime::Backend::m_mtx);
s_backend_shared_library_search_directory = path;
}
const string& runtime::Backend::get_backend_shared_library_search_directory()
{
if (s_backend_shared_library_search_directory.empty())
{
const string& runtime::Backend::get_backend_shared_library_search_directory() {
if (s_backend_shared_library_search_directory.empty()) {
s_backend_shared_library_search_directory = find_my_pathname();
}
return s_backend_shared_library_search_directory;
}
bool runtime::Backend::set_config(const map<string, string>& /* config */, string& error)
{
bool runtime::Backend::set_config(const map<string, string>& /* config */, string& error) {
error = "set_config not supported";
return false;
}

View File

@@ -15,20 +15,17 @@
#include "ngraph/util.hpp"
#include "performance_counter.hpp"
namespace ngraph
{
namespace runtime
{
class Tensor;
class Backend;
}
}
namespace ngraph {
namespace runtime {
class Tensor;
class Backend;
} // namespace runtime
} // namespace ngraph
/// \brief Interface to a generic backend.
///
/// Backends are responsible for function execution and value allocation.
class BACKEND_API ngraph::runtime::Backend
{
class BACKEND_API ngraph::runtime::Backend {
public:
virtual ~Backend();
/// \brief Create a new Backend object
@@ -41,8 +38,7 @@ public:
/// DynamicWrapperBackend. This feature is EXPERIMENTAL.
/// \returns shared_ptr to a new Backend or nullptr if the named backend
/// does not exist.
static std::shared_ptr<Backend> create(const std::string& type,
bool must_support_dynamic = false);
static std::shared_ptr<Backend> create(const std::string& type, bool must_support_dynamic = false);
/// \brief Query the list of registered devices
/// \returns A vector of all registered devices.
@@ -59,8 +55,8 @@ public:
/// \param element_type The type of the tensor element
/// \param shape The shape of the tensor
/// \returns shared_ptr to a new backend-specific tensor
virtual std::shared_ptr<ngraph::runtime::Tensor>
create_tensor(const ngraph::element::Type& element_type, const Shape& shape) = 0;
virtual std::shared_ptr<ngraph::runtime::Tensor> create_tensor(const ngraph::element::Type& element_type,
const Shape& shape) = 0;
/// \brief Create a tensor specific to this backend
/// \param element_type The type of the tensor element
@@ -69,15 +65,15 @@ public:
/// must be sufficient to contain the tensor. The lifetime of the buffer is the
/// responsibility of the caller.
/// \returns shared_ptr to a new backend-specific tensor
virtual std::shared_ptr<ngraph::runtime::Tensor> create_tensor(
const ngraph::element::Type& element_type, const Shape& shape, void* memory_pointer) = 0;
virtual std::shared_ptr<ngraph::runtime::Tensor> create_tensor(const ngraph::element::Type& element_type,
const Shape& shape,
void* memory_pointer) = 0;
/// \brief Create a tensor of C type T specific to this backend
/// \param shape The shape of the tensor
/// \returns shared_ptr to a new backend specific tensor
template <typename T>
std::shared_ptr<ngraph::runtime::Tensor> create_tensor(const Shape& shape)
{
std::shared_ptr<ngraph::runtime::Tensor> create_tensor(const Shape& shape) {
return create_tensor(element::from<T>(), shape);
}
@@ -87,11 +83,13 @@ public:
/// \param shape The shape of the tensor
/// \returns shared_ptr to a new backend-specific tensor
/// \throws std::invalid_argument if the backend does not support dynamic tensors
virtual std::shared_ptr<ngraph::runtime::Tensor>
create_dynamic_tensor(const ngraph::element::Type& element_type, const PartialShape& shape);
virtual std::shared_ptr<ngraph::runtime::Tensor> create_dynamic_tensor(const ngraph::element::Type& element_type,
const PartialShape& shape);
/// \returns `true` if this backend supports dynamic tensors, else `false`.
virtual bool supports_dynamic_tensors() { return false; }
virtual bool supports_dynamic_tensors() {
return false;
}
/// \brief Compiles a Function.
/// \param func The function to compile
/// \returns compiled function or nullptr on failure
@@ -122,7 +120,9 @@ public:
/// \brief Get the version of the backend
/// The default value of 0.0.0 is chosen to be a parsable version number
virtual std::string get_version() const { return "0.0.0"; }
virtual std::string get_version() const {
return "0.0.0";
}
private:
// mutex to modify s_backend_shared_library_search_directory thread safe

View File

@@ -3,12 +3,12 @@
//
#ifdef _WIN32
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
#else
#include <dlfcn.h>
# include <dlfcn.h>
#endif
#include <sstream>
@@ -24,56 +24,47 @@ using namespace std;
using namespace ngraph;
#ifdef _WIN32
#define CLOSE_LIBRARY(a) FreeLibrary(a)
#define DLSYM(a, b) GetProcAddress(a, b)
#define DLERROR() ""
# define CLOSE_LIBRARY(a) FreeLibrary(a)
# define DLSYM(a, b) GetProcAddress(a, b)
# define DLERROR() ""
#else
#define CLOSE_LIBRARY(a) dlclose(a)
#define DLSYM(a, b) dlsym(a, b)
string DLERROR()
{
# define CLOSE_LIBRARY(a) dlclose(a)
# define DLSYM(a, b) dlsym(a, b)
string DLERROR() {
const char* error = dlerror();
return error == nullptr ? "" : error;
}
#endif
unordered_map<string, runtime::BackendConstructor>& runtime::BackendManager::get_registry()
{
unordered_map<string, runtime::BackendConstructor>& runtime::BackendManager::get_registry() {
static unordered_map<string, BackendConstructor> s_registered_backend;
return s_registered_backend;
}
void runtime::BackendManager::register_backend(const string& name, BackendConstructor new_backend)
{
void runtime::BackendManager::register_backend(const string& name, BackendConstructor new_backend) {
get_registry()[name] = new_backend;
}
vector<string> runtime::BackendManager::get_registered_backends()
{
vector<string> runtime::BackendManager::get_registered_backends() {
vector<string> rc;
for (const auto& p : get_registry())
{
for (const auto& p : get_registry()) {
rc.push_back(p.first);
}
for (const auto& p : get_registered_device_map())
{
if (find(rc.begin(), rc.end(), p.first) == rc.end())
{
for (const auto& p : get_registered_device_map()) {
if (find(rc.begin(), rc.end(), p.first) == rc.end()) {
rc.push_back(p.first);
}
}
return rc;
}
shared_ptr<runtime::Backend> runtime::BackendManager::create_backend(std::string config)
{
shared_ptr<runtime::Backend> runtime::BackendManager::create_backend(std::string config) {
string type = config;
string options;
// strip off attributes, IE:CPU becomes IE
auto colon = type.find(":");
if (colon != type.npos)
{
if (colon != type.npos) {
options = type.substr(colon + 1);
type = type.substr(0, colon);
}
@@ -81,34 +72,23 @@ shared_ptr<runtime::Backend> runtime::BackendManager::create_backend(std::string
auto& registry = get_registry();
auto it = registry.find(type);
string error;
if (it == registry.end())
{
if (it == registry.end()) {
DL_HANDLE handle = open_shared_library(type);
if (!handle)
{
if (!handle) {
error = DLERROR();
}
else
{
DLERROR(); // Clear any pending errors
string register_function_name =
string("ngraph_register_") + to_lower(type) + "_backend";
auto register_function =
reinterpret_cast<void (*)()>(DLSYM(handle, register_function_name.c_str()));
if (register_function)
{
} else {
DLERROR(); // Clear any pending errors
string register_function_name = string("ngraph_register_") + to_lower(type) + "_backend";
auto register_function = reinterpret_cast<void (*)()>(DLSYM(handle, register_function_name.c_str()));
if (register_function) {
register_function();
it = registry.find(type);
}
else
{
} else {
error = DLERROR();
CLOSE_LIBRARY(handle);
stringstream ss;
ss << "Failed to find symbol 'get_backend_constructor_pointer' in backend library."
<< endl;
if (error.size() > 0)
{
ss << "Failed to find symbol 'get_backend_constructor_pointer' in backend library." << endl;
if (error.size() > 0) {
ss << "\nError: " << error;
}
error = ss.str();
@@ -116,12 +96,10 @@ shared_ptr<runtime::Backend> runtime::BackendManager::create_backend(std::string
}
}
if (it == registry.end())
{
if (it == registry.end()) {
stringstream ss;
ss << "Backend '" << type << "' not registered.";
if (error.size() > 0)
{
if (error.size() > 0) {
ss << "\n Error: " << DLERROR();
}
throw runtime_error(ss.str());
@@ -129,39 +107,34 @@ shared_ptr<runtime::Backend> runtime::BackendManager::create_backend(std::string
return it->second(options);
}
DL_HANDLE runtime::BackendManager::open_shared_library(string type)
{
DL_HANDLE runtime::BackendManager::open_shared_library(string type) {
DL_HANDLE handle = nullptr;
string lib_prefix = SHARED_LIB_PREFIX;
string lib_suffix = SHARED_LIB_SUFFIX;
// strip off attributes, IE:CPU becomes IE
auto colon = type.find(":");
if (colon != type.npos)
{
if (colon != type.npos) {
type = type.substr(0, colon);
}
string library_name = lib_prefix + to_lower(type) + "_backend" + lib_suffix;
string my_directory =
file_util::get_directory(Backend::get_backend_shared_library_search_directory());
string my_directory = file_util::get_directory(Backend::get_backend_shared_library_search_directory());
string library_path = file_util::path_join(my_directory, library_name);
#ifdef _WIN32
SetDllDirectoryA((LPCSTR)my_directory.c_str());
handle = LoadLibraryA(library_path.c_str());
#elif defined(__APPLE__) || defined(__linux__)
DLERROR(); // Clear any pending errors
DLERROR(); // Clear any pending errors
handle = dlopen(library_path.c_str(), RTLD_NOW | RTLD_GLOBAL);
#else
#error "Unsupported OS"
# error "Unsupported OS"
#endif
string error = DLERROR();
if (!handle)
{
if (!handle) {
stringstream ss;
ss << "Unable to find backend '" << type << "' as file '" << library_path << "'";
if (error.size() > 0)
{
if (error.size() > 0) {
ss << "\nOpen error message '" << error << "'";
}
throw runtime_error(ss.str());
@@ -169,20 +142,16 @@ DL_HANDLE runtime::BackendManager::open_shared_library(string type)
return handle;
}
map<string, string> runtime::BackendManager::get_registered_device_map()
{
map<string, string> runtime::BackendManager::get_registered_device_map() {
map<string, string> rc;
string my_directory =
file_util::get_directory(Backend::get_backend_shared_library_search_directory());
string my_directory = file_util::get_directory(Backend::get_backend_shared_library_search_directory());
vector<string> backend_list;
auto f = [&](const string& file, bool is_dir) {
if (!is_dir)
{
if (!is_dir) {
string name = file_util::get_file_name(file);
string backend_name;
if (is_backend_name(name, backend_name))
{
if (is_backend_name(name, backend_name)) {
rc.insert({to_upper(backend_name), file});
}
}
@@ -191,20 +160,15 @@ map<string, string> runtime::BackendManager::get_registered_device_map()
return rc;
}
bool runtime::BackendManager::is_backend_name(const string& file, string& backend_name)
{
bool runtime::BackendManager::is_backend_name(const string& file, string& backend_name) {
bool rc = false;
string name = file_util::get_file_name(file);
string lib_prefix = SHARED_LIB_PREFIX;
string lib_suffix = SHARED_LIB_SUFFIX;
if ((name.size() > lib_prefix.size() + lib_suffix.size()) &
!name.compare(0, lib_prefix.size(), lib_prefix))
{
if (!name.compare(name.size() - lib_suffix.size(), lib_suffix.size(), lib_suffix))
{
if ((name.size() > lib_prefix.size() + lib_suffix.size()) & !name.compare(0, lib_prefix.size(), lib_prefix)) {
if (!name.compare(name.size() - lib_suffix.size(), lib_suffix.size(), lib_suffix)) {
auto pos = name.find("_backend");
if (pos != name.npos)
{
if (pos != name.npos) {
backend_name = name.substr(lib_prefix.size(), pos - lib_prefix.size());
rc = true;
}

View File

@@ -12,31 +12,27 @@
#include <vector>
#ifdef _WIN32
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#define DL_HANDLE HMODULE
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
# define DL_HANDLE HMODULE
#else
#define DL_HANDLE void*
# define DL_HANDLE void*
#endif
#include "backend_visibility.hpp"
#include "ngraph/ngraph_visibility.hpp"
namespace ngraph
{
namespace runtime
{
class Backend;
class BackendManager;
using BackendConstructor =
std::function<std::shared_ptr<ngraph::runtime::Backend>(const std::string& config)>;
}
}
namespace ngraph {
namespace runtime {
class Backend;
class BackendManager;
using BackendConstructor = std::function<std::shared_ptr<ngraph::runtime::Backend>(const std::string& config)>;
} // namespace runtime
} // namespace ngraph
class ngraph::runtime::BackendManager
{
class ngraph::runtime::BackendManager {
friend class Backend;
public:
@@ -45,8 +41,7 @@ public:
/// \param name The name of the registering backend in UPPER CASE.
/// \param backend_constructor A BackendConstructor which will be called to
//// construct an instance of the registered backend.
static BACKEND_API void register_backend(const std::string& name,
BackendConstructor backend_constructor);
static BACKEND_API void register_backend(const std::string& name, BackendConstructor backend_constructor);
/// \brief Query the list of registered devices
/// \returns A vector of all registered devices.

View File

@@ -4,8 +4,8 @@
#include "ngraph/visibility.hpp"
#ifdef ngraph_backend_EXPORTS // defined if we are building the ngraph_backend as shared library
#define BACKEND_API NGRAPH_HELPER_DLL_EXPORT
#ifdef ngraph_backend_EXPORTS // defined if we are building the ngraph_backend as shared library
# define BACKEND_API NGRAPH_HELPER_DLL_EXPORT
#else
#define BACKEND_API NGRAPH_HELPER_DLL_IMPORT
# define BACKEND_API NGRAPH_HELPER_DLL_IMPORT
#endif

View File

@@ -3,21 +3,18 @@
//
#include "cache.hpp"
#include "ngraph/env_util.hpp"
using namespace ngraph;
using namespace std;
// Constructor
runtime::LRUCache::LRUCache()
{
runtime::LRUCache::LRUCache() {
int32_t cache_size = getenv_int("NGRAPH_CACHE_SIZE");
if (cache_size <= 0)
{
m_cache_size = 1024; // TODO(nbpatel): Figure out a default size for the cache
}
else
{
if (cache_size <= 0) {
m_cache_size = 1024; // TODO(nbpatel): Figure out a default size for the cache
} else {
m_cache_size = cache_size;
}
@@ -26,30 +23,25 @@ runtime::LRUCache::LRUCache()
}
// Destructor
runtime::LRUCache::~LRUCache()
{
runtime::LRUCache::~LRUCache() {
m_list.clear();
m_map.clear();
m_clone_function_map.clear();
}
void runtime::LRUCache::convert_shape_to_string(const vector<int>& shape, ostringstream& key)
{
if (!shape.empty())
{
void runtime::LRUCache::convert_shape_to_string(const vector<int>& shape, ostringstream& key) {
if (!shape.empty()) {
std::copy(shape.begin(), shape.end(), std::ostream_iterator<int>(key, ", "));
}
}
void runtime::LRUCache::add_entry(const vector<int>& shape,
shared_ptr<runtime::Executable> exec,
shared_ptr<Function> func)
{
shared_ptr<Function> func) {
std::lock_guard<std::mutex> guard(m_mutex);
ostringstream key;
// check if the list is empty
if (m_list.size() == static_cast<size_t>(m_cache_size))
{
if (m_list.size() == static_cast<size_t>(m_cache_size)) {
ostringstream key;
convert_shape_to_string(m_list.back(), key);
m_list.pop_back();
@@ -62,37 +54,28 @@ void runtime::LRUCache::add_entry(const vector<int>& shape,
m_clone_function_map.insert({key.str(), func});
}
bool runtime::LRUCache::is_cached(const vector<int>& shape)
{
for (auto itr = m_list.begin(); itr != m_list.end(); itr++)
{
if (*itr == shape)
{
bool runtime::LRUCache::is_cached(const vector<int>& shape) {
for (auto itr = m_list.begin(); itr != m_list.end(); itr++) {
if (*itr == shape) {
return true;
}
}
return false;
}
shared_ptr<runtime::Executable> runtime::LRUCache::get_cached_entry(const vector<int>& shape)
{
shared_ptr<runtime::Executable> runtime::LRUCache::get_cached_entry(const vector<int>& shape) {
std::lock_guard<std::mutex> guard(m_mutex);
ostringstream key;
convert_shape_to_string(shape, key);
// find the entry and return the function
auto it = m_map.find(key.str());
if (it == m_map.end())
{
if (it == m_map.end()) {
throw ngraph_error("Entry not found in cache");
}
else
{
} else {
// update list to push this reference to the front
for (auto itr = m_list.begin(); itr != m_list.end(); itr++)
{
if (*itr == shape)
{
for (auto itr = m_list.begin(); itr != m_list.end(); itr++) {
if (*itr == shape) {
m_list.remove(shape);
m_list.push_front(shape);
break;
@@ -104,15 +87,13 @@ shared_ptr<runtime::Executable> runtime::LRUCache::get_cached_entry(const vector
// Need the clone function to get the output shape so that
// storage can be allocated for output
shared_ptr<Function> runtime::LRUCache::get_cloned_function(const vector<int>& shape)
{
shared_ptr<Function> runtime::LRUCache::get_cloned_function(const vector<int>& shape) {
std::lock_guard<std::mutex> guard(m_mutex);
ostringstream key;
convert_shape_to_string(shape, key);
// find the entry and return the function
auto it = m_clone_function_map.find(key.str());
if (it == m_clone_function_map.end())
{
if (it == m_clone_function_map.end()) {
throw ngraph_error("Cloned function not found");
}
return it->second;

View File

@@ -12,38 +12,34 @@
#include <sstream>
#include <string>
#include <unordered_map>
#include "executable.hpp"
#include "ngraph/function.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace runtime
{
class LRUCache : public std::enable_shared_from_this<LRUCache>
{
public:
using GraphCache = std::unordered_map<std::string, std::shared_ptr<Executable>>;
using ClonedFunctionMap = std::unordered_map<std::string, std::shared_ptr<Function>>;
namespace ngraph {
namespace runtime {
class LRUCache : public std::enable_shared_from_this<LRUCache> {
public:
using GraphCache = std::unordered_map<std::string, std::shared_ptr<Executable>>;
using ClonedFunctionMap = std::unordered_map<std::string, std::shared_ptr<Function>>;
LRUCache();
LRUCache();
virtual ~LRUCache();
virtual ~LRUCache();
void add_entry(const std::vector<int>& shape,
std::shared_ptr<Executable> exec,
std::shared_ptr<Function> func);
bool is_cached(const std::vector<int>& shape);
std::shared_ptr<Executable> get_cached_entry(const std::vector<int>& shape);
void convert_shape_to_string(const std::vector<int>& shape, std::ostringstream& key);
std::shared_ptr<Function> get_cloned_function(const std::vector<int>& shape);
void add_entry(const std::vector<int>& shape, std::shared_ptr<Executable> exec, std::shared_ptr<Function> func);
bool is_cached(const std::vector<int>& shape);
std::shared_ptr<Executable> get_cached_entry(const std::vector<int>& shape);
void convert_shape_to_string(const std::vector<int>& shape, std::ostringstream& key);
std::shared_ptr<Function> get_cloned_function(const std::vector<int>& shape);
private:
int m_cache_size;
GraphCache m_map;
ClonedFunctionMap m_clone_function_map;
std::list<std::vector<int>> m_list;
std::mutex m_mutex;
};
}
}
private:
int m_cache_size;
GraphCache m_map;
ClonedFunctionMap m_clone_function_map;
std::list<std::vector<int>> m_list;
std::mutex m_mutex;
};
} // namespace runtime
} // namespace ngraph

View File

@@ -3,6 +3,7 @@
//
#include "dynamic_backend.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/avg_pool.hpp"
#include "ngraph/op/broadcast.hpp"
@@ -21,49 +22,39 @@ using namespace std;
using namespace ngraph;
runtime::dynamic::DynamicBackend::DynamicBackend(shared_ptr<runtime::Backend> wrapped_backend)
: m_wrapped_backend(std::move(wrapped_backend))
{
}
: m_wrapped_backend(std::move(wrapped_backend)) {}
shared_ptr<runtime::Tensor> runtime::dynamic::DynamicBackend::create_tensor()
{
shared_ptr<runtime::Tensor> runtime::dynamic::DynamicBackend::create_tensor() {
return m_wrapped_backend->create_tensor();
}
shared_ptr<runtime::Tensor>
runtime::dynamic::DynamicBackend::create_tensor(const element::Type& type, const Shape& shape)
{
shared_ptr<runtime::Tensor> runtime::dynamic::DynamicBackend::create_tensor(const element::Type& type,
const Shape& shape) {
return m_wrapped_backend->create_tensor(type, shape);
}
shared_ptr<runtime::Tensor> runtime::dynamic::DynamicBackend::create_tensor(
const element::Type& type, const Shape& shape, void* memory_pointer)
{
shared_ptr<runtime::Tensor> runtime::dynamic::DynamicBackend::create_tensor(const element::Type& type,
const Shape& shape,
void* memory_pointer) {
return m_wrapped_backend->create_tensor(type, shape, memory_pointer);
}
std::shared_ptr<runtime::Tensor>
runtime::dynamic::DynamicBackend::create_dynamic_tensor(const element::Type& type,
const PartialShape& shape)
{
std::shared_ptr<runtime::Tensor> runtime::dynamic::DynamicBackend::create_dynamic_tensor(const element::Type& type,
const PartialShape& shape) {
return make_shared<DynamicTensor>(type, shape, m_wrapped_backend);
}
shared_ptr<runtime::Executable>
runtime::dynamic::DynamicBackend::compile(shared_ptr<Function> function,
bool enable_performance_collection)
{
return make_shared<runtime::dynamic::DynamicExecutable>(
function, m_wrapped_backend, enable_performance_collection);
shared_ptr<runtime::Executable> runtime::dynamic::DynamicBackend::compile(shared_ptr<Function> function,
bool enable_performance_collection) {
return make_shared<runtime::dynamic::DynamicExecutable>(function, m_wrapped_backend, enable_performance_collection);
}
runtime::dynamic::DynamicExecutable::DynamicExecutable(shared_ptr<Function> wrapped_function,
shared_ptr<runtime::Backend> wrapped_backend,
bool enable_performance_collection)
: m_wrapped_function(wrapped_function)
, m_wrapped_backend(wrapped_backend)
, m_enable_performance_collection(enable_performance_collection)
{
: m_wrapped_function(wrapped_function),
m_wrapped_backend(wrapped_backend),
m_enable_performance_collection(enable_performance_collection) {
pass::Manager passes;
passes.register_pass<pass::ShapeRelevance>();
passes.run_passes(m_wrapped_function);
@@ -73,30 +64,24 @@ runtime::dynamic::DynamicExecutable::DynamicExecutable(shared_ptr<Function> wrap
// Due to clang++-3.9 bugs, this needs to be a non-static separate function from
// count_dyn_nodes.
bool is_dynamic_op(const std::shared_ptr<Node>& op)
{
bool is_dynamic_op(const std::shared_ptr<Node>& op) {
return ov::is_type<op::Range>(op) || ov::is_type<op::v1::ConvolutionBackpropData>(op) ||
ov::is_type<op::v3::Broadcast>(op);
}
// Helper for a vile hack in DynamicExecutable::call. See body of that function for details.
static size_t count_dyn_nodes(const shared_ptr<ngraph::Function>& f)
{
static size_t count_dyn_nodes(const shared_ptr<ngraph::Function>& f) {
size_t count = 0;
for (auto op : f->get_ops())
{
if (is_dynamic_op(op))
{
for (auto op : f->get_ops()) {
if (is_dynamic_op(op)) {
count++;
}
}
return count;
}
bool runtime::dynamic::DynamicExecutable::call(
const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs)
{
bool runtime::dynamic::DynamicExecutable::call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) {
// TODO: Get cached executable out if it exists.
// We will cache on:
// (1) all shapes;
@@ -105,24 +90,18 @@ bool runtime::dynamic::DynamicExecutable::call(
std::vector<int> merged_input_shapes;
std::ostringstream key;
size_t loop_count = 0;
for (auto& input : inputs)
{
if (m_wrapped_function->get_parameters()[loop_count]->is_relevant_to_shapes())
{
for (auto& input : inputs) {
if (m_wrapped_function->get_parameters()[loop_count]->is_relevant_to_shapes()) {
// Caching on values of Shape relevant inputs
int size = input->get_size_in_bytes() / (input->get_element_type().bitwidth() / 8);
std::vector<int64_t> data(size);
input->read(data.data(), input->get_size_in_bytes());
for (size_t i = 0; i < input->get_element_count(); i++)
{
for (size_t i = 0; i < input->get_element_count(); i++) {
merged_input_shapes.emplace_back(data[i]);
}
}
else
{
} else {
// Caching on all remaining shapes
for (size_t i = 0; i < input->get_shape().size(); i++)
{
for (size_t i = 0; i < input->get_shape().size(); i++) {
merged_input_shapes.emplace_back(input->get_shape()[i]);
}
}
@@ -133,44 +112,32 @@ bool runtime::dynamic::DynamicExecutable::call(
loop_count++;
}
std::copy(merged_input_shapes.begin(),
merged_input_shapes.end(),
std::ostream_iterator<int>(key, ", "));
std::copy(merged_input_shapes.begin(), merged_input_shapes.end(), std::ostream_iterator<int>(key, ", "));
if (m_lru->is_cached(merged_input_shapes))
{
if (m_lru->is_cached(merged_input_shapes)) {
std::vector<std::shared_ptr<runtime::Tensor>> wrapped_inputs;
std::vector<std::shared_ptr<runtime::Tensor>> wrapped_outputs;
std::shared_ptr<Function> clone = m_lru->get_cloned_function(merged_input_shapes);
const ResultVector& results = clone->get_results();
for (auto& result : results)
{
for (auto& result : results) {
NGRAPH_CHECK(result->get_output_partial_shape(0).is_static(),
"Shape staticization failed for result node ",
*result);
}
NGRAPH_CHECK(results.size() == outputs.size());
for (size_t i = 0; i < outputs.size(); i++)
{
if (auto dynamic_tensor =
std::dynamic_pointer_cast<runtime::dynamic::DynamicTensor>(outputs[i]))
{
dynamic_tensor->make_storage(results[i]->get_output_element_type(0),
results[i]->get_output_shape(0));
for (size_t i = 0; i < outputs.size(); i++) {
if (auto dynamic_tensor = std::dynamic_pointer_cast<runtime::dynamic::DynamicTensor>(outputs[i])) {
dynamic_tensor->make_storage(results[i]->get_output_element_type(0), results[i]->get_output_shape(0));
wrapped_outputs.push_back(dynamic_tensor->get_wrapped_tensor());
}
else
{
} else {
wrapped_outputs.push_back(outputs[i]);
}
}
return m_lru->get_cached_entry(merged_input_shapes)->call(wrapped_outputs, inputs);
}
else
{
} else {
NGRAPH_CHECK(m_wrapped_function->get_parameters().size() == inputs.size());
std::vector<std::shared_ptr<runtime::Tensor>> wrapped_inputs;
@@ -188,14 +155,10 @@ bool runtime::dynamic::DynamicExecutable::call(
size_t i = 0;
for (auto& input : inputs)
{
if (m_wrapped_function->get_parameters()[i]->is_relevant_to_shapes())
{
for (auto& input : inputs) {
if (m_wrapped_function->get_parameters()[i]->is_relevant_to_shapes()) {
// TODO(amprocte): Move has_storage() to runtime::Tensor?
if (auto dynamic_tensor =
std::dynamic_pointer_cast<runtime::dynamic::DynamicTensor>(input))
{
if (auto dynamic_tensor = std::dynamic_pointer_cast<runtime::dynamic::DynamicTensor>(input)) {
NGRAPH_CHECK(dynamic_tensor->has_storage());
}
@@ -205,23 +168,16 @@ bool runtime::dynamic::DynamicExecutable::call(
// TODO(amprocte): For host-resident tensors we should be able to skip the read,
// but no API for that yet.
input->read(arg_value_base_pointers[i], input->get_size_in_bytes());
}
else
{
} else {
arg_value_base_pointers[i] = nullptr;
}
if (auto dynamic_tensor =
std::dynamic_pointer_cast<runtime::dynamic::DynamicTensor>(input))
{
if (auto dynamic_tensor = std::dynamic_pointer_cast<runtime::dynamic::DynamicTensor>(input)) {
NGRAPH_CHECK(dynamic_tensor->has_storage());
arg_element_types.push_back(
dynamic_tensor->get_wrapped_tensor()->get_element_type());
arg_element_types.push_back(dynamic_tensor->get_wrapped_tensor()->get_element_type());
arg_shapes.push_back(dynamic_tensor->get_wrapped_tensor()->get_shape());
wrapped_inputs.push_back(dynamic_tensor->get_wrapped_tensor());
}
else
{
} else {
arg_element_types.push_back(input->get_element_type());
arg_shapes.push_back(input->get_shape());
wrapped_inputs.push_back(input);
@@ -231,8 +187,7 @@ bool runtime::dynamic::DynamicExecutable::call(
}
NGRAPH_SUPPRESS_DEPRECATED_START;
clone = specialize_function(
m_wrapped_function, arg_element_types, arg_shapes, arg_value_base_pointers);
clone = specialize_function(m_wrapped_function, arg_element_types, arg_shapes, arg_value_base_pointers);
NGRAPH_SUPPRESS_DEPRECATED_END;
}
@@ -252,8 +207,7 @@ bool runtime::dynamic::DynamicExecutable::call(
// and DE into one pass.
size_t num_dyn_nodes_last_pass = std::numeric_limits<size_t>::max();
while (num_dyn_nodes_last_pass != 0)
{
while (num_dyn_nodes_last_pass != 0) {
passes.run_passes(clone);
auto num_dyn_nodes_this_pass = count_dyn_nodes(clone);
@@ -272,23 +226,17 @@ bool runtime::dynamic::DynamicExecutable::call(
const ResultVector& results = clone->get_results();
NGRAPH_CHECK(results.size() == outputs.size());
for (size_t i = 0; i < outputs.size(); i++)
{
if (auto dynamic_tensor =
std::dynamic_pointer_cast<runtime::dynamic::DynamicTensor>(outputs[i]))
{
for (size_t i = 0; i < outputs.size(); i++) {
if (auto dynamic_tensor = std::dynamic_pointer_cast<runtime::dynamic::DynamicTensor>(outputs[i])) {
dynamic_tensor->make_storage(results[i]->get_output_element_type(0),
results[i]->get_output_partial_shape(0));
wrapped_outputs.push_back(dynamic_tensor->get_wrapped_tensor());
}
else
{
} else {
wrapped_outputs.push_back(outputs[i]);
}
}
auto compiled_executable =
m_wrapped_backend->compile(clone, m_enable_performance_collection);
auto compiled_executable = m_wrapped_backend->compile(clone, m_enable_performance_collection);
// Put compiled executable in the cache.
m_lru->add_entry(merged_input_shapes, compiled_executable, clone);
auto result = compiled_executable->call(wrapped_outputs, wrapped_inputs);
@@ -297,78 +245,57 @@ bool runtime::dynamic::DynamicExecutable::call(
}
}
runtime::dynamic::DynamicTensor::DynamicTensor(
const element::Type& element_type,
const PartialShape& shape,
const std::shared_ptr<runtime::Backend>& wrapped_backend)
: Tensor(make_shared<descriptor::Tensor>(element_type, shape, "wrapped_dynamic"))
, m_wrapped_tensor(nullptr)
, m_wrapped_backend(wrapped_backend)
{
}
runtime::dynamic::DynamicTensor::DynamicTensor(const element::Type& element_type,
const PartialShape& shape,
const std::shared_ptr<runtime::Backend>& wrapped_backend)
: Tensor(make_shared<descriptor::Tensor>(element_type, shape, "wrapped_dynamic")),
m_wrapped_tensor(nullptr),
m_wrapped_backend(wrapped_backend) {}
size_t runtime::dynamic::DynamicTensor::get_size_in_bytes() const
{
NGRAPH_CHECK(m_wrapped_tensor != nullptr,
"asked for size in bytes of a dynamic tensor with no allocated storage");
size_t runtime::dynamic::DynamicTensor::get_size_in_bytes() const {
NGRAPH_CHECK(m_wrapped_tensor != nullptr, "asked for size in bytes of a dynamic tensor with no allocated storage");
// TODO expand size calculation for type with bitwidth less than 8 like:
// m_wrapped_tensor->get_size_in_bytes()
return get_element_count() * get_element_type().size();
}
size_t runtime::dynamic::DynamicTensor::get_element_count() const
{
NGRAPH_CHECK(m_wrapped_tensor != nullptr,
"asked for element count of a dynamic tensor with no allocated storage");
size_t runtime::dynamic::DynamicTensor::get_element_count() const {
NGRAPH_CHECK(m_wrapped_tensor != nullptr, "asked for element count of a dynamic tensor with no allocated storage");
return shape_size(m_wrapped_tensor->get_shape());
}
const element::Type& runtime::dynamic::DynamicTensor::get_element_type() const
{
if (m_wrapped_tensor == nullptr)
{
const element::Type& runtime::dynamic::DynamicTensor::get_element_type() const {
if (m_wrapped_tensor == nullptr) {
return m_descriptor->get_element_type();
}
else
{
} else {
return m_wrapped_tensor->get_element_type();
}
}
const ngraph::Shape& runtime::dynamic::DynamicTensor::get_shape() const
{
NGRAPH_CHECK(m_wrapped_tensor != nullptr,
"asked for shape of a dynamic tensor with no allocated storage");
const ngraph::Shape& runtime::dynamic::DynamicTensor::get_shape() const {
NGRAPH_CHECK(m_wrapped_tensor != nullptr, "asked for shape of a dynamic tensor with no allocated storage");
return m_wrapped_tensor->get_shape();
}
void runtime::dynamic::DynamicTensor::write(const void* p, size_t n)
{
NGRAPH_CHECK(m_wrapped_tensor != nullptr,
"tried to write to a dynamic tensor with no allocated storage");
void runtime::dynamic::DynamicTensor::write(const void* p, size_t n) {
NGRAPH_CHECK(m_wrapped_tensor != nullptr, "tried to write to a dynamic tensor with no allocated storage");
m_wrapped_tensor->write(p, n);
}
void runtime::dynamic::DynamicTensor::read(void* p, size_t n) const
{
NGRAPH_CHECK(m_wrapped_tensor != nullptr,
"tried to read from a dynamic tensor with no allocated storage");
void runtime::dynamic::DynamicTensor::read(void* p, size_t n) const {
NGRAPH_CHECK(m_wrapped_tensor != nullptr, "tried to read from a dynamic tensor with no allocated storage");
m_wrapped_tensor->read(p, n);
}
bool runtime::dynamic::DynamicTensor::has_storage() const
{
bool runtime::dynamic::DynamicTensor::has_storage() const {
return m_wrapped_tensor != nullptr;
}
void runtime::dynamic::DynamicTensor::release_storage()
{
void runtime::dynamic::DynamicTensor::release_storage() {
m_wrapped_tensor = nullptr;
}
void runtime::dynamic::DynamicTensor::make_storage(const element::Type& element_type,
const PartialShape& shape)
{
void runtime::dynamic::DynamicTensor::make_storage(const element::Type& element_type, const PartialShape& shape) {
NGRAPH_CHECK(element_type.is_static(), "make_storage requires a static element type");
NGRAPH_CHECK(get_element_type().is_dynamic() || get_element_type() == element_type,
"tried to make storage with element type ",
@@ -380,18 +307,13 @@ void runtime::dynamic::DynamicTensor::make_storage(const element::Type& element_
shape,
" which is incompatible with dynamic tensor shape ",
get_partial_shape());
if (shape.is_static())
{
if (shape.is_static()) {
m_wrapped_tensor = m_wrapped_backend->create_tensor(element_type, shape.get_shape());
}
else
{
} else {
m_wrapped_tensor = m_wrapped_backend->create_dynamic_tensor(element_type, shape);
}
}
const std::shared_ptr<ngraph::runtime::Tensor>&
runtime::dynamic::DynamicTensor::get_wrapped_tensor() const
{
const std::shared_ptr<ngraph::runtime::Tensor>& runtime::dynamic::DynamicTensor::get_wrapped_tensor() const {
return m_wrapped_tensor;
}

View File

@@ -14,18 +14,15 @@
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/tensor.hpp"
namespace ngraph
{
namespace runtime
{
namespace dynamic
{
class DynamicBackend;
class DynamicExecutable;
class DynamicTensor;
}
}
}
namespace ngraph {
namespace runtime {
namespace dynamic {
class DynamicBackend;
class DynamicExecutable;
class DynamicTensor;
} // namespace dynamic
} // namespace runtime
} // namespace ngraph
///
/// \brief Wrapper class used to provide dynamic tensor support on backends
@@ -42,22 +39,21 @@ namespace ngraph
///
/// This class is instantiated by `ngraph::runtime::Backend::create`.
///
class ngraph::runtime::dynamic::DynamicBackend : public Backend
{
class ngraph::runtime::dynamic::DynamicBackend : public Backend {
public:
DynamicBackend(std::shared_ptr<ngraph::runtime::Backend> wrapped_backend);
std::shared_ptr<Tensor> create_tensor() override;
std::shared_ptr<Tensor>
create_tensor(const element::Type& type, const Shape& shape, void* memory_pointer) override;
std::shared_ptr<Tensor> create_tensor(const element::Type& type, const Shape& shape, void* memory_pointer) override;
std::shared_ptr<Tensor> create_tensor(const element::Type& type, const Shape& shape) override;
std::shared_ptr<Tensor> create_dynamic_tensor(const element::Type& type,
const PartialShape& shape) override;
std::shared_ptr<Tensor> create_dynamic_tensor(const element::Type& type, const PartialShape& shape) override;
bool supports_dynamic_tensors() override { return true; }
bool supports_dynamic_tensors() override {
return true;
}
std::shared_ptr<Executable> compile(std::shared_ptr<Function> function,
bool enable_performance_data = false) override;
@@ -79,8 +75,7 @@ private:
///
/// `DynamicExecutable` objects are produced by `DynamicBackend::compile()`.
///
class ngraph::runtime::dynamic::DynamicExecutable : public ngraph::runtime::Executable
{
class ngraph::runtime::dynamic::DynamicExecutable : public ngraph::runtime::Executable {
public:
DynamicExecutable(std::shared_ptr<Function> wrapped_function,
std::shared_ptr<ngraph::runtime::Backend> wrapped_backend,
@@ -91,8 +86,7 @@ public:
private:
std::shared_ptr<ngraph::Function> m_wrapped_function;
std::shared_ptr<ngraph::runtime::Backend> m_wrapped_backend;
std::shared_ptr<ngraph::runtime::LRUCache> m_lru =
std::make_shared<ngraph::runtime::LRUCache>();
std::shared_ptr<ngraph::runtime::LRUCache> m_lru = std::make_shared<ngraph::runtime::LRUCache>();
bool m_enable_performance_collection;
};
@@ -114,8 +108,7 @@ private:
/// called until the storage has been released via `release_storage()`.
/// 4. `release_storage()` unassigns previously assigned storage.
///
class ngraph::runtime::dynamic::DynamicTensor : public ngraph::runtime::Tensor
{
class ngraph::runtime::dynamic::DynamicTensor : public ngraph::runtime::Tensor {
public:
DynamicTensor(const element::Type& element_type,
const PartialShape& shape,

View File

@@ -2,9 +2,10 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "executable.hpp"
#include <sstream>
#include "executable.hpp"
#include "ngraph/file_util.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "ngraph/util.hpp"
@@ -17,147 +18,121 @@ runtime::Executable::Executable() {}
runtime::Executable::~Executable() {}
bool runtime::Executable::call_with_validate(const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
const vector<shared_ptr<runtime::Tensor>>& inputs) {
validate(outputs, inputs);
return call(outputs, inputs);
}
void runtime::Executable::validate(const vector<std::shared_ptr<runtime::Tensor>>& outputs,
const vector<std::shared_ptr<runtime::Tensor>>& inputs)
{
const vector<std::shared_ptr<runtime::Tensor>>& inputs) {
const ParameterVector& parameters = get_parameters();
const ResultVector& results = get_results();
if (parameters.size() != inputs.size())
{
if (parameters.size() != inputs.size()) {
stringstream ss;
ss << "Call input count " << inputs.size() << " does not match Function's Parameter count "
<< parameters.size();
throw runtime_error(ss.str());
}
if (results.size() != outputs.size())
{
if (results.size() != outputs.size()) {
stringstream ss;
ss << "Call output count " << outputs.size() << " does not match Function's Result count "
<< results.size();
ss << "Call output count " << outputs.size() << " does not match Function's Result count " << results.size();
throw runtime_error(ss.str());
}
for (size_t i = 0; i < parameters.size(); i++)
{
for (size_t i = 0; i < parameters.size(); i++) {
if (parameters[i]->get_element_type().is_static() &&
parameters[i]->get_element_type() != inputs[i]->get_element_type())
{
parameters[i]->get_element_type() != inputs[i]->get_element_type()) {
stringstream ss;
ss << "Input " << i << " type '" << inputs[i]->get_element_type()
<< "' does not match Parameter type '" << parameters[i]->get_element_type() << "'";
ss << "Input " << i << " type '" << inputs[i]->get_element_type() << "' does not match Parameter type '"
<< parameters[i]->get_element_type() << "'";
throw runtime_error(ss.str());
}
if (!(parameters[i]->get_output_partial_shape(0).relaxes(inputs[i]->get_partial_shape())))
{
if (!(parameters[i]->get_output_partial_shape(0).relaxes(inputs[i]->get_partial_shape()))) {
stringstream ss;
ss << "Input " << i << " shape " << inputs[i]->get_partial_shape()
<< " does not match Parameter shape " << parameters[i]->get_output_partial_shape(0);
ss << "Input " << i << " shape " << inputs[i]->get_partial_shape() << " does not match Parameter shape "
<< parameters[i]->get_output_partial_shape(0);
throw runtime_error(ss.str());
}
}
for (size_t i = 0; i < results.size(); i++)
{
if (outputs[i]->get_element_type().is_static() &&
results[i]->get_element_type().is_static() &&
results[i]->get_element_type() != outputs[i]->get_element_type())
{
for (size_t i = 0; i < results.size(); i++) {
if (outputs[i]->get_element_type().is_static() && results[i]->get_element_type().is_static() &&
results[i]->get_element_type() != outputs[i]->get_element_type()) {
stringstream ss;
ss << "Output " << i << " type '" << outputs[i]->get_element_type()
<< "' does not match Result type '" << results[i]->get_element_type() << "'";
ss << "Output " << i << " type '" << outputs[i]->get_element_type() << "' does not match Result type '"
<< results[i]->get_element_type() << "'";
throw runtime_error(ss.str());
}
if (!outputs[i]->get_partial_shape().relaxes(results[i]->get_output_partial_shape(0)))
{
if (!outputs[i]->get_partial_shape().relaxes(results[i]->get_output_partial_shape(0))) {
stringstream ss;
ss << "Output " << i << " shape " << outputs[i]->get_partial_shape()
<< " does not match max Result shape "
ss << "Output " << i << " shape " << outputs[i]->get_partial_shape() << " does not match max Result shape "
<< results[i]->get_output_partial_shape(0).get_max_shape();
throw runtime_error(ss.str());
}
}
}
const ngraph::ParameterVector& runtime::Executable::get_parameters() const
{
const ngraph::ParameterVector& runtime::Executable::get_parameters() const {
return m_parameters;
}
const ngraph::ResultVector& runtime::Executable::get_results() const
{
const ngraph::ResultVector& runtime::Executable::get_results() const {
return m_results;
}
size_t runtime::Executable::get_preferred_pipeline_depth() const
{
size_t runtime::Executable::get_preferred_pipeline_depth() const {
return 2;
}
void runtime::Executable::set_parameters_and_results(const Function& func)
{
void runtime::Executable::set_parameters_and_results(const Function& func) {
m_parameters = func.get_parameters();
m_results = func.get_results();
}
vector<runtime::PerformanceCounter> runtime::Executable::get_performance_data() const
{
vector<runtime::PerformanceCounter> runtime::Executable::get_performance_data() const {
return vector<PerformanceCounter>();
}
void runtime::Executable::save(std::ostream& /* output_stream */)
{
void runtime::Executable::save(std::ostream& /* output_stream */) {
throw runtime_error("save operation unimplemented.");
}
shared_ptr<runtime::Tensor> runtime::Executable::create_input_tensor(size_t /* input_index */)
{
shared_ptr<runtime::Tensor> runtime::Executable::create_input_tensor(size_t /* input_index */) {
throw runtime_error("create_input_tensor unimplemented");
}
shared_ptr<runtime::Tensor> runtime::Executable::create_input_tensor(size_t /* input_index */,
void* /* memory_pointer */)
{
void* /* memory_pointer */) {
throw runtime_error("create_input_tensor unimplemented");
}
shared_ptr<runtime::Tensor> runtime::Executable::create_output_tensor(size_t /* output_index */)
{
shared_ptr<runtime::Tensor> runtime::Executable::create_output_tensor(size_t /* output_index */) {
throw runtime_error("create_output_tensor unimplemented");
}
shared_ptr<runtime::Tensor> runtime::Executable::create_output_tensor(size_t /* output_index */,
void* /* memory_pointer */)
{
void* /* memory_pointer */) {
throw runtime_error("create_output_tensor unimplemented");
}
vector<shared_ptr<runtime::Tensor>>
runtime::Executable::create_input_tensor(size_t /* input_index */, size_t /* pipeline_depth */)
{
vector<shared_ptr<runtime::Tensor>> runtime::Executable::create_input_tensor(size_t /* input_index */,
size_t /* pipeline_depth */) {
throw runtime_error("create_input_tensor unimplemented");
}
vector<shared_ptr<runtime::Tensor>> runtime::Executable::create_input_tensor(
size_t /* input_index */, size_t /* pipeline_depth */, std::vector<void*> /* memory_pointer */)
{
vector<shared_ptr<runtime::Tensor>> runtime::Executable::create_input_tensor(size_t /* input_index */,
size_t /* pipeline_depth */,
std::vector<void*> /* memory_pointer */) {
throw runtime_error("create_input_tensor unimplemented");
}
vector<shared_ptr<runtime::Tensor>>
runtime::Executable::create_output_tensor(size_t /* output_index */,
size_t /* pipeline_depth */)
{
vector<shared_ptr<runtime::Tensor>> runtime::Executable::create_output_tensor(size_t /* output_index */,
size_t /* pipeline_depth */) {
throw runtime_error("create_output_tensor unimplemented");
}
vector<shared_ptr<runtime::Tensor>> runtime::Executable::create_output_tensor(
size_t /* output_index */, size_t /* pipeline_depth */, std::vector<void*> /* memory_pointer */)
{
vector<shared_ptr<runtime::Tensor>> runtime::Executable::create_output_tensor(size_t /* output_index */,
size_t /* pipeline_depth */,
std::vector<void*> /* memory_pointer */) {
throw runtime_error("create_output_tensor unimplemented");
}

View File

@@ -13,16 +13,13 @@
#include "ngraph/type/element_type.hpp"
#include "performance_counter.hpp"
namespace ngraph
{
namespace runtime
{
class Executable;
}
namespace ngraph {
namespace runtime {
class Executable;
}
} // namespace ngraph
class BACKEND_API ngraph::runtime::Executable
{
class BACKEND_API ngraph::runtime::Executable {
public:
Executable();
virtual ~Executable();
@@ -79,8 +76,7 @@ public:
/// must be sufficient to contain the tensor. The lifetime of the buffer is the
/// responsibility of the caller and must outlive the created Tensor.
/// \returns A Tensor
virtual std::shared_ptr<runtime::Tensor> create_input_tensor(size_t input_index,
void* memory_pointer);
virtual std::shared_ptr<runtime::Tensor> create_input_tensor(size_t input_index, void* memory_pointer);
/// \brief Create an output Tensor
/// \param output_index The index position in the output Result vector. This would be the same
@@ -95,8 +91,7 @@ public:
/// must be sufficient to contain the tensor. The lifetime of the buffer is the
/// responsibility of the caller and must outlive the created Tensor.
/// \returns A Tensor
virtual std::shared_ptr<runtime::Tensor> create_output_tensor(size_t output_index,
void* memory_pointer);
virtual std::shared_ptr<runtime::Tensor> create_output_tensor(size_t output_index, void* memory_pointer);
/// \brief Create a vector of input Tensors
/// \param input_index The index position in the input Parameter vector. This would be the same
@@ -104,8 +99,8 @@ public:
/// \param pipeline_depth The number of stages in the input pipeline. For double-buffered input
/// you would specify pipeline_depth=2
/// \returns A vector of Tensors, one for each stage of the pipeline
virtual std::vector<std::shared_ptr<runtime::Tensor>>
create_input_tensor(size_t input_index, size_t pipeline_depth);
virtual std::vector<std::shared_ptr<runtime::Tensor>> create_input_tensor(size_t input_index,
size_t pipeline_depth);
/// \brief Create a vector of input Tensors
/// \param input_index The index position in the input Parameter vector. This would be the same
@@ -116,8 +111,9 @@ public:
/// the buffer must be sufficient to contain the tensor. The lifetime of the buffers is the
/// responsibility of the caller and must outlive the created Tensor.
/// \returns A vector of Tensors, one for each stage of the pipeline
virtual std::vector<std::shared_ptr<runtime::Tensor>> create_input_tensor(
size_t input_index, size_t pipeline_depth, std::vector<void*> memory_pointers);
virtual std::vector<std::shared_ptr<runtime::Tensor>> create_input_tensor(size_t input_index,
size_t pipeline_depth,
std::vector<void*> memory_pointers);
/// \brief Create a vector of output Tensors
/// \param output_index The index position in the output Result vector. This would be the same
@@ -125,8 +121,8 @@ public:
/// \param pipeline_depth The number of stages in the output pipeline. For double-buffered
/// output you would specify pipeline_depth=2
/// \returns A vector of Tensors, one for each stage of the pipeline
virtual std::vector<std::shared_ptr<runtime::Tensor>>
create_output_tensor(size_t output_index, size_t pipeline_depth);
virtual std::vector<std::shared_ptr<runtime::Tensor>> create_output_tensor(size_t output_index,
size_t pipeline_depth);
/// \brief Create a vector of output Tensors
/// \param output_index The index position in the output Result vector. This would be the same
@@ -137,8 +133,9 @@ public:
/// the buffer must be sufficient to contain the tensor. The lifetime of the buffers is the
/// responsibility of the caller and must outlive the created Tensor.
/// \returns A vector of Tensors, one for each stage of the pipeline
virtual std::vector<std::shared_ptr<runtime::Tensor>> create_output_tensor(
size_t output_index, size_t pipeline_depth, std::vector<void*> memory_pointers);
virtual std::vector<std::shared_ptr<runtime::Tensor>> create_output_tensor(size_t output_index,
size_t pipeline_depth,
std::vector<void*> memory_pointers);
protected:
/// \brief Called at the end of compile to the values to be returned by get_parameters

View File

@@ -16,50 +16,42 @@
using namespace std;
using namespace ngraph;
runtime::ie::IE_Backend::IE_Backend(const string& configuration_string)
{
runtime::ie::IE_Backend::IE_Backend(const string& configuration_string) {
string config = configuration_string;
// Get device name, after colon if present: IE:CPU -> CPU
auto separator = config.find(":");
if (separator != config.npos)
{
if (separator != config.npos) {
config = config.substr(separator + 1);
}
m_device = config;
}
shared_ptr<runtime::Executable> runtime::ie::IE_Backend::compile(shared_ptr<Function> func, bool)
{
shared_ptr<runtime::Executable> runtime::ie::IE_Backend::compile(shared_ptr<Function> func, bool) {
return make_shared<IE_Executable>(func, m_device);
}
bool runtime::ie::IE_Backend::is_supported(const Node& node) const
{
bool runtime::ie::IE_Backend::is_supported(const Node& node) const {
const auto& opset = get_opset1();
return opset.contains_op_type(&node);
}
shared_ptr<runtime::Tensor>
runtime::ie::IE_Backend::create_dynamic_tensor(const element::Type& type,
const PartialShape& shape)
{
shared_ptr<runtime::Tensor> runtime::ie::IE_Backend::create_dynamic_tensor(const element::Type& type,
const PartialShape& shape) {
return make_shared<IETensor>(type, shape);
}
shared_ptr<runtime::Tensor> runtime::ie::IE_Backend::create_tensor()
{
shared_ptr<runtime::Tensor> runtime::ie::IE_Backend::create_tensor() {
throw runtime_error("IE_Backend::create_tensor() not supported");
}
shared_ptr<runtime::Tensor>
runtime::ie::IE_Backend::create_tensor(const element::Type& element_type, const Shape& shape)
{
shared_ptr<runtime::Tensor> runtime::ie::IE_Backend::create_tensor(const element::Type& element_type,
const Shape& shape) {
return make_shared<IETensor>(element_type, shape);
}
shared_ptr<runtime::Tensor> runtime::ie::IE_Backend::create_tensor(
const element::Type& element_type, const Shape& shape, void* data)
{
shared_ptr<runtime::Tensor> runtime::ie::IE_Backend::create_tensor(const element::Type& element_type,
const Shape& shape,
void* data) {
shared_ptr<runtime::Tensor> tensor = make_shared<IETensor>(element_type, shape);
if (tensor == nullptr)
throw runtime_error("Cannot create IETensor!");
@@ -67,8 +59,8 @@ shared_ptr<runtime::Tensor> runtime::ie::IE_Backend::create_tensor(
return tensor;
}
extern "C" IE_BACKEND_API void ngraph_register_ie_backend()
{
runtime::BackendManager::register_backend(
"IE", [](const string& config) { return make_shared<runtime::ie::IE_Backend>(config); });
extern "C" IE_BACKEND_API void ngraph_register_ie_backend() {
runtime::BackendManager::register_backend("IE", [](const string& config) {
return make_shared<runtime::ie::IE_Backend>(config);
});
}

View File

@@ -18,50 +18,40 @@
class Handle;
namespace ngraph
{
namespace runtime
{
namespace ie
{
class IE_Backend final : public runtime::Backend
{
public:
IE_Backend(const std::string& configuration_string);
virtual ~IE_Backend() = default;
namespace ngraph {
namespace runtime {
namespace ie {
class IE_Backend final : public runtime::Backend {
public:
IE_Backend(const std::string& configuration_string);
virtual ~IE_Backend() = default;
std::shared_ptr<Executable> compile(std::shared_ptr<Function> func,
bool enable_performance_data = false) override;
bool is_supported(const Node& node) const override;
std::shared_ptr<Executable> compile(std::shared_ptr<Function> func, bool enable_performance_data = false) override;
bool is_supported(const Node& node) const override;
std::shared_ptr<ngraph::runtime::Tensor>
create_dynamic_tensor(const ngraph::element::Type& type,
const ngraph::PartialShape& shape) override;
std::shared_ptr<ngraph::runtime::Tensor> create_dynamic_tensor(const ngraph::element::Type& type,
const ngraph::PartialShape& shape) override;
std::shared_ptr<ngraph::runtime::Tensor> create_tensor() override;
std::shared_ptr<ngraph::runtime::Tensor> create_tensor() override;
std::shared_ptr<ngraph::runtime::Tensor>
create_tensor(const ngraph::element::Type& element_type,
const Shape& shape) final override;
std::shared_ptr<ngraph::runtime::Tensor> create_tensor(const ngraph::element::Type& element_type,
const Shape& shape) final override;
std::shared_ptr<ngraph::runtime::Tensor>
create_tensor(const ngraph::element::Type& element_type,
const Shape& shape,
void* data) final override;
std::shared_ptr<ngraph::runtime::Tensor> create_tensor(const ngraph::element::Type& element_type,
const Shape& shape,
void* data) final override;
template <typename T>
std::shared_ptr<ngraph::runtime::Tensor>
create_tensor(ngraph::element::Type type, ngraph::Shape shape, T* data)
{
auto tensor = std::make_shared<IETensor>(type, shape);
size_t size = shape_size(shape);
tensor->write(data, size * sizeof(T));
return tensor;
}
private:
std::string m_device;
};
}
template <typename T>
std::shared_ptr<ngraph::runtime::Tensor> create_tensor(ngraph::element::Type type, ngraph::Shape shape, T* data) {
auto tensor = std::make_shared<IETensor>(type, shape);
size_t size = shape_size(shape);
tensor->write(data, size * sizeof(T));
return tensor;
}
}
private:
std::string m_device;
};
} // namespace ie
} // namespace runtime
} // namespace ngraph

View File

@@ -4,8 +4,8 @@
#include "ngraph/visibility.hpp"
#ifdef ie_backend_EXPORTS // defined if we are building the ie_backend as shared library
#define IE_BACKEND_API NGRAPH_HELPER_DLL_EXPORT
#ifdef ie_backend_EXPORTS // defined if we are building the ie_backend as shared library
# define IE_BACKEND_API NGRAPH_HELPER_DLL_EXPORT
#else
#define IE_BACKEND_API NGRAPH_HELPER_DLL_IMPORT
# define IE_BACKEND_API NGRAPH_HELPER_DLL_IMPORT
#endif

View File

@@ -4,33 +4,29 @@
#pragma once
#include <ie_core.hpp>
#include <memory>
#include <string>
#include <vector>
#include <ie_core.hpp>
#include "executable.hpp"
#include "ngraph/runtime/tensor.hpp"
namespace ngraph
{
namespace runtime
{
namespace ie
{
// A Inference Engine executable object produced by compiling an nGraph function.
class IE_Executable final : public Executable
{
public:
IE_Executable(std::shared_ptr<Function> func, std::string device);
virtual ~IE_Executable() {}
bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override final;
namespace ngraph {
namespace runtime {
namespace ie {
// A Inference Engine executable object produced by compiling an nGraph function.
class IE_Executable final : public Executable {
public:
IE_Executable(std::shared_ptr<Function> func, std::string device);
virtual ~IE_Executable() {}
bool call(const std::vector<std::shared_ptr<runtime::Tensor>>& outputs,
const std::vector<std::shared_ptr<runtime::Tensor>>& inputs) override final;
private:
InferenceEngine::CNNNetwork m_network;
std::string m_device;
};
}
}
}
private:
InferenceEngine::CNNNetwork m_network;
std::string m_device;
};
} // namespace ie
} // namespace runtime
} // namespace ngraph

View File

@@ -2,11 +2,12 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ie_tensor.hpp"
#include <cstring>
#include <memory>
#include <utility>
#include "ie_tensor.hpp"
#include "ngraph/check.hpp"
#include "ngraph/except.hpp"
#include "ngraph/util.hpp"
@@ -15,25 +16,18 @@ using namespace ngraph;
using namespace std;
runtime::ie::IETensor::IETensor(const element::Type& element_type, const PartialShape& shape)
: runtime::Tensor(make_shared<descriptor::Tensor>(element_type, shape, ""))
{
}
: runtime::Tensor(make_shared<descriptor::Tensor>(element_type, shape, "")) {}
runtime::ie::IETensor::IETensor(const element::Type& element_type, const Shape& shape)
: runtime::Tensor(make_shared<descriptor::Tensor>(element_type, shape, ""))
, m_data(shape_size(shape) * element_type.size())
{
}
: runtime::Tensor(make_shared<descriptor::Tensor>(element_type, shape, "")),
m_data(shape_size(shape) * element_type.size()) {}
void runtime::ie::IETensor::write(const void* src, size_t bytes)
{
void runtime::ie::IETensor::write(const void* src, size_t bytes) {
const int8_t* src_ptr = static_cast<const int8_t*>(src);
if (src_ptr == nullptr)
{
if (src_ptr == nullptr) {
return;
}
if (get_partial_shape().is_dynamic())
{
if (get_partial_shape().is_dynamic()) {
m_data = AlignedBuffer(bytes);
}
NGRAPH_CHECK(bytes <= m_data.size(),
@@ -44,11 +38,9 @@ void runtime::ie::IETensor::write(const void* src, size_t bytes)
copy(src_ptr, src_ptr + bytes, m_data.get_ptr<int8_t>());
}
void runtime::ie::IETensor::read(void* dst, size_t bytes) const
{
void runtime::ie::IETensor::read(void* dst, size_t bytes) const {
int8_t* dst_ptr = static_cast<int8_t*>(dst);
if (dst_ptr == nullptr)
{
if (dst_ptr == nullptr) {
return;
}
NGRAPH_CHECK(bytes <= m_data.size(),
@@ -59,7 +51,6 @@ void runtime::ie::IETensor::read(void* dst, size_t bytes) const
copy(m_data.get_ptr<int8_t>(), m_data.get_ptr<int8_t>() + bytes, dst_ptr);
}
const void* runtime::ie::IETensor::get_data_ptr() const
{
const void* runtime::ie::IETensor::get_data_ptr() const {
return m_data.get_ptr();
}

View File

@@ -11,43 +11,39 @@
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
namespace ngraph
{
namespace runtime
{
namespace ie
{
class IE_BACKEND_API IETensor : public ngraph::runtime::Tensor
{
public:
IETensor(const ngraph::element::Type& element_type, const Shape& shape);
IETensor(const ngraph::element::Type& element_type, const PartialShape& shape);
namespace ngraph {
namespace runtime {
namespace ie {
class IE_BACKEND_API IETensor : public ngraph::runtime::Tensor {
public:
IETensor(const ngraph::element::Type& element_type, const Shape& shape);
IETensor(const ngraph::element::Type& element_type, const PartialShape& shape);
///
/// \brief Write bytes directly into the tensor
///
/// \param src Pointer to source of data
/// \param bytes Number of bytes to write, must be integral number of
/// elements.
///
void write(const void* src, size_t bytes) override;
///
/// \brief Write bytes directly into the tensor
///
/// \param src Pointer to source of data
/// \param bytes Number of bytes to write, must be integral number of
/// elements.
///
void write(const void* src, size_t bytes) override;
///
/// \brief Read bytes directly from the tensor
///
/// \param dst Pointer to destination for data
/// \param bytes Number of bytes to read, must be integral number of elements.
///
void read(void* dst, size_t bytes) const override;
///
/// \brief Read bytes directly from the tensor
///
/// \param dst Pointer to destination for data
/// \param bytes Number of bytes to read, must be integral number of elements.
///
void read(void* dst, size_t bytes) const override;
const void* get_data_ptr() const;
const void* get_data_ptr() const;
private:
IETensor(const IETensor&) = delete;
IETensor(IETensor&&) = delete;
IETensor& operator=(const IETensor&) = delete;
AlignedBuffer m_data;
};
} // namespace ie
} // namespace runtime
} // namespace ngraph
private:
IETensor(const IETensor&) = delete;
IETensor(IETensor&&) = delete;
IETensor& operator=(const IETensor&) = delete;
AlignedBuffer m_data;
};
} // namespace ie
} // namespace runtime
} // namespace ngraph

View File

@@ -1229,8 +1229,9 @@ struct InfoForEDROIFeature {
Shape output_rois_shape;
};
InfoForEDROIFeature get_info_for_ed_roi_feature(const std::vector<Shape> input_shapes,
const op::v6::ExperimentalDetectronROIFeatureExtractor::Attributes& attrs) {
InfoForEDROIFeature get_info_for_ed_roi_feature(
const std::vector<Shape> input_shapes,
const op::v6::ExperimentalDetectronROIFeatureExtractor::Attributes& attrs) {
InfoForEDROIFeature result;
size_t output_size = static_cast<size_t>(attrs.output_size);
@@ -1249,13 +1250,12 @@ InfoForEDROIFeature get_info_for_ed_roi_feature(const std::vector<Shape> input_s
return result;
}
} // namespace experimental_roi_feature
} // namespace experimental_roi_feature
template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v6::ExperimentalDetectronROIFeatureExtractor>& op,
const HostTensorVector& outputs,
const HostTensorVector& inputs)
{
const HostTensorVector& inputs) {
const auto attrs = op->get_attrs();
std::vector<std::vector<float>> input_data;

View File

@@ -6,18 +6,14 @@
#include "int_backend_visibility.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace runtime
{
namespace interpreter
{
using EvaluatorsMap =
std::map<ngraph::NodeTypeInfo,
std::function<bool(const std::shared_ptr<ngraph::Node>& node,
const ngraph::HostTensorVector& outputs,
const ngraph::HostTensorVector& inputs)>>;
EvaluatorsMap& get_evaluators_map();
}
}
}
namespace ngraph {
namespace runtime {
namespace interpreter {
using EvaluatorsMap = std::map<ngraph::NodeTypeInfo,
std::function<bool(const std::shared_ptr<ngraph::Node>& node,
const ngraph::HostTensorVector& outputs,
const ngraph::HostTensorVector& inputs)>>;
EvaluatorsMap& get_evaluators_map();
} // namespace interpreter
} // namespace runtime
} // namespace ngraph

View File

@@ -2,10 +2,10 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "int_backend_visibility.hpp"
#include "int_backend.hpp"
#include "backend_manager.hpp"
#include "int_backend.hpp"
#include "int_backend_visibility.hpp"
#include "int_executable.hpp"
#include "ngraph/except.hpp"
#include "ngraph/runtime/host_tensor.hpp"
@@ -14,8 +14,7 @@
using namespace std;
using namespace ngraph;
extern "C" INTERPRETER_BACKEND_API void ngraph_register_interpreter_backend()
{
extern "C" INTERPRETER_BACKEND_API void ngraph_register_interpreter_backend() {
runtime::BackendManager::register_backend("INTERPRETER", [](const std::string& /* config */) {
return std::make_shared<runtime::interpreter::INTBackend>();
});
@@ -24,53 +23,42 @@ extern "C" INTERPRETER_BACKEND_API void ngraph_register_interpreter_backend()
runtime::interpreter::INTBackend::INTBackend() {}
runtime::interpreter::INTBackend::INTBackend(const vector<string>& unsupported_op_name_list)
: m_unsupported_op_name_list{unsupported_op_name_list.begin(), unsupported_op_name_list.end()}
{
}
: m_unsupported_op_name_list{unsupported_op_name_list.begin(), unsupported_op_name_list.end()} {}
shared_ptr<runtime::Tensor> runtime::interpreter::INTBackend::create_tensor()
{
shared_ptr<runtime::Tensor> runtime::interpreter::INTBackend::create_tensor() {
return make_shared<runtime::HostTensor>();
}
shared_ptr<runtime::Tensor>
runtime::interpreter::INTBackend::create_tensor(const element::Type& type, const Shape& shape)
{
shared_ptr<runtime::Tensor> runtime::interpreter::INTBackend::create_tensor(const element::Type& type,
const Shape& shape) {
return make_shared<runtime::HostTensor>(type, shape);
}
shared_ptr<runtime::Tensor>
runtime::interpreter::INTBackend::create_dynamic_tensor(const element::Type& type,
const PartialShape& pshape)
{
shared_ptr<runtime::Tensor> runtime::interpreter::INTBackend::create_dynamic_tensor(const element::Type& type,
const PartialShape& pshape) {
return make_shared<runtime::HostTensor>(type, pshape);
}
shared_ptr<runtime::Tensor> runtime::interpreter::INTBackend::create_tensor(
const element::Type& type, const Shape& shape, void* memory_pointer)
{
shared_ptr<runtime::Tensor> runtime::interpreter::INTBackend::create_tensor(const element::Type& type,
const Shape& shape,
void* memory_pointer) {
return make_shared<runtime::HostTensor>(type, shape, memory_pointer);
}
shared_ptr<runtime::Executable>
runtime::interpreter::INTBackend::compile(shared_ptr<Function> function,
bool enable_performance_collection)
{
shared_ptr<runtime::Executable> runtime::interpreter::INTBackend::compile(shared_ptr<Function> function,
bool enable_performance_collection) {
return make_shared<INTExecutable>(function, enable_performance_collection);
}
bool runtime::interpreter::INTBackend::is_supported(const Node& node) const
{
bool runtime::interpreter::INTBackend::is_supported(const Node& node) const {
return m_unsupported_op_name_list.find(node.description()) == m_unsupported_op_name_list.end();
}
bool runtime::interpreter::INTBackend::set_config(const map<string, string>& config, string& error)
{
bool runtime::interpreter::INTBackend::set_config(const map<string, string>& config, string& error) {
bool rc = false;
auto it = config.find("test_echo");
error = "";
if (it != config.end())
{
if (it != config.end()) {
error = it->second;
rc = true;
}

View File

@@ -10,26 +10,21 @@
#include <string>
#include <vector>
#include "int_backend_visibility.hpp"
#include "backend.hpp"
#include "backend_manager.hpp"
#include "int_backend_visibility.hpp"
#include "ngraph/runtime/tensor.hpp"
namespace ngraph
{
namespace runtime
{
namespace interpreter
{
class INTBackend;
class INTExecutable;
}
}
}
namespace ngraph {
namespace runtime {
namespace interpreter {
class INTBackend;
class INTExecutable;
} // namespace interpreter
} // namespace runtime
} // namespace ngraph
class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTBackend : public Backend
{
class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTBackend : public Backend {
public:
INTBackend();
INTBackend(const std::vector<std::string>& unsupported_op_name_list);
@@ -39,12 +34,10 @@ public:
std::shared_ptr<Tensor> create_tensor() override;
std::shared_ptr<Tensor>
create_tensor(const element::Type& type, const Shape& shape, void* memory_pointer) override;
std::shared_ptr<Tensor> create_tensor(const element::Type& type, const Shape& shape, void* memory_pointer) override;
std::shared_ptr<Tensor> create_tensor(const element::Type& type, const Shape& shape) override;
std::shared_ptr<Tensor> create_dynamic_tensor(const element::Type& type,
const PartialShape& shape) override;
std::shared_ptr<Tensor> create_dynamic_tensor(const element::Type& type, const PartialShape& shape) override;
std::shared_ptr<Executable> compile(std::shared_ptr<Function> function,
bool enable_performance_data = false) override;

View File

@@ -8,10 +8,10 @@
// INTERPRETER_API is used for the public API symbols. It either DLL imports or DLL exports
// (or does nothing for static build)
#ifdef INTERPRETER_BACKEND_EXPORTS // defined if we are building the INTERPRETER DLL (instead of
#ifdef INTERPRETER_BACKEND_EXPORTS // defined if we are building the INTERPRETER DLL (instead of
// using
// it)
#define INTERPRETER_BACKEND_API NGRAPH_HELPER_DLL_EXPORT
# define INTERPRETER_BACKEND_API NGRAPH_HELPER_DLL_EXPORT
#else
#define INTERPRETER_BACKEND_API NGRAPH_HELPER_DLL_IMPORT
#endif // INTERPRETER_DLL_EXPORTS
# define INTERPRETER_BACKEND_API NGRAPH_HELPER_DLL_IMPORT
#endif // INTERPRETER_DLL_EXPORTS

View File

@@ -3,7 +3,9 @@
//
#include "int_executable.hpp"
#include <cstring>
#include "backend_manager.hpp"
#include "evaluates_map.hpp"
#include "ngraph/except.hpp"
@@ -17,28 +19,22 @@ using namespace ngraph;
NGRAPH_SUPPRESS_DEPRECATED_START
class TemporaryOverrideOutputs
{
class TemporaryOverrideOutputs {
std::shared_ptr<Node> node;
std::vector<PartialShape> orig_shapes;
public:
TemporaryOverrideOutputs(std::shared_ptr<Node> node,
const std::vector<std::shared_ptr<HostTensor>>& args)
: node(node)
{
for (size_t i = 0; i < args.size(); ++i)
{
TemporaryOverrideOutputs(std::shared_ptr<Node> node, const std::vector<std::shared_ptr<HostTensor>>& args)
: node(node) {
for (size_t i = 0; i < args.size(); ++i) {
auto output = node->get_input_source_output(i);
orig_shapes.push_back(output.get_partial_shape());
output.get_tensor().set_partial_shape(args[i]->get_shape());
}
}
~TemporaryOverrideOutputs()
{
for (size_t i = 0; i < orig_shapes.size(); ++i)
{
~TemporaryOverrideOutputs() {
for (size_t i = 0; i < orig_shapes.size(); ++i) {
auto output = node->get_input_source_output(i);
output.get_tensor().set_partial_shape(orig_shapes[i]);
}
@@ -47,36 +43,30 @@ public:
runtime::interpreter::INTExecutable::INTExecutable(const shared_ptr<Function>& function,
bool enable_performance_collection)
: m_is_compiled{true}
, m_performance_counters_enabled{enable_performance_collection}
{
: m_is_compiled{true},
m_performance_counters_enabled{enable_performance_collection} {
m_function = clone_function(*function);
for (auto node : m_function->get_ordered_ops())
{
for (auto node : m_function->get_ordered_ops()) {
m_nodes.push_back(node);
}
set_parameters_and_results(*m_function);
}
bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::Tensor>>& outputs,
const vector<shared_ptr<runtime::Tensor>>& inputs)
{
const vector<shared_ptr<runtime::Tensor>>& inputs) {
// convert inputs to HostTensor
vector<shared_ptr<HostTensor>> func_inputs;
for (const auto& tensor : inputs)
{
for (const auto& tensor : inputs) {
auto host_tensor = static_pointer_cast<runtime::HostTensor>(tensor);
func_inputs.push_back(host_tensor);
}
if (m_nan_check_enabled)
{
if (m_nan_check_enabled) {
perform_nan_check(func_inputs);
}
// convert outputs to HostTensor
vector<shared_ptr<HostTensor>> func_outputs;
for (const auto& tensor : outputs)
{
for (const auto& tensor : outputs) {
auto host_tensor = static_pointer_cast<runtime::HostTensor>(tensor);
func_outputs.push_back(host_tensor);
}
@@ -84,10 +74,8 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
// map function params -> HostTensor
unordered_map<descriptor::Tensor*, shared_ptr<HostTensor>> tensor_map;
size_t input_count = 0;
for (const auto& param : get_parameters())
{
for (size_t i = 0; i < param->get_output_size(); ++i)
{
for (const auto& param : get_parameters()) {
for (size_t i = 0; i < param->get_output_size(); ++i) {
descriptor::Tensor* tensor = &param->output(i).get_tensor();
tensor_map.insert({tensor, func_inputs[input_count++]});
}
@@ -95,40 +83,34 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
std::unordered_map<std::shared_ptr<ngraph::Node>, size_t> results_map;
// map function outputs -> HostTensor
for (size_t output_count = 0; output_count < get_results().size(); ++output_count)
{
for (size_t output_count = 0; output_count < get_results().size(); ++output_count) {
auto output = get_results()[output_count];
results_map[output] = output_count;
}
// for each ordered op in the graph
for (const auto& op : m_nodes)
{
if (dynamic_pointer_cast<op::Parameter>(op) != nullptr)
{
for (const auto& op : m_nodes) {
if (dynamic_pointer_cast<op::Parameter>(op) != nullptr) {
continue;
}
// get op inputs from map
vector<shared_ptr<HostTensor>> op_inputs;
for (auto input : op->inputs())
{
for (auto input : op->inputs()) {
descriptor::Tensor* tensor = &input.get_tensor();
op_inputs.push_back(tensor_map.at(tensor));
}
TemporaryOverrideOutputs overrider(op, op_inputs);
OutputVector outputs;
for (size_t i = 0; i < op->inputs().size(); ++i)
{
for (size_t i = 0; i < op->inputs().size(); ++i) {
outputs.push_back(op->get_input_source_output(i));
}
auto cloned_node = op->clone_with_new_inputs(outputs);
// get op outputs from map or create
vector<shared_ptr<HostTensor>> op_outputs;
for (size_t i = 0; i < op->get_output_size(); ++i)
{
for (size_t i = 0; i < op->get_output_size(); ++i) {
descriptor::Tensor* tensor = &op->output(i).get_tensor();
shared_ptr<HostTensor> host_tensor;
auto it = tensor_map.find(tensor);
@@ -138,9 +120,7 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
// Use cloned_node to create HostTensor with static dimensions
host_tensor = make_shared<HostTensor>(cloned_node->output(i));
tensor_map.insert({tensor, host_tensor});
}
else
{
} else {
host_tensor = it->second;
}
op_outputs.push_back(host_tensor);
@@ -148,39 +128,30 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
// get op type
element::Type type;
if (ov::is_type<op::Convert>(op) || ov::is_type<op::PriorBox>(op))
{
if (ov::is_type<op::Convert>(op) || ov::is_type<op::PriorBox>(op)) {
type = op->get_input_element_type(0);
}
else if (ov::is_type<op::v1::Equal>(op) || ov::is_type<op::v1::Greater>(op) ||
ov::is_type<op::v1::GreaterEqual>(op) || ov::is_type<op::v1::Less>(op) ||
ov::is_type<op::v1::LessEqual>(op) || ov::is_type<op::v1::NotEqual>(op))
{
} else if (ov::is_type<op::v1::Equal>(op) || ov::is_type<op::v1::Greater>(op) ||
ov::is_type<op::v1::GreaterEqual>(op) || ov::is_type<op::v1::Less>(op) ||
ov::is_type<op::v1::LessEqual>(op) || ov::is_type<op::v1::NotEqual>(op)) {
// Get the type of the second input, not the first
// All BinaryElementwiseComparision ops have the same type for inputs
// Select has bool for first input and the type we are interested in for the second
type = op->get_input_element_type(1);
}
else
{
} else {
type = op->get_output_element_type(0);
}
if (m_performance_counters_enabled)
{
if (m_performance_counters_enabled) {
m_timer_map[op].start();
}
// Call evaluate for cloned_node with static shapes
if (!cloned_node->evaluate(op_outputs, op_inputs))
{
if (!cloned_node->evaluate(op_outputs, op_inputs)) {
evaluate_node(cloned_node, op_outputs, op_inputs);
}
if (m_performance_counters_enabled)
{
if (m_performance_counters_enabled) {
m_timer_map[op].stop();
}
if (m_nan_check_enabled)
{
if (m_nan_check_enabled) {
perform_nan_check(op_outputs, op.get());
}
}
@@ -188,58 +159,38 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
return true;
}
vector<runtime::PerformanceCounter>
runtime::interpreter::INTExecutable::get_performance_data() const
{
vector<runtime::PerformanceCounter> runtime::interpreter::INTExecutable::get_performance_data() const {
vector<runtime::PerformanceCounter> rc;
for (const pair<shared_ptr<const Node>, stopwatch> p : m_timer_map)
{
for (const pair<shared_ptr<const Node>, stopwatch> p : m_timer_map) {
rc.emplace_back(p.first, p.second.get_total_microseconds(), p.second.get_call_count());
}
return rc;
}
void runtime::interpreter::INTExecutable::perform_nan_check(
const vector<shared_ptr<HostTensor>>& tensors, const Node* op)
{
void runtime::interpreter::INTExecutable::perform_nan_check(const vector<shared_ptr<HostTensor>>& tensors,
const Node* op) {
size_t arg_number = 1;
for (const shared_ptr<HostTensor>& tensor : tensors)
{
for (const shared_ptr<HostTensor>& tensor : tensors) {
const element::Type& type = tensor->get_element_type();
if (type == element::f32)
{
if (type == element::f32) {
const float* data = tensor->get_data_ptr<float>();
for (size_t i = 0; i < tensor->get_element_count(); i++)
{
if (std::isnan(data[i]))
{
if (op)
{
for (size_t i = 0; i < tensor->get_element_count(); i++) {
if (std::isnan(data[i])) {
if (op) {
throw runtime_error("nan found in op '" + op->get_name() + "' output");
}
else
{
throw runtime_error("nan found in function's input tensor number " +
to_string(arg_number));
} else {
throw runtime_error("nan found in function's input tensor number " + to_string(arg_number));
}
}
}
}
else if (type == element::f64)
{
} else if (type == element::f64) {
const double* data = tensor->get_data_ptr<double>();
for (size_t i = 0; i < tensor->get_element_count(); i++)
{
if (std::isnan(data[i]))
{
if (op)
{
for (size_t i = 0; i < tensor->get_element_count(); i++) {
if (std::isnan(data[i])) {
if (op) {
throw runtime_error("nan found in op '" + op->get_name() + "' output");
}
else
{
throw runtime_error("nan found in function's input tensor number " +
to_string(arg_number));
} else {
throw runtime_error("nan found in function's input tensor number " + to_string(arg_number));
}
}
}
@@ -248,72 +199,56 @@ void runtime::interpreter::INTExecutable::perform_nan_check(
}
}
shared_ptr<ngraph::op::Parameter>
runtime::interpreter::INTExecutable::get_parameter(size_t index) const
{
shared_ptr<ngraph::op::Parameter> runtime::interpreter::INTExecutable::get_parameter(size_t index) const {
const ParameterVector& parameters = get_parameters();
NGRAPH_CHECK(index < parameters.size(), "create_tensor for input out of bounds");
return parameters[index];
}
shared_ptr<ngraph::op::Result> runtime::interpreter::INTExecutable::get_result(size_t index) const
{
shared_ptr<ngraph::op::Result> runtime::interpreter::INTExecutable::get_result(size_t index) const {
const ResultVector& results = get_results();
NGRAPH_CHECK(index < results.size(), "create_tensor for input out of bounds");
return results[index];
}
shared_ptr<runtime::Tensor>
runtime::interpreter::INTExecutable::create_input_tensor(size_t input_index)
{
shared_ptr<runtime::Tensor> runtime::interpreter::INTExecutable::create_input_tensor(size_t input_index) {
shared_ptr<op::Parameter> parameter = get_parameter(input_index);
return make_shared<runtime::HostTensor>(parameter->get_element_type(), parameter->get_shape());
}
shared_ptr<runtime::Tensor>
runtime::interpreter::INTExecutable::create_output_tensor(size_t output_index)
{
shared_ptr<runtime::Tensor> runtime::interpreter::INTExecutable::create_output_tensor(size_t output_index) {
shared_ptr<op::Result> result = get_result(output_index);
return make_shared<runtime::HostTensor>(result->get_element_type(), result->get_shape());
}
vector<shared_ptr<runtime::Tensor>>
runtime::interpreter::INTExecutable::create_input_tensor(size_t input_index,
size_t pipeline_depth)
{
vector<shared_ptr<runtime::Tensor>> runtime::interpreter::INTExecutable::create_input_tensor(size_t input_index,
size_t pipeline_depth) {
vector<shared_ptr<runtime::HostTensor>> tensors;
shared_ptr<op::Parameter> parameter = get_parameter(input_index);
for (size_t i = 0; i < pipeline_depth; i++)
{
for (size_t i = 0; i < pipeline_depth; i++) {
shared_ptr<runtime::HostTensor> tensor;
auto t =
make_shared<runtime::HostTensor>(parameter->get_element_type(), parameter->get_shape());
auto t = make_shared<runtime::HostTensor>(parameter->get_element_type(), parameter->get_shape());
tensor = static_pointer_cast<runtime::HostTensor>(t);
tensors.push_back(tensor);
}
vector<shared_ptr<runtime::Tensor>> result_tensors;
for (const shared_ptr<runtime::HostTensor>& tensor : tensors)
{
for (const shared_ptr<runtime::HostTensor>& tensor : tensors) {
result_tensors.push_back(tensor);
}
return result_tensors;
}
vector<shared_ptr<runtime::Tensor>>
runtime::interpreter::INTExecutable::create_output_tensor(size_t output_index,
size_t pipeline_depth)
{
vector<shared_ptr<runtime::Tensor>> runtime::interpreter::INTExecutable::create_output_tensor(size_t output_index,
size_t pipeline_depth) {
vector<shared_ptr<runtime::HostTensor>> tensors;
shared_ptr<op::Result> result = get_result(output_index);
for (size_t i = 0; i < pipeline_depth; i++)
{
for (size_t i = 0; i < pipeline_depth; i++) {
shared_ptr<runtime::HostTensor> tensor;
auto t = make_shared<runtime::HostTensor>(result->get_element_type(), result->get_shape());
tensor = static_pointer_cast<runtime::HostTensor>(t);
tensors.push_back(tensor);
}
vector<shared_ptr<runtime::Tensor>> result_tensors;
for (const shared_ptr<runtime::HostTensor>& tensor : tensors)
{
for (const shared_ptr<runtime::HostTensor>& tensor : tensors) {
result_tensors.push_back(tensor);
}
return result_tensors;
@@ -321,25 +256,19 @@ vector<shared_ptr<runtime::Tensor>>
bool runtime::interpreter::INTExecutable::evaluate_node(const std::shared_ptr<Node>& node,
const HostTensorVector& outputs,
const HostTensorVector& inputs) const
{
const HostTensorVector& inputs) const {
auto& map = runtime::interpreter::get_evaluators_map();
auto it = map.find(node->get_type_info());
bool res = false;
if (it != map.end())
{
if (it != map.end()) {
res = it->second(node, outputs, inputs);
if (!res)
{
throw ngraph_error(std::string("Running evaluate method for OP ") +
node->get_type_info().name + std::string(" failed!"));
if (!res) {
throw ngraph_error(std::string("Running evaluate method for OP ") + node->get_type_info().name +
std::string(" failed!"));
}
}
else
{
throw unsupported_op(
std::string("Interpreter backend doesn't implement evaluate method for OP ") +
node->get_type_info().name);
} else {
throw unsupported_op(std::string("Interpreter backend doesn't implement evaluate method for OP ") +
node->get_type_info().name);
}
return res;
}

View File

@@ -3,8 +3,8 @@
//
#ifndef NGRAPH_OP
#warning "NGRAPH_OP not defined"
#define NGRAPH_OP(x, y)
# warning "NGRAPH_OP not defined"
# define NGRAPH_OP(x, y)
#endif
NGRAPH_OP(Abs, op::v0)

View File

@@ -8,44 +8,41 @@
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace pass
{
/// \brief The DynElimination pass finds dynamic operations in a graph whose
/// shape relevant inputs have already been resolved to static values, and
/// replaces those dynamic operations with the equivalent operations using
/// static inputs and attributes.
/// \details This pass should be executed after the ConstantFolding pass.
///
/// The ConstantFolding and DynElimination passes are used together to transform
/// dynamic operations in a computation graph to static operations when the
/// graph is executed with input data.
///
/// In the example shown below, the original graph is constructed with dynamic
/// broadcast operation. When the graph is executed with input data, the input
/// shapes become available, by applying the ConstantFolding and DynElimination
/// pass, the graph is updated with dynamic broadcast being replaced by a static
/// broadcast operation.
/// <table>
/// <tr>
/// <th>Original</th>
/// <th>After %ConstantFolding</th>
/// <th>After %DynElimination</th>
/// </tr>
/// <tr>
/// <td> \image html dyn_broadcast_pre_constfld.svg </td>
/// <td> \image html dyn_broadcast_post_constfld.svg </td>
/// <td> \image html dyn_broadcast_post_dyneliminate.svg </td>
/// </tr>
/// </table>
class BACKEND_API DynElimination : public GraphRewrite
{
public:
DynElimination();
namespace ngraph {
namespace pass {
/// \brief The DynElimination pass finds dynamic operations in a graph whose
/// shape relevant inputs have already been resolved to static values, and
/// replaces those dynamic operations with the equivalent operations using
/// static inputs and attributes.
/// \details This pass should be executed after the ConstantFolding pass.
///
/// The ConstantFolding and DynElimination passes are used together to transform
/// dynamic operations in a computation graph to static operations when the
/// graph is executed with input data.
///
/// In the example shown below, the original graph is constructed with dynamic
/// broadcast operation. When the graph is executed with input data, the input
/// shapes become available, by applying the ConstantFolding and DynElimination
/// pass, the graph is updated with dynamic broadcast being replaced by a static
/// broadcast operation.
/// <table>
/// <tr>
/// <th>Original</th>
/// <th>After %ConstantFolding</th>
/// <th>After %DynElimination</th>
/// </tr>
/// <tr>
/// <td> \image html dyn_broadcast_pre_constfld.svg </td>
/// <td> \image html dyn_broadcast_post_constfld.svg </td>
/// <td> \image html dyn_broadcast_post_dyneliminate.svg </td>
/// </tr>
/// </table>
class BACKEND_API DynElimination : public GraphRewrite {
public:
DynElimination();
private:
void construct_range();
};
}
}
private:
void construct_range();
};
} // namespace pass
} // namespace ngraph

View File

@@ -3,6 +3,7 @@
//
#include "pass/shape_relevance.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/op_types.hpp"
@@ -32,8 +33,7 @@ using namespace ngraph;
// Neither N0 nor N1 will be flagged as shape-relevant. (N1 does feed into the "shape" input of N3,
// but only via the value-irrelevant input of ShapeOf.)
//
bool pass::ShapeRelevance::run_on_function(std::shared_ptr<Function> f)
{
bool pass::ShapeRelevance::run_on_function(std::shared_ptr<Function> f) {
// TODO(amprocte): We are probably reinventing the wheel with the graph traversal here; the
// reason is that we need to cut the traversal short in cases where input values are
// irrelevant. See if there is a way to reduce this duplication.
@@ -43,14 +43,10 @@ bool pass::ShapeRelevance::run_on_function(std::shared_ptr<Function> f)
// Step 1: Find root nodes (these are nodes with an output connected to a shape-relevant
// input).
for (auto& n : f->get_ops())
{
for (auto& output : n->outputs())
{
for (auto& input : output.get_target_inputs())
{
if (input.get_is_relevant_to_shapes())
{
for (auto& n : f->get_ops()) {
for (auto& output : n->outputs()) {
for (auto& input : output.get_target_inputs()) {
if (input.get_is_relevant_to_shapes()) {
shape_determinants.insert(n.get());
break;
}
@@ -67,38 +63,31 @@ bool pass::ShapeRelevance::run_on_function(std::shared_ptr<Function> f)
std::list<Node*> to_visit{shape_determinants.begin(), shape_determinants.end()};
std::set<Node*> already_visited;
while (!to_visit.empty())
{
while (!to_visit.empty()) {
auto node = to_visit.front();
to_visit.pop_front();
if (already_visited.count(node) > 0)
{
if (already_visited.count(node) > 0) {
continue;
}
shape_determinants.insert(node);
already_visited.insert(node);
if (op::is_parameter(node))
{
if (op::is_parameter(node)) {
auto node_as_param = static_cast<op::Parameter*>(node);
if (!node_as_param->is_relevant_to_shapes())
{
if (!node_as_param->is_relevant_to_shapes()) {
node_as_param->set_is_relevant_to_shapes(true);
changes_made = true;
}
}
for (size_t i = 0; i < node->get_input_size(); i++)
{
if (!node->input(i).get_is_relevant_to_values())
{
for (size_t i = 0; i < node->get_input_size(); i++) {
if (!node->input(i).get_is_relevant_to_values()) {
continue;
}
auto source_node = node->get_input_node_ptr(i);
if (already_visited.count(source_node) == 0)
{
if (already_visited.count(source_node) == 0) {
to_visit.push_front(source_node);
}
}

View File

@@ -7,18 +7,12 @@
#include "backend_visibility.hpp"
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class BACKEND_API ShapeRelevance : public FunctionPass
{
public:
ShapeRelevance()
: FunctionPass()
{
}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
}
}
namespace ngraph {
namespace pass {
class BACKEND_API ShapeRelevance : public FunctionPass {
public:
ShapeRelevance() : FunctionPass() {}
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
} // namespace pass
} // namespace ngraph

View File

@@ -10,29 +10,29 @@
#include "backend_visibility.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace runtime
{
class BACKEND_API PerformanceCounter
{
public:
PerformanceCounter(const std::shared_ptr<const Node>& n, size_t us, size_t calls)
: m_node(n)
, m_total_microseconds(us)
, m_call_count(calls)
{
}
std::shared_ptr<const Node> get_node() const { return m_node; }
size_t total_microseconds() const { return m_total_microseconds; }
size_t microseconds() const
{
return m_call_count == 0 ? 0 : m_total_microseconds / m_call_count;
}
size_t call_count() const { return m_call_count; }
std::shared_ptr<const Node> m_node;
size_t m_total_microseconds;
size_t m_call_count;
};
namespace ngraph {
namespace runtime {
class BACKEND_API PerformanceCounter {
public:
PerformanceCounter(const std::shared_ptr<const Node>& n, size_t us, size_t calls)
: m_node(n),
m_total_microseconds(us),
m_call_count(calls) {}
std::shared_ptr<const Node> get_node() const {
return m_node;
}
}
size_t total_microseconds() const {
return m_total_microseconds;
}
size_t microseconds() const {
return m_call_count == 0 ? 0 : m_total_microseconds / m_call_count;
}
size_t call_count() const {
return m_call_count;
}
std::shared_ptr<const Node> m_node;
size_t m_total_microseconds;
size_t m_call_count;
};
} // namespace runtime
} // namespace ngraph

View File

@@ -32,3 +32,6 @@ target_link_libraries(ngraph_test_util PUBLIC ngraph ngraph_backend gtest gmock)
if (NGRAPH_ONNX_FRONTEND_ENABLE)
target_link_libraries(ngraph_test_util PRIVATE onnx_common)
endif()
file(GLOB_RECURSE util_src "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/*.hpp")
add_clang_format_target(ngraph_test_util_clang FOR_SOURCES ${util_src})

View File

@@ -14,162 +14,107 @@
#include "random.hpp"
#include "test_tools.hpp"
namespace ngraph
{
namespace test
{
/// \brief Same as numpy.allclose
/// \param a First tensor to compare
/// \param b Second tensor to compare
/// \param rtol Relative tolerance
/// \param atol Absolute tolerance
/// \returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value, ::testing::AssertionResult>::type
all_close(const std::vector<T>& a,
const std::vector<T>& b,
T rtol = static_cast<T>(1e-5),
T atol = static_cast<T>(1e-8))
{
bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
if (a.size() != b.size())
{
throw std::invalid_argument("all_close: Argument vectors' sizes do not match");
}
size_t count = 0;
for (size_t i = 0; i < a.size(); ++i)
{
if (std::abs(a[i] - b[i]) > atol + rtol * std::abs(b[i]) || !std::isfinite(a[i]) ||
!std::isfinite(b[i]))
{
if (count < 5)
{
ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< a[i] << " is not close to " << b[i] << " at index " << i
<< std::endl;
}
count++;
rc = false;
}
}
ar_fail << "diff count: " << count << " out of " << a.size() << std::endl;
return rc ? ::testing::AssertionSuccess() : ar_fail;
}
/// \brief Same as numpy.allclose
/// \param a First tensor to compare
/// \param b Second tensor to compare
/// \param rtol Relative tolerance
/// \param atol Absolute tolerance
/// \returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T>
typename std::enable_if<std::is_integral<T>::value, ::testing::AssertionResult>::type
all_close(const std::vector<T>& a,
const std::vector<T>& b,
T rtol = static_cast<T>(1e-5),
T atol = static_cast<T>(1e-8))
{
bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
if (a.size() != b.size())
{
throw std::invalid_argument("all_close: Argument vectors' sizes do not match");
}
for (size_t i = 0; i < a.size(); ++i)
{
T abs_diff = (a[i] > b[i]) ? (a[i] - b[i]) : (b[i] - a[i]);
if (abs_diff > atol + rtol * b[i])
{
// use unary + operator to force integral values to be displayed as numbers
ar_fail << +a[i] << " is not close to " << +b[i] << " at index " << i
<< std::endl;
rc = false;
}
}
return rc ? ::testing::AssertionSuccess() : ar_fail;
}
/// \brief Same as numpy.allclose
/// \param a First tensor to compare
/// \param b Second tensor to compare
/// \param rtol Relative tolerance
/// \param atol Absolute tolerance
/// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T>
::testing::AssertionResult all_close(const std::shared_ptr<ngraph::runtime::Tensor>& a,
const std::shared_ptr<ngraph::runtime::Tensor>& b,
T rtol = 1e-5f,
T atol = 1e-8f)
{
if (a->get_shape() != b->get_shape())
{
return ::testing::AssertionFailure()
<< "Cannot compare tensors with different shapes";
}
return all_close(read_vector<T>(a), read_vector<T>(b), rtol, atol);
}
/// \brief Same as numpy.allclose
/// \param as First tensors to compare
/// \param bs Second tensors to compare
/// \param rtol Relative tolerance
/// \param atol Absolute tolerance
/// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T>
::testing::AssertionResult
all_close(const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& as,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& bs,
T rtol,
T atol)
{
if (as.size() != bs.size())
{
return ::testing::AssertionFailure()
<< "Cannot compare tensors with different sizes";
}
for (size_t i = 0; i < as.size(); ++i)
{
auto ar = all_close(as[i], bs[i], rtol, atol);
if (!ar)
{
return ar;
}
}
return ::testing::AssertionSuccess();
}
} // namespace test
} // namespace ngraph
// apply pass, execute and compare with INTERPRETER using random data
template <typename T, typename TIN, typename TOUT = TIN>
bool compare_pass_int(std::shared_ptr<ngraph::Function>& baseline_f,
std::shared_ptr<ngraph::Function>& optimized_f,
std::vector<std::vector<TIN>> args = std::vector<std::vector<TIN>>{})
{
ngraph::pass::Manager pass_manager;
pass_manager.register_pass<ngraph::pass::Validate>();
pass_manager.register_pass<T>();
pass_manager.run_passes(optimized_f);
if (args.size() == 0)
{
for (auto& p : baseline_f->get_parameters())
{
args.emplace_back(shape_size(p->get_shape()), 0);
if (std::is_integral<TIN>())
{
std::generate(args.back().begin(), args.back().end(), rand);
}
else
{
static ngraph::test::Uniform<float> rng{0, 1, 0};
rng.initialize(args.back());
namespace ngraph {
namespace test {
/// \brief Same as numpy.allclose
/// \param a First tensor to compare
/// \param b Second tensor to compare
/// \param rtol Relative tolerance
/// \param atol Absolute tolerance
/// \returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value, ::testing::AssertionResult>::type all_close(
const std::vector<T>& a,
const std::vector<T>& b,
T rtol = static_cast<T>(1e-5),
T atol = static_cast<T>(1e-8)) {
bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
if (a.size() != b.size()) {
throw std::invalid_argument("all_close: Argument vectors' sizes do not match");
}
size_t count = 0;
for (size_t i = 0; i < a.size(); ++i) {
if (std::abs(a[i] - b[i]) > atol + rtol * std::abs(b[i]) || !std::isfinite(a[i]) || !std::isfinite(b[i])) {
if (count < 5) {
ar_fail << std::setprecision(std::numeric_limits<long double>::digits10 + 1) << a[i]
<< " is not close to " << b[i] << " at index " << i << std::endl;
}
count++;
rc = false;
}
}
auto baseline_results = execute<TIN, TOUT>(baseline_f, args, "INTERPRETER");
auto optimized_results = execute<TIN, TOUT>(optimized_f, args, "INTERPRETER");
return ngraph::test::all_close(baseline_results.at(0), optimized_results.at(0));
ar_fail << "diff count: " << count << " out of " << a.size() << std::endl;
return rc ? ::testing::AssertionSuccess() : ar_fail;
}
/// \brief Same as numpy.allclose
/// \param a First tensor to compare
/// \param b Second tensor to compare
/// \param rtol Relative tolerance
/// \param atol Absolute tolerance
/// \returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T>
typename std::enable_if<std::is_integral<T>::value, ::testing::AssertionResult>::type all_close(
const std::vector<T>& a,
const std::vector<T>& b,
T rtol = static_cast<T>(1e-5),
T atol = static_cast<T>(1e-8)) {
bool rc = true;
::testing::AssertionResult ar_fail = ::testing::AssertionFailure();
if (a.size() != b.size()) {
throw std::invalid_argument("all_close: Argument vectors' sizes do not match");
}
for (size_t i = 0; i < a.size(); ++i) {
T abs_diff = (a[i] > b[i]) ? (a[i] - b[i]) : (b[i] - a[i]);
if (abs_diff > atol + rtol * b[i]) {
// use unary + operator to force integral values to be displayed as numbers
ar_fail << +a[i] << " is not close to " << +b[i] << " at index " << i << std::endl;
rc = false;
}
}
return rc ? ::testing::AssertionSuccess() : ar_fail;
}
/// \brief Same as numpy.allclose
/// \param a First tensor to compare
/// \param b Second tensor to compare
/// \param rtol Relative tolerance
/// \param atol Absolute tolerance
/// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T>
::testing::AssertionResult all_close(const std::shared_ptr<ngraph::runtime::Tensor>& a,
const std::shared_ptr<ngraph::runtime::Tensor>& b,
T rtol = 1e-5f,
T atol = 1e-8f) {
if (a->get_shape() != b->get_shape()) {
return ::testing::AssertionFailure() << "Cannot compare tensors with different shapes";
}
return all_close(read_vector<T>(a), read_vector<T>(b), rtol, atol);
}
/// \brief Same as numpy.allclose
/// \param as First tensors to compare
/// \param bs Second tensors to compare
/// \param rtol Relative tolerance
/// \param atol Absolute tolerance
/// Returns true if shapes match and for all elements, |a_i-b_i| <= atol + rtol*|b_i|.
template <typename T>
::testing::AssertionResult all_close(const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& as,
const std::vector<std::shared_ptr<ngraph::runtime::Tensor>>& bs,
T rtol,
T atol) {
if (as.size() != bs.size()) {
return ::testing::AssertionFailure() << "Cannot compare tensors with different sizes";
}
for (size_t i = 0; i < as.size(); ++i) {
auto ar = all_close(as[i], bs[i], rtol, atol);
if (!ar) {
return ar;
}
}
return ::testing::AssertionSuccess();
}
} // namespace test
} // namespace ngraph

View File

@@ -2,12 +2,13 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "util/all_close_f.hpp"
#include <climits>
#include <cmath>
#include "ngraph/env_util.hpp"
#include "ngraph/util.hpp"
#include "util/all_close_f.hpp"
using namespace std;
using namespace ngraph;
@@ -27,20 +28,13 @@ constexpr uint32_t FLOAT_MAX_DIFF = UINT_MAX - 1;
constexpr uint64_t DOUBLE_BELOW_MIN_SIGNAL = ULLONG_MAX;
constexpr uint64_t DOUBLE_MAX_DIFF = ULLONG_MAX - 1;
uint32_t test::float_distance(float a, float b, float min_signal)
{
if (std::isnan(a) && std::isnan(b))
{
uint32_t float_distance(float a, float b, float min_signal) {
if (std::isnan(a) && std::isnan(b)) {
return 0;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
} else if (std::isinf(a) && std::isinf(b)) {
if (a > 0 && b > 0) {
return 0;
}
else if (a < 0 && b < 0)
{
} else if (a < 0 && b < 0) {
return 0;
}
return FLOAT_MAX_DIFF;
@@ -64,17 +58,13 @@ uint32_t test::float_distance(float a, float b, float min_signal)
uint32_t a_uint_abs = (abs_value_bits_mask & a_fu.i);
uint32_t b_uint_abs = (abs_value_bits_mask & b_fu.i);
uint32_t min_signal_uint_abs = (abs_value_bits_mask & min_signal_fu.i);
if ((a_uint_abs < min_signal_uint_abs) && (b_uint_abs < min_signal_uint_abs))
{
if ((a_uint_abs < min_signal_uint_abs) && (b_uint_abs < min_signal_uint_abs)) {
// Both a & b below minimum signal
distance = FLOAT_BELOW_MIN_SIGNAL;
}
else
{
} else {
distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
// We've reserved UINT_MAX to mean FLOAT_BELOW_MIN_SIGNAL
if (distance == UINT_MAX)
{
if (distance == UINT_MAX) {
distance = FLOAT_MAX_DIFF;
}
}
@@ -82,20 +72,13 @@ uint32_t test::float_distance(float a, float b, float min_signal)
return distance;
}
uint64_t test::float_distance(double a, double b, double min_signal)
{
if (std::isnan(a) && std::isnan(b))
{
uint64_t float_distance(double a, double b, double min_signal) {
if (std::isnan(a) && std::isnan(b)) {
return 0;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
} else if (std::isinf(a) && std::isinf(b)) {
if (a > 0 && b > 0) {
return 0;
}
else if (a < 0 && b < 0)
{
} else if (a < 0 && b < 0) {
return 0;
}
return DOUBLE_MAX_DIFF;
@@ -119,17 +102,13 @@ uint64_t test::float_distance(double a, double b, double min_signal)
uint64_t a_uint_abs = (abs_value_bits_mask & a_du.i);
uint64_t b_uint_abs = (abs_value_bits_mask & b_du.i);
uint64_t min_signal_uint_abs = (abs_value_bits_mask & min_signal_du.i);
if ((a_uint_abs < min_signal_uint_abs) && (b_uint_abs < min_signal_uint_abs))
{
if ((a_uint_abs < min_signal_uint_abs) && (b_uint_abs < min_signal_uint_abs)) {
// Both a & b below minimum signal
distance = DOUBLE_BELOW_MIN_SIGNAL;
}
else
{
} else {
distance = (a_uint >= b_uint) ? (a_uint - b_uint) : (b_uint - a_uint);
// We've reserved ULLONG_MAX to mean DOUBLE_BELOW_MIN_SIGNAL
if (distance == ULLONG_MAX)
{
if (distance == ULLONG_MAX) {
distance = DOUBLE_MAX_DIFF;
}
}
@@ -137,20 +116,13 @@ uint64_t test::float_distance(double a, double b, double min_signal)
return distance;
}
bool test::close_f(float a, float b, int tolerance_bits, float min_signal)
{
if (std::isnan(a) && std::isnan(b))
{
bool test::close_f(float a, float b, int tolerance_bits, float min_signal) {
if (std::isnan(a) && std::isnan(b)) {
return true;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
} else if (std::isinf(a) && std::isinf(b)) {
if (a > 0 && b > 0) {
return true;
}
else if (a < 0 && b < 0)
{
} else if (a < 0 && b < 0) {
return true;
}
return false;
@@ -167,20 +139,13 @@ bool test::close_f(float a, float b, int tolerance_bits, float min_signal)
return (distance <= tolerance) || (distance == FLOAT_BELOW_MIN_SIGNAL);
}
bool test::close_f(double a, double b, int tolerance_bits, double min_signal)
{
if (std::isnan(a) && std::isnan(b))
{
bool test::close_f(double a, double b, int tolerance_bits, double min_signal) {
if (std::isnan(a) && std::isnan(b)) {
return true;
}
else if (std::isinf(a) && std::isinf(b))
{
if (a > 0 && b > 0)
{
} else if (std::isinf(a) && std::isinf(b)) {
if (a > 0 && b > 0) {
return true;
}
else if (a < 0 && b < 0)
{
} else if (a < 0 && b < 0) {
return true;
}
return false;
@@ -197,49 +162,38 @@ bool test::close_f(double a, double b, int tolerance_bits, double min_signal)
return (distance <= tolerance) || (distance == DOUBLE_BELOW_MIN_SIGNAL);
}
vector<uint32_t>
test::float_distances(const vector<float>& a, const vector<float>& b, float min_signal)
{
if (a.size() != b.size())
{
vector<uint32_t> float_distances(const vector<float>& a, const vector<float>& b, float min_signal) {
if (a.size() != b.size()) {
throw ngraph_error("a.size() != b.size() for float_distances comparison.");
}
vector<uint32_t> distances(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
for (size_t i = 0; i < a.size(); ++i) {
distances[i] = float_distance(a[i], b[i], min_signal);
}
return distances;
}
vector<uint64_t>
test::float_distances(const vector<double>& a, const vector<double>& b, double min_signal)
{
if (a.size() != b.size())
{
vector<uint64_t> float_distances(const vector<double>& a, const vector<double>& b, double min_signal) {
if (a.size() != b.size()) {
throw ngraph_error("a.size() != b.size() for float_distances comparison.");
}
vector<uint64_t> distances(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
for (size_t i = 0; i < a.size(); ++i) {
distances[i] = float_distance(a[i], b[i], min_signal);
}
return distances;
}
uint32_t test::matching_mantissa_bits(uint32_t distance)
{
uint32_t matching_mantissa_bits(uint32_t distance) {
uint32_t tolerance_bit_shift = 0;
uint32_t num_bits_on = 0;
// Do some bit probing to find the most significant bit that's on,
// as well as how many bits are on.
for (uint32_t check_bit = 0; check_bit < 32; ++check_bit)
{
if (distance & (1 << check_bit))
{
for (uint32_t check_bit = 0; check_bit < 32; ++check_bit) {
if (distance & (1 << check_bit)) {
tolerance_bit_shift = check_bit;
++num_bits_on;
}
@@ -247,8 +201,7 @@ uint32_t test::matching_mantissa_bits(uint32_t distance)
// all_close_f is <= test for tolerance (where tolerance is uint32_t with single bit on)
// So if more than one bit is on we need the next higher tolerance
if (num_bits_on > 1)
{
if (num_bits_on > 1) {
++tolerance_bit_shift;
}
@@ -263,22 +216,18 @@ uint32_t test::matching_mantissa_bits(uint32_t distance)
// tolerance_bit_shift = 32 - (1 + 8 + (matching_matissa_bits - 1 ) )
// matching_matissa_bits = 32 - (1 + 8 + (tolerance_bit_shift - 1 ) )
// clang-format on
uint32_t matching_matissa_bits =
tolerance_bit_shift < 24 ? (32 - (1 + 8 + (tolerance_bit_shift - 1))) : 0;
uint32_t matching_matissa_bits = tolerance_bit_shift < 24 ? (32 - (1 + 8 + (tolerance_bit_shift - 1))) : 0;
return matching_matissa_bits;
}
uint32_t test::matching_mantissa_bits(uint64_t distance)
{
uint32_t matching_mantissa_bits(uint64_t distance) {
uint32_t tolerance_bit_shift = 0;
uint32_t num_bits_on = 0;
// Do some bit probing to find the most significant bit that's on,
// as well as how many bits are on.
for (uint32_t check_bit = 0; check_bit < 64; ++check_bit)
{
if (distance & (1ull << check_bit))
{
for (uint32_t check_bit = 0; check_bit < 64; ++check_bit) {
if (distance & (1ull << check_bit)) {
tolerance_bit_shift = check_bit;
++num_bits_on;
}
@@ -286,8 +235,7 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
// all_close_f is <= test for tolerance (where tolerance is uint64_t with single bit on)
// So if more than one bit is on we need the next higher tolerance
if (num_bits_on > 1)
{
if (num_bits_on > 1) {
++tolerance_bit_shift;
}
@@ -302,33 +250,27 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
// tolerance_bit_shift = 64 - (1 + 11 + (matching_matissa_bits - 1 ) )
// matching_matissa_bits = 64 - (1 + 11 + (tolerance_bit_shift - 1 ) )
// clang-format on
uint32_t matching_matissa_bits =
tolerance_bit_shift < 53 ? (64 - (1 + 11 + (tolerance_bit_shift - 1))) : 0;
uint32_t matching_matissa_bits = tolerance_bit_shift < 53 ? (64 - (1 + 11 + (tolerance_bit_shift - 1))) : 0;
return matching_matissa_bits;
}
::testing::AssertionResult test::all_close_f(const vector<float>& a,
const vector<float>& b,
int tolerance_bits,
float min_signal)
{
if (tolerance_bits < MIN_FLOAT_TOLERANCE_BITS)
{
float min_signal) {
if (tolerance_bits < MIN_FLOAT_TOLERANCE_BITS) {
tolerance_bits = MIN_FLOAT_TOLERANCE_BITS;
}
if (tolerance_bits >= FLOAT_MANTISSA_BITS)
{
if (tolerance_bits >= FLOAT_MANTISSA_BITS) {
tolerance_bits = FLOAT_MANTISSA_BITS - 1;
}
bool rc = true;
stringstream msg;
if (a.size() != b.size())
{
if (a.size() != b.size()) {
return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
}
if (a.size() == 0)
{
if (a.size() == 0) {
return ::testing::AssertionSuccess() << "No elements to compare";
}
vector<uint32_t> distances = float_distances(a, b, min_signal);
@@ -344,97 +286,79 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
size_t min_distance_index = 0;
size_t diff_count = 0;
size_t below_min_count = 0;
for (size_t i = 0; i < a.size(); ++i)
{
if (distances[i] == FLOAT_BELOW_MIN_SIGNAL)
{
for (size_t i = 0; i < a.size(); ++i) {
if (distances[i] == FLOAT_BELOW_MIN_SIGNAL) {
// Special value that indicates both values were below min_signal
below_min_count++;
continue;
}
if (distances[i] > max_distance)
{
if (distances[i] > max_distance) {
max_distance = distances[i];
max_distance_index = i;
}
if (distances[i] < min_distance)
{
if (distances[i] < min_distance) {
min_distance = distances[i];
min_distance_index = i;
}
bool is_close_f = distances[i] <= tolerance;
if (!is_close_f)
{
if (diff_count < 5)
{
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1) << a[i]
<< " is not close to " << b[i] << " at index " << i << std::endl;
if (!is_close_f) {
if (diff_count < 5) {
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1) << a[i] << " is not close to "
<< b[i] << " at index " << i << std::endl;
}
rc = false;
diff_count++;
}
}
if (!rc)
{
if (!rc) {
msg << "diff count: " << diff_count << " out of " << a.size() << std::endl;
}
// Find median value via partial sorting
size_t middle = distances.size() / 2;
std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
uint32_t median_distance = distances[middle];
if (distances.size() % 2 == 0)
{
if (distances.size() % 2 == 0) {
// Find middle-1 value
uint64_t median_sum = static_cast<uint64_t>(median_distance) +
*max_element(distances.begin(), distances.begin() + middle);
uint64_t median_sum =
static_cast<uint64_t>(median_distance) + *max_element(distances.begin(), distances.begin() + middle);
median_distance = median_sum / 2;
}
bool all_below_min_signal = below_min_count == distances.size();
if (rc && (getenv_bool("NGRAPH_GTEST_INFO")))
{
if (rc && (getenv_bool("NGRAPH_GTEST_INFO"))) {
// Short unobtrusive message when passing
std::cout << "[ INFO ] Verifying match of <= " << (FLOAT_MANTISSA_BITS - tolerance_bits)
<< " mantissa bits (" << FLOAT_MANTISSA_BITS << " bits precision - "
<< tolerance_bits << " tolerance). ";
if (all_below_min_signal)
{
<< " mantissa bits (" << FLOAT_MANTISSA_BITS << " bits precision - " << tolerance_bits
<< " tolerance). ";
if (all_below_min_signal) {
std::cout << "All values below min_signal: " << min_signal << std::endl;
}
else
{
std::cout << below_min_count << " value(s) below min_signal: " << min_signal
<< " Loosest match found is " << matching_mantissa_bits(max_distance)
<< " mantissa bits.\n";
} else {
std::cout << below_min_count << " value(s) below min_signal: " << min_signal << " Loosest match found is "
<< matching_mantissa_bits(max_distance) << " mantissa bits.\n";
}
}
msg << "passing criteria - mismatch allowed @ mantissa bit: "
<< (FLOAT_MANTISSA_BITS - tolerance_bits) << " or later (" << tolerance_bits
<< " tolerance bits)\n";
if (all_below_min_signal)
{
msg << "passing criteria - mismatch allowed @ mantissa bit: " << (FLOAT_MANTISSA_BITS - tolerance_bits)
<< " or later (" << tolerance_bits << " tolerance bits)\n";
if (all_below_min_signal) {
msg << "All values below min_signal: " << min_signal << std::endl;
}
else
{
} else {
msg << below_min_count << " value(s) below min_signal: " << min_signal << std::endl;
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(min_distance) << " or next bit (" << a[min_distance_index]
<< " vs " << b[min_distance_index] << " at [" << min_distance_index << "])\n";
<< "tightest match - mismatch occurred @ mantissa bit: " << matching_mantissa_bits(min_distance)
<< " or next bit (" << a[min_distance_index] << " vs " << b[min_distance_index] << " at ["
<< min_distance_index << "])\n";
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "loosest match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(max_distance) << " or next bit (" << a[max_distance_index]
<< " vs " << b[max_distance_index] << " at [" << max_distance_index << "])\n";
msg << "median match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(median_distance) << " or next bit\n";
<< "loosest match - mismatch occurred @ mantissa bit: " << matching_mantissa_bits(max_distance)
<< " or next bit (" << a[max_distance_index] << " vs " << b[max_distance_index] << " at ["
<< max_distance_index << "])\n";
msg << "median match - mismatch occurred @ mantissa bit: " << matching_mantissa_bits(median_distance)
<< " or next bit\n";
}
::testing::AssertionResult res =
rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
::testing::AssertionResult res = rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
res << msg.str();
return res;
}
@@ -442,25 +366,20 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
::testing::AssertionResult test::all_close_f(const vector<double>& a,
const vector<double>& b,
int tolerance_bits,
double min_signal)
{
if (tolerance_bits < 0)
{
double min_signal) {
if (tolerance_bits < 0) {
tolerance_bits = 0;
}
if (tolerance_bits >= DOUBLE_MANTISSA_BITS)
{
if (tolerance_bits >= DOUBLE_MANTISSA_BITS) {
tolerance_bits = DOUBLE_MANTISSA_BITS - 1;
}
bool rc = true;
stringstream msg;
if (a.size() != b.size())
{
if (a.size() != b.size()) {
return ::testing::AssertionFailure() << "a.size() != b.size() for all_close_f comparison.";
}
if (a.size() == 0)
{
if (a.size() == 0) {
return ::testing::AssertionSuccess() << "No elements to compare";
}
vector<uint64_t> distances = float_distances(a, b, min_signal);
@@ -476,30 +395,24 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
size_t min_distance_index = 0;
size_t diff_count = 0;
size_t below_min_count = 0;
for (size_t i = 0; i < a.size(); ++i)
{
if (distances[i] == DOUBLE_BELOW_MIN_SIGNAL)
{
for (size_t i = 0; i < a.size(); ++i) {
if (distances[i] == DOUBLE_BELOW_MIN_SIGNAL) {
// Special value that indicates both values were below min_signal
below_min_count++;
continue;
}
if (distances[i] > max_distance)
{
if (distances[i] > max_distance) {
max_distance = distances[i];
max_distance_index = i;
}
if (distances[i] < min_distance)
{
if (distances[i] < min_distance) {
min_distance = distances[i];
min_distance_index = i;
}
bool is_close_f = distances[i] <= tolerance;
if (!is_close_f)
{
if (diff_count < 5)
{
if (!is_close_f) {
if (diff_count < 5) {
msg << a[i] << " is not close to " << b[i] << " at index " << i << std::endl;
}
@@ -507,67 +420,53 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
diff_count++;
}
}
if (!rc)
{
if (!rc) {
msg << "diff count: " << diff_count << " out of " << a.size() << std::endl;
}
// Find median value via partial sorting
size_t middle = distances.size() / 2;
std::nth_element(distances.begin(), distances.begin() + middle, distances.end());
uint64_t median_distance = distances[middle];
if (distances.size() % 2 == 0)
{
if (distances.size() % 2 == 0) {
uint64_t median_distance2 = *max_element(distances.begin(), distances.begin() + middle);
uint64_t remainder1 = median_distance % 2;
uint64_t remainder2 = median_distance2 % 2;
median_distance =
(median_distance / 2) + (median_distance2 / 2) + ((remainder1 + remainder2) / 2);
median_distance = (median_distance / 2) + (median_distance2 / 2) + ((remainder1 + remainder2) / 2);
}
bool all_below_min_signal = below_min_count == distances.size();
if (rc && (getenv_bool("NGRAPH_GTEST_INFO")))
{
if (rc && (getenv_bool("NGRAPH_GTEST_INFO"))) {
// Short unobtrusive message when passing
std::cout << "[ INFO ] Verifying match of >= "
<< (DOUBLE_MANTISSA_BITS - tolerance_bits) << " mantissa bits ("
<< DOUBLE_MANTISSA_BITS << " bits precision - " << tolerance_bits
std::cout << "[ INFO ] Verifying match of >= " << (DOUBLE_MANTISSA_BITS - tolerance_bits)
<< " mantissa bits (" << DOUBLE_MANTISSA_BITS << " bits precision - " << tolerance_bits
<< " tolerance). ";
if (all_below_min_signal)
{
if (all_below_min_signal) {
std::cout << "All values below min_signal: " << min_signal << std::endl;
}
else
{
std::cout << below_min_count << " value(s) below min_signal: " << min_signal
<< " Loosest match found is " << matching_mantissa_bits(max_distance)
<< " mantissa bits.\n";
} else {
std::cout << below_min_count << " value(s) below min_signal: " << min_signal << " Loosest match found is "
<< matching_mantissa_bits(max_distance) << " mantissa bits.\n";
}
}
msg << "passing criteria - mismatch allowed @ mantissa bit: "
<< (DOUBLE_MANTISSA_BITS - tolerance_bits) << " or later (" << tolerance_bits
<< " tolerance bits)\n";
if (all_below_min_signal)
{
msg << "passing criteria - mismatch allowed @ mantissa bit: " << (DOUBLE_MANTISSA_BITS - tolerance_bits)
<< " or later (" << tolerance_bits << " tolerance bits)\n";
if (all_below_min_signal) {
msg << "All values below min_signal: " << min_signal << std::endl;
}
else
{
} else {
msg << below_min_count << " value(s) below min_signal: " << min_signal << std::endl;
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "tightest match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(min_distance) << " or next bit (" << a[min_distance_index]
<< " vs " << b[min_distance_index] << " at [" << min_distance_index << "])\n";
<< "tightest match - mismatch occurred @ mantissa bit: " << matching_mantissa_bits(min_distance)
<< " or next bit (" << a[min_distance_index] << " vs " << b[min_distance_index] << " at ["
<< min_distance_index << "])\n";
msg << std::setprecision(std::numeric_limits<long double>::digits10 + 1)
<< "loosest match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(max_distance) << " or next bit (" << a[max_distance_index]
<< " vs " << b[max_distance_index] << " at [" << max_distance_index << "])\n";
msg << "median match - mismatch occurred @ mantissa bit: "
<< matching_mantissa_bits(median_distance) << " or next bit\n";
<< "loosest match - mismatch occurred @ mantissa bit: " << matching_mantissa_bits(max_distance)
<< " or next bit (" << a[max_distance_index] << " vs " << b[max_distance_index] << " at ["
<< max_distance_index << "])\n";
msg << "median match - mismatch occurred @ mantissa bit: " << matching_mantissa_bits(median_distance)
<< " or next bit\n";
}
::testing::AssertionResult res =
rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
::testing::AssertionResult res = rc ? ::testing::AssertionSuccess() : ::testing::AssertionFailure();
res << msg.str();
return res;
}
@@ -575,33 +474,25 @@ uint32_t test::matching_mantissa_bits(uint64_t distance)
::testing::AssertionResult test::all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b,
int tolerance_bits,
float min_signal)
{
float min_signal) {
// Check that the layouts are compatible
if (a->get_shape() != b->get_shape())
{
if (a->get_shape() != b->get_shape()) {
return ::testing::AssertionFailure() << "Cannot compare tensors with different shapes";
}
return test::all_close_f(
read_float_vector(a), read_float_vector(b), tolerance_bits, min_signal);
return test::all_close_f(read_float_vector(a), read_float_vector(b), tolerance_bits, min_signal);
}
::testing::AssertionResult
test::all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int tolerance_bits,
float min_signal)
{
if (as.size() != bs.size())
{
::testing::AssertionResult test::all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int tolerance_bits,
float min_signal) {
if (as.size() != bs.size()) {
return ::testing::AssertionFailure() << "Cannot compare tensors with different sizes";
}
for (size_t i = 0; i < as.size(); ++i)
{
for (size_t i = 0; i < as.size(); ++i) {
auto ar = test::all_close_f(as[i], bs[i], tolerance_bits, min_signal);
if (!ar)
{
if (!ar) {
return ar;
}
}

View File

@@ -16,12 +16,12 @@ static constexpr int DOUBLE_MANTISSA_BITS = 53;
// Maximum available float bits
#ifndef MAX_FLOAT_BITS
#define MAX_FLOAT_BITS FLOAT_MANTISSA_BITS
# define MAX_FLOAT_BITS FLOAT_MANTISSA_BITS
#endif
// Minimum float tolerance bits possible
#ifndef MIN_FLOAT_TOLERANCE_BITS
#define MIN_FLOAT_TOLERANCE_BITS (FLOAT_MANTISSA_BITS - MAX_FLOAT_BITS)
# define MIN_FLOAT_TOLERANCE_BITS (FLOAT_MANTISSA_BITS - MAX_FLOAT_BITS)
#endif
static_assert((MAX_FLOAT_BITS > 0) && (MAX_FLOAT_BITS <= FLOAT_MANTISSA_BITS),
@@ -31,203 +31,112 @@ static_assert((MIN_FLOAT_TOLERANCE_BITS >= 0) && (MIN_FLOAT_TOLERANCE_BITS < FLO
// Default float tolerance bits
#ifndef DEFAULT_FLOAT_TOLERANCE_BITS
#define DEFAULT_FLOAT_TOLERANCE_BITS (MIN_FLOAT_TOLERANCE_BITS + 2)
# define DEFAULT_FLOAT_TOLERANCE_BITS (MIN_FLOAT_TOLERANCE_BITS + 2)
#endif
// Default float tolerance bits
#ifndef DEFAULT_DOUBLE_TOLERANCE_BITS
#define DEFAULT_DOUBLE_TOLERANCE_BITS 2
# define DEFAULT_DOUBLE_TOLERANCE_BITS 2
#endif
static_assert((DEFAULT_FLOAT_TOLERANCE_BITS >= 0) &&
(DEFAULT_FLOAT_TOLERANCE_BITS < FLOAT_MANTISSA_BITS),
static_assert((DEFAULT_FLOAT_TOLERANCE_BITS >= 0) && (DEFAULT_FLOAT_TOLERANCE_BITS < FLOAT_MANTISSA_BITS),
"DEFAULT_FLOAT_TOLERANCE_BITS must be in range [0, 24)");
static_assert((DEFAULT_DOUBLE_TOLERANCE_BITS >= 0) &&
(DEFAULT_DOUBLE_TOLERANCE_BITS < DOUBLE_MANTISSA_BITS),
static_assert((DEFAULT_DOUBLE_TOLERANCE_BITS >= 0) && (DEFAULT_DOUBLE_TOLERANCE_BITS < DOUBLE_MANTISSA_BITS),
"DEFAULT_DOUBLE_TOLERANCE_BITS must be in range [0, 53)");
namespace ngraph
{
namespace test
{
// clang-format off
/// \brief Determine distance between two f32 numbers
/// \param a First number to compare
/// \param b Second number to compare
/// \param min_signal Minimum value for comparisons
/// \returns Distance
///
/// References:
/// - https://en.wikipedia.org/wiki/Unit_in_the_last_place
/// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2
/// - https://github.com/google/googletest/blob/master/googletest/docs/AdvancedGuide.md#floating-point-comparison
///
/// s e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m
/// |------------bfloat-----------|
/// |----------------------------float----------------------------|
///
/// bfloat (s1, e8, m7) has 7 + 1 = 8 bits of mantissa or bit_precision
/// float (s1, e8, m23) has 23 + 1 = 24 bits of mantissa or bit_precision
///
/// This function uses hard-coded value of 8 bit exponent_bits, so it's only valid for
/// bfloat and f32.
// clang-format on
uint32_t float_distance(float a, float b, float min_signal = 0.0f);
namespace ngraph {
namespace test {
// clang-format off
/// \brief Determine distance between two f64 numbers
/// \param a First number to compare
/// \param b Second number to compare
/// \param min_signal Minimum value for comparisons
/// \returns Distance
///
/// References:
/// - https://en.wikipedia.org/wiki/Unit_in_the_last_place
/// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2
/// - https://github.com/google/googletest/blob/master/googletest/docs/AdvancedGuide.md#floating-point-comparison
///
/// s e e e e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m
/// |----------------------------double-------------------------------------------------------------------------------------------|
///
/// double (s1, e11, m52) has 52 + 1 = 53 bits of mantissa or bit_precision
///
/// This function uses hard-coded value of 11 bit exponent_bits, so it's only valid for f64.
// clang-format on
uint64_t float_distance(double a, double b, double min_signal = 0.0);
// clang-format off
/// \brief Check if the two f32 numbers are close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns True iff the distance between a and b is within 2 ^ tolerance_bits ULP
///
/// References:
/// - https://en.wikipedia.org/wiki/Unit_in_the_last_place
/// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2
/// - https://github.com/abseil/googletest/blob/master/googletest/docs/advanced.md#floating-point-comparison
///
/// s e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m
/// |------------bfloat-----------|
/// |----------------------------float----------------------------|
///
/// bfloat (s1, e8, m7) has 7 + 1 = 8 bits of mantissa or bit_precision
/// float (s1, e8, m23) has 23 + 1 = 24 bits of mantissa or bit_precision
///
/// This function uses hard-coded value of 8 bit exponent_bits, so it's only valid for
/// bfloat and f32.
// clang-format on
bool close_f(float a, float b, int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS, float min_signal = 0.0f);
// clang-format off
/// \brief Check if the two f32 numbers are close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns True iff the distance between a and b is within 2 ^ tolerance_bits ULP
///
/// References:
/// - https://en.wikipedia.org/wiki/Unit_in_the_last_place
/// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2
/// - https://github.com/abseil/googletest/blob/master/googletest/docs/advanced.md#floating-point-comparison
///
/// s e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m
/// |------------bfloat-----------|
/// |----------------------------float----------------------------|
///
/// bfloat (s1, e8, m7) has 7 + 1 = 8 bits of mantissa or bit_precision
/// float (s1, e8, m23) has 23 + 1 = 24 bits of mantissa or bit_precision
///
/// This function uses hard-coded value of 8 bit exponent_bits, so it's only valid for
/// bfloat and f32.
// clang-format on
bool close_f(float a,
float b,
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS,
float min_signal = 0.0f);
// clang-format off
/// \brief Check if the two f64 numbers are close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns True iff the distance between a and b is within 2 ^ tolerance_bits ULP
///
/// References:
/// - https://en.wikipedia.org/wiki/Unit_in_the_last_place
/// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2
/// - https://github.com/abseil/googletest/blob/master/googletest/docs/advanced.md#floating-point-comparison
///
/// s e e e e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m
/// |----------------------------double-------------------------------------------------------------------------------------------|
///
/// double (s1, e11, m52) has 52 + 1 = 53 bits of mantissa or bit_precision
///
/// This function uses hard-coded value of 11 bit exponent_bits, so it's only valid for f64.
// clang-format on
bool close_f(double a, double b, int tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS, double min_signal = 0.0);
// clang-format off
/// \brief Check if the two f64 numbers are close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns True iff the distance between a and b is within 2 ^ tolerance_bits ULP
///
/// References:
/// - https://en.wikipedia.org/wiki/Unit_in_the_last_place
/// - https://randomascii.wordpress.com/2012/01/23/stupid-float-tricks-2
/// - https://github.com/abseil/googletest/blob/master/googletest/docs/advanced.md#floating-point-comparison
///
/// s e e e e e e e e e e e m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m m
/// |----------------------------double-------------------------------------------------------------------------------------------|
///
/// double (s1, e11, m52) has 52 + 1 = 53 bits of mantissa or bit_precision
///
/// This function uses hard-coded value of 11 bit exponent_bits, so it's only valid for f64.
// clang-format on
bool close_f(double a,
double b,
int tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS,
double min_signal = 0.0);
/// \brief Check if the two floating point vectors are all close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns ::testing::AssertionSuccess iff the two floating point vectors are close
::testing::AssertionResult all_close_f(const std::vector<float>& a,
const std::vector<float>& b,
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS,
float min_signal = 0.0f);
/// \brief Determine distances between two vectors of f32 numbers
/// \param a Vector of floats to compare
/// \param b Vector of floats to compare
/// \param min_signal Minimum value for comparisons
/// \returns Vector of distances
///
/// See float_distance for limitations and assumptions.
std::vector<uint32_t> float_distances(const std::vector<float>& a,
const std::vector<float>& b,
float min_signal = 0.0f);
/// \brief Check if the two double floating point vectors are all close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns ::testing::AssertionSuccess iff the two floating point vectors are close
::testing::AssertionResult all_close_f(const std::vector<double>& a,
const std::vector<double>& b,
int tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS,
double min_signal = 0.0);
/// \brief Determine distances between two vectors of f64 numbers
/// \param a Vector of doubles to compare
/// \param b Vector of doubles to compare
/// \param min_signal Minimum value for comparisons
/// \returns Vector of distances
///
/// See float_distance for limitations and assumptions.
std::vector<uint64_t> float_distances(const std::vector<double>& a,
const std::vector<double>& b,
double min_signal = 0.0);
/// \brief Check if the two TensorViews are all close in float
/// \param a First Tensor to compare
/// \param b Second Tensor to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// Returns true iff the two TensorViews are all close in float
::testing::AssertionResult all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b,
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS,
float min_signal = 0.0f);
/// \brief Determine number of matching mantissa bits given a distance
/// \param distance Distance calculated by float_distance
/// \returns Number of matching mantissa bits
///
/// See float_distance for limitations and assumptions.
uint32_t matching_mantissa_bits(uint32_t distance);
/// \brief Determine number of matching mantissa bits given a distance
/// \param distance Distance calculated by float_distance
/// \returns Number of matching mantissa bits
///
/// See float_distance for limitations and assumptions.
uint32_t matching_mantissa_bits(uint64_t distance);
/// \brief Check if the two floating point vectors are all close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns ::testing::AssertionSuccess iff the two floating point vectors are close
::testing::AssertionResult all_close_f(const std::vector<float>& a,
const std::vector<float>& b,
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS,
float min_signal = 0.0f);
/// \brief Check if the two double floating point vectors are all close
/// \param a First number to compare
/// \param b Second number to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// \returns ::testing::AssertionSuccess iff the two floating point vectors are close
::testing::AssertionResult all_close_f(const std::vector<double>& a,
const std::vector<double>& b,
int tolerance_bits = DEFAULT_DOUBLE_TOLERANCE_BITS,
double min_signal = 0.0);
/// \brief Check if the two TensorViews are all close in float
/// \param a First Tensor to compare
/// \param b Second Tensor to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// Returns true iff the two TensorViews are all close in float
::testing::AssertionResult all_close_f(const std::shared_ptr<runtime::Tensor>& a,
const std::shared_ptr<runtime::Tensor>& b,
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS,
float min_signal = 0.0f);
/// \brief Check if the two vectors of TensorViews are all close in float
/// \param as First vector of Tensor to compare
/// \param bs Second vector of Tensor to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// Returns true iff the two TensorViews are all close in float
::testing::AssertionResult
all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS,
float min_signal = 0.0f);
}
}
/// \brief Check if the two vectors of TensorViews are all close in float
/// \param as First vector of Tensor to compare
/// \param bs Second vector of Tensor to compare
/// \param tolerance_bits Bit tolerance error
/// \param min_signal Minimum value for comparisons
/// Returns true iff the two TensorViews are all close in float
::testing::AssertionResult all_close_f(const std::vector<std::shared_ptr<runtime::Tensor>>& as,
const std::vector<std::shared_ptr<runtime::Tensor>>& bs,
int tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS,
float min_signal = 0.0f);
} // namespace test
} // namespace ngraph

View File

@@ -7,53 +7,39 @@
#include "ngraph/function.hpp"
#include "util/engine/engine_traits.hpp"
namespace ngraph
{
namespace test
{
enum class TestCaseType
{
STATIC,
DYNAMIC
};
namespace ngraph {
namespace test {
enum class TestCaseType { STATIC, DYNAMIC };
namespace
{
/// A factory that can create engines supporting devices but not dynamic backends.
/// Currently: IE_CPU_Backend and IE_GPU_Backend
template <typename Engine>
typename std::enable_if<supports_devices<Engine>::value, Engine>::type
create_engine_impl(const std::shared_ptr<ngraph::Function> function,
const TestCaseType)
{
return Engine{function};
}
namespace {
/// A factory that can create engines supporting devices but not dynamic backends.
/// Currently: IE_CPU_Backend and IE_GPU_Backend
template <typename Engine>
typename std::enable_if<supports_devices<Engine>::value, Engine>::type create_engine_impl(
const std::shared_ptr<ngraph::Function> function,
const TestCaseType) {
return Engine{function};
}
/// A factory that can create engines which support dynamic backends
/// but do not support devices. Currently: INTERPRETER_Engine
template <typename Engine>
typename std::enable_if<supports_dynamic<Engine>::value, Engine>::type
create_engine_impl(const std::shared_ptr<ngraph::Function> function,
const TestCaseType tct)
{
if (tct == TestCaseType::DYNAMIC)
{
return Engine::dynamic(function);
}
else
{
return Engine{function};
}
}
}
/// A factory that is able to create all types of test Engines
/// in both static and dynamic mode
template <typename Engine>
Engine create_engine(const std::shared_ptr<ngraph::Function> function,
const TestCaseType tct)
{
return create_engine_impl<Engine>(function, tct);
};
/// A factory that can create engines which support dynamic backends
/// but do not support devices. Currently: INTERPRETER_Engine
template <typename Engine>
typename std::enable_if<supports_dynamic<Engine>::value, Engine>::type create_engine_impl(
const std::shared_ptr<ngraph::Function> function,
const TestCaseType tct) {
if (tct == TestCaseType::DYNAMIC) {
return Engine::dynamic(function);
} else {
return Engine{function};
}
}
} // namespace
/// A factory that is able to create all types of test Engines
/// in both static and dynamic mode
template <typename Engine>
Engine create_engine(const std::shared_ptr<ngraph::Function> function, const TestCaseType tct) {
return create_engine_impl<Engine>(function, tct);
};
} // namespace test
} // namespace ngraph

View File

@@ -4,27 +4,25 @@
#pragma once
namespace ngraph
{
namespace test
{
/// These templates should be specialized for each test engine and they should contain
/// a "static constexpr const bool value" member set to true or false.
/// These traits are used in engine_factory.hpp
namespace ngraph {
namespace test {
/// These templates should be specialized for each test engine and they should contain
/// a "static constexpr const bool value" member set to true or false.
/// These traits are used in engine_factory.hpp
/// Indicates that a given Engine can be constructed for different devices (IE engines)
template <typename Engine>
struct supports_devices;
/// Indicates that a given Engine can be constructed for different devices (IE engines)
template <typename Engine>
struct supports_devices;
/// Indicates that a given Engine supports dynamic shapes
template <typename Engine>
struct supports_dynamic;
/// Indicates that a given Engine supports dynamic shapes
template <typename Engine>
struct supports_dynamic;
/// Example:
///
// template <>
// struct supports_dynamic<EngineName> {
// static constexpr const bool value = true;
// };
}
}
/// Example:
///
// template <>
// struct supports_dynamic<EngineName> {
// static constexpr const bool value = true;
// };
} // namespace test
} // namespace ngraph

View File

@@ -12,257 +12,256 @@ using namespace ngraph;
NGRAPH_SUPPRESS_DEPRECATED_START
namespace
{
/// Extracts the data from two blobs and returns them as a pair of vectors.
template <typename T>
std::pair<std::vector<T>, std::vector<T>>
extract_test_results(InferenceEngine::MemoryBlob::CPtr computed,
InferenceEngine::MemoryBlob::CPtr expected)
{
const auto computed_data = computed->rmap();
const auto expected_data = expected->rmap();
namespace {
/// Extracts the data from two blobs and returns them as a pair of vectors.
template <typename T>
std::pair<std::vector<T>, std::vector<T>> extract_test_results(InferenceEngine::MemoryBlob::CPtr computed,
InferenceEngine::MemoryBlob::CPtr expected) {
const auto computed_data = computed->rmap();
const auto expected_data = expected->rmap();
const auto* computed_data_buffer = computed_data.template as<const T*>();
std::vector<T> computed_values(computed_data_buffer,
computed_data_buffer + computed->size());
const auto* computed_data_buffer = computed_data.template as<const T*>();
std::vector<T> computed_values(computed_data_buffer, computed_data_buffer + computed->size());
switch (static_cast<InferenceEngine::Precision::ePrecision>(expected->getTensorDesc().getPrecision()))
{
case InferenceEngine::Precision::FP32: {
const auto* expected_data_buffer = expected_data.template as<const float *>();
std::vector<T> expected_values(expected_data_buffer,
expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::FP64: {
const auto *expected_data_buffer = expected_data.template as<const double *>();
std::vector<T> expected_values(expected_data_buffer,
expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::I8: {
const auto *expected_data_buffer = expected_data.template as<const int8_t *>();
std::vector<T> expected_values(expected_data_buffer,
expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::I16: {
const auto *expected_data_buffer = expected_data.template as<const int16_t *>();
std::vector<T> expected_values(expected_data_buffer,
expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::I32: {
const auto *expected_data_buffer = expected_data.template as<const int32_t *>();
std::vector<T> expected_values(expected_data_buffer,
expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::I64: {
const auto *expected_data_buffer = expected_data.template as<const int64_t *>();
std::vector<T> expected_values(expected_data_buffer,
expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::U8: {
const auto *expected_data_buffer = expected_data.template as<const uint8_t *>();
std::vector<T> expected_values(expected_data_buffer,
expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::U16: {
const auto *expected_data_buffer = expected_data.template as<const uint16_t *>();
std::vector<T> expected_values(expected_data_buffer,
expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::U32: {
const auto *expected_data_buffer = expected_data.template as<const uint32_t *>();
std::vector<T> expected_values(expected_data_buffer,
expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::U64: {
const auto *expected_data_buffer = expected_data.template as<const uint64_t *>();
std::vector<T> expected_values(expected_data_buffer,
expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::BOOL: {
const auto *expected_data_buffer = expected_data.template as<const uint8_t *>();
std::vector<T> expected_values(expected_data_buffer,
expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::BF16: {
const auto *expected_data_buffer = expected_data.template as<const bfloat16 *>();
std::vector<T> expected_values(expected_data_buffer,
expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
default: THROW_IE_EXCEPTION << "Not implemented yet";
}
switch (static_cast<InferenceEngine::Precision::ePrecision>(expected->getTensorDesc().getPrecision())) {
case InferenceEngine::Precision::FP32: {
const auto* expected_data_buffer = expected_data.template as<const float*>();
std::vector<T> expected_values(expected_data_buffer, expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::FP64: {
const auto* expected_data_buffer = expected_data.template as<const double*>();
std::vector<T> expected_values(expected_data_buffer, expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::I8: {
const auto* expected_data_buffer = expected_data.template as<const int8_t*>();
std::vector<T> expected_values(expected_data_buffer, expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::I16: {
const auto* expected_data_buffer = expected_data.template as<const int16_t*>();
std::vector<T> expected_values(expected_data_buffer, expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::I32: {
const auto* expected_data_buffer = expected_data.template as<const int32_t*>();
std::vector<T> expected_values(expected_data_buffer, expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::I64: {
const auto* expected_data_buffer = expected_data.template as<const int64_t*>();
std::vector<T> expected_values(expected_data_buffer, expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::U8: {
const auto* expected_data_buffer = expected_data.template as<const uint8_t*>();
std::vector<T> expected_values(expected_data_buffer, expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::U16: {
const auto* expected_data_buffer = expected_data.template as<const uint16_t*>();
std::vector<T> expected_values(expected_data_buffer, expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::U32: {
const auto* expected_data_buffer = expected_data.template as<const uint32_t*>();
std::vector<T> expected_values(expected_data_buffer, expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::U64: {
const auto* expected_data_buffer = expected_data.template as<const uint64_t*>();
std::vector<T> expected_values(expected_data_buffer, expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::BOOL: {
const auto* expected_data_buffer = expected_data.template as<const uint8_t*>();
std::vector<T> expected_values(expected_data_buffer, expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
case InferenceEngine::Precision::BF16: {
const auto* expected_data_buffer = expected_data.template as<const bfloat16*>();
std::vector<T> expected_values(expected_data_buffer, expected_data_buffer + computed->size());
return std::make_pair(std::move(computed_values), std::move(expected_values));
break;
}
default:
THROW_IE_EXCEPTION << "Not implemented yet";
}
}
/// Compares two blobs containing floating point elements.
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value, testing::AssertionResult>::type compare_blobs(
InferenceEngine::MemoryBlob::CPtr computed,
InferenceEngine::MemoryBlob::CPtr expected,
const size_t tolerance_bits) {
const auto test_results = extract_test_results<T>(computed, expected);
return ngraph::test::all_close_f(test_results.first, test_results.second, tolerance_bits);
}
/// Compares two blobs containing integer elements.
template <typename T>
typename std::enable_if<std::is_integral<T>::value, testing::AssertionResult>::type
compare_blobs(InferenceEngine::MemoryBlob::CPtr computed, InferenceEngine::MemoryBlob::CPtr expected, const size_t) {
const auto test_results = extract_test_results<T>(computed, expected);
return ngraph::test::all_close<T>(test_results.first, test_results.second);
}
template <typename T>
typename std::enable_if<std::is_class<T>::value, testing::AssertionResult>::type compare_blobs(
InferenceEngine::MemoryBlob::CPtr computed,
InferenceEngine::MemoryBlob::CPtr expected,
const size_t tolerance_bits) {
const auto test_results = extract_test_results<T>(computed, expected);
NGRAPH_CHECK(test_results.first.size() == test_results.second.size(),
"Number of expected and computed results don't match");
std::vector<double> expected_double(test_results.first.size());
std::vector<double> result_double(test_results.second.size());
for (size_t i = 0; i < test_results.first.size(); ++i) {
expected_double[i] = static_cast<double>(test_results.first[i]);
result_double[i] = static_cast<double>(test_results.second[i]);
}
/// Compares two blobs containing floating point elements.
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value, testing::AssertionResult>::type
compare_blobs(InferenceEngine::MemoryBlob::CPtr computed,
InferenceEngine::MemoryBlob::CPtr expected,
const size_t tolerance_bits)
{
const auto test_results = extract_test_results<T>(computed, expected);
return ngraph::test::all_close_f(expected_double, result_double, tolerance_bits);
}
return ngraph::test::all_close_f(test_results.first, test_results.second, tolerance_bits);
/// Compares two blobs elementwise
inline testing::AssertionResult compare_blobs(InferenceEngine::MemoryBlob::CPtr computed,
InferenceEngine::MemoryBlob::CPtr expected,
const size_t tolerance_bits) {
const auto& computed_precision = computed->getTensorDesc().getPrecision();
switch (static_cast<InferenceEngine::Precision::ePrecision>(computed_precision)) {
case InferenceEngine::Precision::FP32:
return compare_blobs<float>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::FP64:
return compare_blobs<double>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::I8:
return compare_blobs<int8_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::I16:
return compare_blobs<int16_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::I32:
return compare_blobs<int32_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::I64:
return compare_blobs<int64_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::U8:
return compare_blobs<uint8_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::U16:
return compare_blobs<uint16_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::U32:
return compare_blobs<uint32_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::U64:
return compare_blobs<uint64_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::BOOL:
return compare_blobs<uint8_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::BF16:
return compare_blobs<bfloat16>(computed, expected, tolerance_bits);
break;
default:
THROW_IE_EXCEPTION << "Not implemented yet";
}
}
}; // namespace
/// Compares two blobs containing integer elements.
template <typename T>
typename std::enable_if<std::is_integral<T>::value, testing::AssertionResult>::type
compare_blobs(InferenceEngine::MemoryBlob::CPtr computed,
InferenceEngine::MemoryBlob::CPtr expected,
const size_t)
{
const auto test_results = extract_test_results<T>(computed, expected);
return ngraph::test::all_close<T>(test_results.first, test_results.second);
}
template <typename T>
typename std::enable_if<std::is_class<T>::value, testing::AssertionResult>::type
compare_blobs(InferenceEngine::MemoryBlob::CPtr computed,
InferenceEngine::MemoryBlob::CPtr expected,
const size_t tolerance_bits)
{
const auto test_results = extract_test_results<T>(computed, expected);
NGRAPH_CHECK(test_results.first.size() == test_results.second.size(),
"Number of expected and computed results don't match");
std::vector<double> expected_double(test_results.first.size());
std::vector<double> result_double(test_results.second.size());
for (size_t i = 0; i < test_results.first.size(); ++i)
{
expected_double[i] = static_cast<double>(test_results.first[i]);
result_double[i] = static_cast<double>(test_results.second[i]);
}
return ngraph::test::all_close_f(expected_double, result_double, tolerance_bits);
}
/// Compares two blobs elementwise
inline testing::AssertionResult compare_blobs(InferenceEngine::MemoryBlob::CPtr computed,
InferenceEngine::MemoryBlob::CPtr expected,
const size_t tolerance_bits)
{
const auto& computed_precision = computed->getTensorDesc().getPrecision();
switch (static_cast<InferenceEngine::Precision::ePrecision>(computed_precision))
{
case InferenceEngine::Precision::FP32:
return compare_blobs<float>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::FP64:
return compare_blobs<double>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::I8:
return compare_blobs<int8_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::I16:
return compare_blobs<int16_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::I32:
return compare_blobs<int32_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::I64:
return compare_blobs<int64_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::U8:
return compare_blobs<uint8_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::U16:
return compare_blobs<uint16_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::U32:
return compare_blobs<uint32_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::U64:
return compare_blobs<uint64_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::BOOL:
return compare_blobs<uint8_t>(computed, expected, tolerance_bits);
break;
case InferenceEngine::Precision::BF16:
return compare_blobs<bfloat16>(computed, expected, tolerance_bits);
break;
default: THROW_IE_EXCEPTION << "Not implemented yet";
}
}
}; // namespace
namespace
{
InferenceEngine::Precision ng_type_to_precission(const element::Type& target_type)
{
namespace {
InferenceEngine::Precision ng_type_to_precission(const element::Type& target_type) {
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (target_type)
{
case element::Type_t::boolean: return InferenceEngine::Precision::BOOL; break;
case element::Type_t::bf16: return InferenceEngine::Precision::BF16; break;
case element::Type_t::f16: return InferenceEngine::Precision::FP16; break;
case element::Type_t::f32: return InferenceEngine::Precision::FP32; break;
case element::Type_t::f64: return InferenceEngine::Precision::FP64; break;
case element::Type_t::i8: return InferenceEngine::Precision::I8; break;
case element::Type_t::i16: return InferenceEngine::Precision::I16; break;
case element::Type_t::i32: return InferenceEngine::Precision::I32; break;
case element::Type_t::i64: return InferenceEngine::Precision::I64; break;
case element::Type_t::u8: return InferenceEngine::Precision::U8; break;
case element::Type_t::u16: return InferenceEngine::Precision::U16; break;
case element::Type_t::u32: return InferenceEngine::Precision::U32; break;
case element::Type_t::u64: return InferenceEngine::Precision::U64; break;
case element::Type_t::u1: return InferenceEngine::Precision::BIN; break;
case element::Type_t::i4:
case element::Type_t::u4:
case element::Type_t::undefined:
case element::Type_t::dynamic: throw std::runtime_error("unsupported type");
}
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
# pragma GCC diagnostic push
# pragma GCC diagnostic error "-Wswitch"
# pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (target_type) {
case element::Type_t::boolean:
return InferenceEngine::Precision::BOOL;
break;
case element::Type_t::bf16:
return InferenceEngine::Precision::BF16;
break;
case element::Type_t::f16:
return InferenceEngine::Precision::FP16;
break;
case element::Type_t::f32:
return InferenceEngine::Precision::FP32;
break;
case element::Type_t::f64:
return InferenceEngine::Precision::FP64;
break;
case element::Type_t::i8:
return InferenceEngine::Precision::I8;
break;
case element::Type_t::i16:
return InferenceEngine::Precision::I16;
break;
case element::Type_t::i32:
return InferenceEngine::Precision::I32;
break;
case element::Type_t::i64:
return InferenceEngine::Precision::I64;
break;
case element::Type_t::u8:
return InferenceEngine::Precision::U8;
break;
case element::Type_t::u16:
return InferenceEngine::Precision::U16;
break;
case element::Type_t::u32:
return InferenceEngine::Precision::U32;
break;
case element::Type_t::u64:
return InferenceEngine::Precision::U64;
break;
case element::Type_t::u1:
return InferenceEngine::Precision::BIN;
break;
case element::Type_t::i4:
case element::Type_t::u4:
case element::Type_t::undefined:
case element::Type_t::dynamic:
throw std::runtime_error("unsupported type");
}
} // namespace
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
# pragma GCC diagnostic pop
#endif
throw std::runtime_error("unsupported type");
}
} // namespace
test::IE_Engine::IE_Engine(const std::shared_ptr<Function> function, const char* device)
: m_function{function}
{
test::IE_Engine::IE_Engine(const std::shared_ptr<Function> function, const char* device) : m_function{function} {
const auto cnn_network = InferenceEngine::CNNNetwork(m_function);
m_network_inputs = cnn_network.getInputsInfo();
m_network_outputs = cnn_network.getOutputsInfo();
for (const auto& result : m_function->get_results())
{
for (const auto& result : m_function->get_results()) {
const auto& out_name = get_output_name(result);
m_network_outputs[out_name]->setPrecision(
ng_type_to_precission(result->get_element_type()));
m_network_outputs[out_name]->setPrecision(ng_type_to_precission(result->get_element_type()));
}
InferenceEngine::Core ie;
@@ -270,35 +269,27 @@ test::IE_Engine::IE_Engine(const std::shared_ptr<Function> function, const char*
m_inference_req = exe_network.CreateInferRequest();
}
void test::IE_Engine::infer()
{
if (m_network_inputs.size() != m_allocated_inputs)
{
IE_THROW() << "The tested graph has " << m_network_inputs.size() << " inputs, but "
<< m_allocated_inputs << " were passed.";
}
else
{
void test::IE_Engine::infer() {
if (m_network_inputs.size() != m_allocated_inputs) {
IE_THROW() << "The tested graph has " << m_network_inputs.size() << " inputs, but " << m_allocated_inputs
<< " were passed.";
} else {
m_inference_req.Infer();
}
}
testing::AssertionResult test::IE_Engine::compare_results(const size_t tolerance_bits)
{
testing::AssertionResult test::IE_Engine::compare_results(const size_t tolerance_bits) {
auto comparison_result = testing::AssertionSuccess();
for (const auto& output : m_network_outputs)
{
for (const auto& output : m_network_outputs) {
InferenceEngine::MemoryBlob::CPtr computed_output_blob =
InferenceEngine::as<InferenceEngine::MemoryBlob>(m_inference_req.GetBlob(output.first));
const auto& expected_output_blob = m_expected_outputs[output.first];
comparison_result =
compare_blobs(computed_output_blob, expected_output_blob, tolerance_bits);
comparison_result = compare_blobs(computed_output_blob, expected_output_blob, tolerance_bits);
if (comparison_result == testing::AssertionFailure())
{
if (comparison_result == testing::AssertionFailure()) {
break;
}
}
@@ -306,19 +297,14 @@ testing::AssertionResult test::IE_Engine::compare_results(const size_t tolerance
return comparison_result;
}
std::string test::IE_Engine::get_output_name(const std::shared_ptr<op::v0::Result>& ng_result)
{
if (m_function->get_results().size() == 1)
{
std::string test::IE_Engine::get_output_name(const std::shared_ptr<op::v0::Result>& ng_result) {
if (m_function->get_results().size() == 1) {
// ng_result argument is ignored
return m_network_outputs.begin()->first;
}
else
{
} else {
const auto& prev_layer = ng_result->input_value(0);
auto network_out_name = prev_layer.get_node_shared_ptr()->get_friendly_name();
if (prev_layer.get_node_shared_ptr()->get_output_size() != 1)
{
if (prev_layer.get_node_shared_ptr()->get_output_size() != 1) {
network_out_name += "." + std::to_string(prev_layer.get_index());
}
@@ -332,15 +318,11 @@ std::string test::IE_Engine::get_output_name(const std::shared_ptr<op::v0::Resul
}
}
testing::AssertionResult
test::IE_Engine::compare_results_with_tolerance_as_fp(const float tolerance)
{
testing::AssertionResult test::IE_Engine::compare_results_with_tolerance_as_fp(const float tolerance) {
auto comparison_result = testing::AssertionSuccess();
for (const auto& output : m_network_outputs)
{
if (comparison_result == testing::AssertionFailure())
{
for (const auto& output : m_network_outputs) {
if (comparison_result == testing::AssertionFailure()) {
break;
}
@@ -349,28 +331,22 @@ testing::AssertionResult
const auto& expected_output_blob = m_expected_outputs[output.first];
switch (expected_output_blob->getTensorDesc().getPrecision())
{
case InferenceEngine::Precision::FP32:
{
const auto test_results =
extract_test_results<float>(computed_output_blob, expected_output_blob);
comparison_result =
test::compare_with_tolerance(test_results.first, test_results.second, tolerance);
switch (expected_output_blob->getTensorDesc().getPrecision()) {
case InferenceEngine::Precision::FP32: {
const auto test_results = extract_test_results<float>(computed_output_blob, expected_output_blob);
comparison_result = test::compare_with_tolerance(test_results.first, test_results.second, tolerance);
break;
}
default:
comparison_result = testing::AssertionFailure()
<< "Unsupported data type encountered in "
"'compare_results_with_tolerance_as_fp' method";
comparison_result = testing::AssertionFailure() << "Unsupported data type encountered in "
"'compare_results_with_tolerance_as_fp' method";
}
}
return comparison_result;
}
std::set<NodeTypeInfo> test::IE_Engine::get_ie_ops() const
{
std::set<NodeTypeInfo> test::IE_Engine::get_ie_ops() const {
std::set<NodeTypeInfo> ie_ops = get_opset1().get_type_info_set();
const auto& opset2 = get_opset2().get_type_info_set();
ie_ops.insert(opset2.begin(), opset2.end());
@@ -389,23 +365,20 @@ std::set<NodeTypeInfo> test::IE_Engine::get_ie_ops() const
return ie_ops;
}
void test::IE_Engine::reset()
{
void test::IE_Engine::reset() {
m_allocated_inputs = 0;
m_allocated_expected_outputs = 0;
m_expected_outputs.clear();
}
namespace InferenceEngine
{
// Without this section the linker is not able to find destructors for missing TBlob
// specializations which are instantiated in the unit tests that use TestCase and this engine
template <typename T, typename U>
TBlob<T, U>::~TBlob()
{
free();
}
namespace InferenceEngine {
// Without this section the linker is not able to find destructors for missing TBlob
// specializations which are instantiated in the unit tests that use TestCase and this engine
template <typename T, typename U>
TBlob<T, U>::~TBlob() {
free();
}
template class TBlob<ngraph::bfloat16>;
template class TBlob<ngraph::float16>;
} // namespace InferenceEngine
template class TBlob<ngraph::bfloat16>;
template class TBlob<ngraph::float16>;
} // namespace InferenceEngine

View File

@@ -5,155 +5,134 @@
#pragma once
#include <ie_core.hpp>
#include "ngraph/function.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
#include "util/engine/engine_traits.hpp"
#include "util/engine/test_case_engine.hpp"
namespace ngraph
{
namespace test
{
/// A generic engine that uses OV objects natively
class IE_Engine : public TestCaseEngine
{
public:
IE_Engine() = delete;
namespace ngraph {
namespace test {
/// A generic engine that uses OV objects natively
class IE_Engine : public TestCaseEngine {
public:
IE_Engine() = delete;
/// Constructs an IE test engine for a given device (plugin)
IE_Engine(const std::shared_ptr<Function> function, const char* device);
/// Constructs an IE test engine for a given device (plugin)
IE_Engine(const std::shared_ptr<Function> function, const char* device);
void infer() override;
void infer() override;
testing::AssertionResult compare_results(
const size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS) override;
testing::AssertionResult compare_results(const size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS) override;
testing::AssertionResult
compare_results_with_tolerance_as_fp(const float tolerance = 1.0e-5f) override;
testing::AssertionResult compare_results_with_tolerance_as_fp(const float tolerance = 1.0e-5f) override;
void reset() override;
void reset() override;
template <typename T>
void add_input(const Shape& shape, const std::vector<T>& values)
{
// Retrieve the next function parameter which has not been set yet.
// The params are stored in a vector in the order of their creation.
const auto& function_params = m_function->get_parameters();
const auto& input_to_allocate = function_params[m_allocated_inputs];
template <typename T>
void add_input(const Shape& shape, const std::vector<T>& values) {
// Retrieve the next function parameter which has not been set yet.
// The params are stored in a vector in the order of their creation.
const auto& function_params = m_function->get_parameters();
const auto& input_to_allocate = function_params[m_allocated_inputs];
NGRAPH_CHECK(
m_network_inputs.count(input_to_allocate->get_friendly_name()) == 1,
"nGraph function's input number ",
m_allocated_inputs,
" was not found in the CNNNetwork built from it. Function's input name: ",
input_to_allocate->get_friendly_name());
NGRAPH_CHECK(m_network_inputs.count(input_to_allocate->get_friendly_name()) == 1,
"nGraph function's input number ",
m_allocated_inputs,
" was not found in the CNNNetwork built from it. Function's input name: ",
input_to_allocate->get_friendly_name());
// Retrieve the corresponding CNNNetwork input using param's friendly name.
// Here the inputs are stored in the map and are accessible by a string key.
const auto& input_info = m_network_inputs[input_to_allocate->get_friendly_name()];
// Retrieve the corresponding CNNNetwork input using param's friendly name.
// Here the inputs are stored in the map and are accessible by a string key.
const auto& input_info = m_network_inputs[input_to_allocate->get_friendly_name()];
auto blob =
std::make_shared<InferenceEngine::TBlob<T>>(input_info->getTensorDesc());
blob->allocate();
auto* blob_buffer = blob->wmap().template as<T*>();
auto blob = std::make_shared<InferenceEngine::TBlob<T>>(input_info->getTensorDesc());
blob->allocate();
auto* blob_buffer = blob->wmap().template as<T*>();
NGRAPH_CHECK(blob->size() == values.size(),
"The allocated blob for input '",
input_to_allocate->get_friendly_name(),
" ' expects ",
blob->size(),
" elements while ",
values.size(),
" were provided.");
NGRAPH_CHECK(blob->size() == values.size(),
"The allocated blob for input '",
input_to_allocate->get_friendly_name(),
" ' expects ",
blob->size(),
" elements while ",
values.size(),
" were provided.");
std::copy(values.begin(), values.end(), blob_buffer);
std::copy(values.begin(), values.end(), blob_buffer);
m_inference_req.SetBlob(input_to_allocate->get_friendly_name(), blob);
m_inference_req.SetBlob(input_to_allocate->get_friendly_name(), blob);
++m_allocated_inputs;
}
++m_allocated_inputs;
}
template <typename T>
void add_expected_output(const ngraph::Shape& expected_shape,
const std::vector<T>& values)
{
const auto& function_output =
m_function->get_results()[m_allocated_expected_outputs];
std::string network_out_name = get_output_name(function_output);
InferenceEngine::DataPtr network_output = m_network_outputs[network_out_name];
template <typename T>
void add_expected_output(const ngraph::Shape& expected_shape, const std::vector<T>& values) {
const auto& function_output = m_function->get_results()[m_allocated_expected_outputs];
std::string network_out_name = get_output_name(function_output);
InferenceEngine::DataPtr network_output = m_network_outputs[network_out_name];
auto blob =
std::make_shared<InferenceEngine::TBlob<T>>(network_output->getTensorDesc());
blob->allocate();
auto blob = std::make_shared<InferenceEngine::TBlob<T>>(network_output->getTensorDesc());
blob->allocate();
NGRAPH_CHECK(blob->size() == values.size(),
"The allocated blob for output '",
network_out_name,
" ' expects ",
blob->size(),
" elements while ",
values.size(),
" were provided.");
NGRAPH_CHECK(blob->size() == values.size(),
"The allocated blob for output '",
network_out_name,
" ' expects ",
blob->size(),
" elements while ",
values.size(),
" were provided.");
auto* blob_buffer = blob->wmap().template as<T*>();
std::copy(values.begin(), values.end(), blob_buffer);
auto* blob_buffer = blob->wmap().template as<T*>();
std::copy(values.begin(), values.end(), blob_buffer);
m_expected_outputs.emplace(network_out_name, blob);
m_expected_outputs.emplace(network_out_name, blob);
++m_allocated_expected_outputs;
}
++m_allocated_expected_outputs;
}
private:
const std::shared_ptr<Function> m_function;
InferenceEngine::InputsDataMap m_network_inputs;
InferenceEngine::OutputsDataMap m_network_outputs;
InferenceEngine::InferRequest m_inference_req;
std::map<std::string, InferenceEngine::MemoryBlob::Ptr> m_expected_outputs;
unsigned int m_allocated_inputs = 0;
unsigned int m_allocated_expected_outputs = 0;
private:
const std::shared_ptr<Function> m_function;
InferenceEngine::InputsDataMap m_network_inputs;
InferenceEngine::OutputsDataMap m_network_outputs;
InferenceEngine::InferRequest m_inference_req;
std::map<std::string, InferenceEngine::MemoryBlob::Ptr> m_expected_outputs;
unsigned int m_allocated_inputs = 0;
unsigned int m_allocated_expected_outputs = 0;
/// Retrieves a set of all ops IE can execute
std::set<NodeTypeInfo> get_ie_ops() const;
/// Retrieves a set of all ops IE can execute
std::set<NodeTypeInfo> get_ie_ops() const;
// Get IE blob which corresponds to result of nG Function
std::string get_output_name(const std::shared_ptr<op::v0::Result>& ng_result);
};
// Get IE blob which corresponds to result of nG Function
std::string get_output_name(const std::shared_ptr<op::v0::Result>& ng_result);
};
class IE_CPU_Engine final : public IE_Engine
{
public:
IE_CPU_Engine(const std::shared_ptr<Function> function)
: IE_Engine{function, m_device}
{
}
class IE_CPU_Engine final : public IE_Engine {
public:
IE_CPU_Engine(const std::shared_ptr<Function> function) : IE_Engine{function, m_device} {}
private:
static constexpr const char* m_device = "CPU";
};
private:
static constexpr const char* m_device = "CPU";
};
class IE_GPU_Engine final : public IE_Engine
{
public:
IE_GPU_Engine(const std::shared_ptr<Function> function)
: IE_Engine{function, m_device}
{
}
class IE_GPU_Engine final : public IE_Engine {
public:
IE_GPU_Engine(const std::shared_ptr<Function> function) : IE_Engine{function, m_device} {}
private:
static constexpr const char* m_device = "GPU";
};
private:
static constexpr const char* m_device = "GPU";
};
template <>
struct supports_devices<IE_CPU_Engine>
{
static constexpr bool value = true;
};
template <>
struct supports_devices<IE_CPU_Engine> {
static constexpr bool value = true;
};
template <>
struct supports_devices<IE_GPU_Engine>
{
static constexpr bool value = true;
};
} // namespace test
} // namespace ngraph
template <>
struct supports_devices<IE_GPU_Engine> {
static constexpr bool value = true;
};
} // namespace test
} // namespace ngraph

View File

@@ -2,113 +2,99 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "interpreter_engine.hpp"
#include <cmath>
#include <iomanip>
#include <sstream>
#include "interpreter_engine.hpp"
#include "shared_utils.hpp"
using namespace ngraph;
namespace
{
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value, testing::AssertionResult>::type
compare_values(const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results,
const size_t tolerance_bits)
{
const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results);
namespace {
template <typename T>
typename std::enable_if<std::is_floating_point<T>::value, testing::AssertionResult>::type compare_values(
const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results,
const size_t tolerance_bits) {
const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results);
return ngraph::test::all_close_f(expected, result, tolerance_bits);
return ngraph::test::all_close_f(expected, result, tolerance_bits);
}
testing::AssertionResult compare_with_fp_tolerance(const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results,
const float tolerance) {
auto comparison_result = testing::AssertionSuccess();
const auto expected = expected_results->get_vector<float>();
const auto result = read_vector<float>(results);
return ngraph::test::compare_with_tolerance(expected, result, tolerance);
}
template <typename T>
typename std::enable_if<std::is_integral<T>::value, testing::AssertionResult>::type compare_values(
const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results,
const size_t) {
const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results);
return ngraph::test::all_close(expected, result);
}
// used for float16 and bfloat 16 comparisons
template <typename T>
typename std::enable_if<std::is_class<T>::value, testing::AssertionResult>::type compare_values(
const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results,
const size_t tolerance_bits) {
const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results);
// TODO: add testing infrastructure for float16 and bfloat16 to avoid cast to double
std::vector<double> expected_double(expected.size());
std::vector<double> result_double(result.size());
NGRAPH_CHECK(expected.size() == result.size(), "Number of expected and computed results don't match");
for (size_t i = 0; i < expected.size(); ++i) {
expected_double[i] = static_cast<double>(expected[i]);
result_double[i] = static_cast<double>(result[i]);
}
testing::AssertionResult
compare_with_fp_tolerance(const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results,
const float tolerance)
{
auto comparison_result = testing::AssertionSuccess();
return ngraph::test::all_close_f(expected_double, result_double, tolerance_bits);
}
}; // namespace
const auto expected = expected_results->get_vector<float>();
const auto result = read_vector<float>(results);
return ngraph::test::compare_with_tolerance(expected, result, tolerance);
}
template <typename T>
typename std::enable_if<std::is_integral<T>::value, testing::AssertionResult>::type
compare_values(const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results,
const size_t)
{
const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results);
return ngraph::test::all_close(expected, result);
}
// used for float16 and bfloat 16 comparisons
template <typename T>
typename std::enable_if<std::is_class<T>::value, testing::AssertionResult>::type
compare_values(const std::shared_ptr<ngraph::op::Constant>& expected_results,
const std::shared_ptr<ngraph::runtime::Tensor>& results,
const size_t tolerance_bits)
{
const auto expected = expected_results->get_vector<T>();
const auto result = read_vector<T>(results);
// TODO: add testing infrastructure for float16 and bfloat16 to avoid cast to double
std::vector<double> expected_double(expected.size());
std::vector<double> result_double(result.size());
NGRAPH_CHECK(expected.size() == result.size(),
"Number of expected and computed results don't match");
for (size_t i = 0; i < expected.size(); ++i)
{
expected_double[i] = static_cast<double>(expected[i]);
result_double[i] = static_cast<double>(result[i]);
}
return ngraph::test::all_close_f(expected_double, result_double, tolerance_bits);
}
}; // namespace
test::INTERPRETER_Engine::INTERPRETER_Engine(const std::shared_ptr<Function> function)
: m_function{function}
{
m_backend = ngraph::runtime::Backend::create(NG_BACKEND_NAME, false); // static INT backend
test::INTERPRETER_Engine::INTERPRETER_Engine(const std::shared_ptr<Function> function) : m_function{function} {
m_backend = ngraph::runtime::Backend::create(NG_BACKEND_NAME, false); // static INT backend
m_executable = m_backend->compile(m_function);
for (size_t i = 0; i < m_function->get_output_size(); ++i)
{
m_result_tensors.push_back(m_backend->create_tensor(m_function->get_output_element_type(i),
m_function->get_output_shape(i)));
for (size_t i = 0; i < m_function->get_output_size(); ++i) {
m_result_tensors.push_back(
m_backend->create_tensor(m_function->get_output_element_type(i), m_function->get_output_shape(i)));
}
}
test::INTERPRETER_Engine::INTERPRETER_Engine(const std::shared_ptr<Function> function,
INTERPRETER_Engine::DynamicBackendTag)
: m_function{function}
{
m_backend = ngraph::runtime::Backend::create(NG_BACKEND_NAME, true); // dynamic INT backend
: m_function{function} {
m_backend = ngraph::runtime::Backend::create(NG_BACKEND_NAME, true); // dynamic INT backend
m_executable = m_backend->compile(m_function);
for (size_t i = 0; i < m_function->get_output_size(); ++i)
{
m_result_tensors.push_back(m_backend->create_dynamic_tensor(
m_function->get_output_element_type(i), m_function->get_output_partial_shape(i)));
for (size_t i = 0; i < m_function->get_output_size(); ++i) {
m_result_tensors.push_back(m_backend->create_dynamic_tensor(m_function->get_output_element_type(i),
m_function->get_output_partial_shape(i)));
}
}
test::INTERPRETER_Engine test::INTERPRETER_Engine::dynamic(const std::shared_ptr<Function> function)
{
test::INTERPRETER_Engine test::INTERPRETER_Engine::dynamic(const std::shared_ptr<Function> function) {
return INTERPRETER_Engine{function, DynamicBackendTag{}};
}
void test::INTERPRETER_Engine::infer()
{
void test::INTERPRETER_Engine::infer() {
const auto& function_results = m_function->get_results();
NGRAPH_CHECK(m_expected_outputs.size() == function_results.size(),
"Expected number of outputs is different from the function's number "
@@ -116,13 +102,10 @@ void test::INTERPRETER_Engine::infer()
m_executable->call_with_validate(m_result_tensors, m_input_tensors);
}
testing::AssertionResult
test::INTERPRETER_Engine::compare_results_with_tolerance_as_fp(const float tolerance)
{
testing::AssertionResult test::INTERPRETER_Engine::compare_results_with_tolerance_as_fp(const float tolerance) {
auto comparison_result = testing::AssertionSuccess();
for (size_t i = 0; i < m_expected_outputs.size(); ++i)
{
for (size_t i = 0; i < m_expected_outputs.size(); ++i) {
const auto& result_tensor = m_result_tensors.at(i);
const auto& expected_result_constant = m_expected_outputs.at(i);
const auto& element_type = result_tensor->get_element_type();
@@ -130,29 +113,23 @@ testing::AssertionResult
const auto& expected_shape = expected_result_constant->get_shape();
const auto& result_shape = result_tensor->get_shape();
if (expected_shape != result_shape)
{
if (expected_shape != result_shape) {
comparison_result = testing::AssertionFailure();
comparison_result << "Computed data shape(" << result_shape
<< ") does not match the expected shape(" << expected_shape
<< ") for output " << i << std::endl;
comparison_result << "Computed data shape(" << result_shape << ") does not match the expected shape("
<< expected_shape << ") for output " << i << std::endl;
break;
}
switch (element_type)
{
switch (element_type) {
case element::Type_t::f32:
comparison_result =
compare_with_fp_tolerance(expected_result_constant, result_tensor, tolerance);
comparison_result = compare_with_fp_tolerance(expected_result_constant, result_tensor, tolerance);
break;
default:
comparison_result = testing::AssertionFailure()
<< "Unsupported data type encountered in "
"'compare_results_with_tolerance_as_fp' method";
comparison_result = testing::AssertionFailure() << "Unsupported data type encountered in "
"'compare_results_with_tolerance_as_fp' method";
}
if (comparison_result == testing::AssertionFailure())
{
if (comparison_result == testing::AssertionFailure()) {
break;
}
}
@@ -160,12 +137,10 @@ testing::AssertionResult
return comparison_result;
}
testing::AssertionResult test::INTERPRETER_Engine::compare_results(const size_t tolerance_bits)
{
testing::AssertionResult test::INTERPRETER_Engine::compare_results(const size_t tolerance_bits) {
auto comparison_result = testing::AssertionSuccess();
for (size_t i = 0; i < m_expected_outputs.size(); ++i)
{
for (size_t i = 0; i < m_expected_outputs.size(); ++i) {
const auto& result_tensor = m_result_tensors.at(i);
const auto& expected_result_constant = m_expected_outputs.at(i);
const auto& element_type = result_tensor->get_element_type();
@@ -173,76 +148,61 @@ testing::AssertionResult test::INTERPRETER_Engine::compare_results(const size_t
const auto& expected_shape = expected_result_constant->get_shape();
const auto& result_shape = result_tensor->get_shape();
if (expected_shape != result_shape)
{
if (expected_shape != result_shape) {
comparison_result = testing::AssertionFailure();
comparison_result << "Computed data shape(" << result_shape
<< ") does not match the expected shape(" << expected_shape
<< ") for output " << i << std::endl;
comparison_result << "Computed data shape(" << result_shape << ") does not match the expected shape("
<< expected_shape << ") for output " << i << std::endl;
break;
}
switch (element_type)
{
switch (element_type) {
case element::Type_t::f16:
comparison_result = compare_values<ngraph::float16>(
expected_result_constant, result_tensor, tolerance_bits);
comparison_result =
compare_values<ngraph::float16>(expected_result_constant, result_tensor, tolerance_bits);
break;
case element::Type_t::bf16:
comparison_result = compare_values<ngraph::bfloat16>(
expected_result_constant, result_tensor, tolerance_bits);
comparison_result =
compare_values<ngraph::bfloat16>(expected_result_constant, result_tensor, tolerance_bits);
break;
case element::Type_t::f32:
comparison_result =
compare_values<float>(expected_result_constant, result_tensor, tolerance_bits);
comparison_result = compare_values<float>(expected_result_constant, result_tensor, tolerance_bits);
break;
case element::Type_t::f64:
comparison_result =
compare_values<double>(expected_result_constant, result_tensor, tolerance_bits);
comparison_result = compare_values<double>(expected_result_constant, result_tensor, tolerance_bits);
break;
case element::Type_t::i8:
comparison_result =
compare_values<int8_t>(expected_result_constant, result_tensor, tolerance_bits);
comparison_result = compare_values<int8_t>(expected_result_constant, result_tensor, tolerance_bits);
break;
case element::Type_t::i16:
comparison_result =
compare_values<int16_t>(expected_result_constant, result_tensor, tolerance_bits);
comparison_result = compare_values<int16_t>(expected_result_constant, result_tensor, tolerance_bits);
break;
case element::Type_t::i32:
comparison_result =
compare_values<int32_t>(expected_result_constant, result_tensor, tolerance_bits);
comparison_result = compare_values<int32_t>(expected_result_constant, result_tensor, tolerance_bits);
break;
case element::Type_t::i64:
comparison_result =
compare_values<int64_t>(expected_result_constant, result_tensor, tolerance_bits);
comparison_result = compare_values<int64_t>(expected_result_constant, result_tensor, tolerance_bits);
break;
case element::Type_t::u8:
comparison_result =
compare_values<uint8_t>(expected_result_constant, result_tensor, tolerance_bits);
comparison_result = compare_values<uint8_t>(expected_result_constant, result_tensor, tolerance_bits);
break;
case element::Type_t::u16:
comparison_result =
compare_values<uint16_t>(expected_result_constant, result_tensor, tolerance_bits);
comparison_result = compare_values<uint16_t>(expected_result_constant, result_tensor, tolerance_bits);
break;
case element::Type_t::u32:
comparison_result =
compare_values<uint32_t>(expected_result_constant, result_tensor, tolerance_bits);
comparison_result = compare_values<uint32_t>(expected_result_constant, result_tensor, tolerance_bits);
break;
case element::Type_t::u64:
comparison_result =
compare_values<uint64_t>(expected_result_constant, result_tensor, tolerance_bits);
comparison_result = compare_values<uint64_t>(expected_result_constant, result_tensor, tolerance_bits);
break;
case element::Type_t::boolean:
comparison_result =
compare_values<char>(expected_result_constant, result_tensor, tolerance_bits);
comparison_result = compare_values<char>(expected_result_constant, result_tensor, tolerance_bits);
break;
default:
comparison_result = testing::AssertionFailure()
<< "Unsupported data type encountered in 'compare_results' method";
}
if (comparison_result == testing::AssertionFailure())
{
if (comparison_result == testing::AssertionFailure()) {
break;
}
}
@@ -250,8 +210,7 @@ testing::AssertionResult test::INTERPRETER_Engine::compare_results(const size_t
return comparison_result;
}
void test::INTERPRETER_Engine::reset()
{
void test::INTERPRETER_Engine::reset() {
m_input_index = 0;
m_output_index = 0;
m_expected_outputs.clear();

View File

@@ -11,78 +11,66 @@
#include "util/engine/engine_traits.hpp"
#include "util/engine/test_case_engine.hpp"
namespace ngraph
{
namespace test
{
class INTERPRETER_Engine : public TestCaseEngine
{
public:
INTERPRETER_Engine(const std::shared_ptr<Function> function);
namespace ngraph {
namespace test {
class INTERPRETER_Engine : public TestCaseEngine {
public:
INTERPRETER_Engine(const std::shared_ptr<Function> function);
static INTERPRETER_Engine dynamic(const std::shared_ptr<Function> function);
static INTERPRETER_Engine dynamic(const std::shared_ptr<Function> function);
void infer() override;
void infer() override;
testing::AssertionResult compare_results(
const size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS) override;
testing::AssertionResult compare_results(const size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS) override;
testing::AssertionResult
compare_results_with_tolerance_as_fp(const float tolerance = 1.0e-5f) override;
testing::AssertionResult compare_results_with_tolerance_as_fp(const float tolerance = 1.0e-5f) override;
void reset() override;
void reset() override;
template <typename T>
void add_input(const Shape& shape, const std::vector<T>& values)
{
const auto params = m_function->get_parameters();
auto tensor =
m_backend->create_tensor(params.at(m_input_index)->get_element_type(), shape);
template <typename T>
void add_input(const Shape& shape, const std::vector<T>& values) {
const auto params = m_function->get_parameters();
auto tensor = m_backend->create_tensor(params.at(m_input_index)->get_element_type(), shape);
copy_data(tensor, values);
copy_data(tensor, values);
m_input_tensors.push_back(tensor);
m_input_tensors.push_back(tensor);
++m_input_index;
}
template <typename T>
void add_expected_output(const ngraph::Shape& expected_shape,
const std::vector<T>& values)
{
const auto results = m_function->get_results();
const auto function_output_type = results.at(m_output_index)->get_element_type();
m_expected_outputs.emplace_back(std::make_shared<ngraph::op::Constant>(
function_output_type, expected_shape, values));
++m_output_index;
}
private:
struct DynamicBackendTag
{
};
/// A private constructor that should only be used from the dynamic() member function
INTERPRETER_Engine(const std::shared_ptr<Function> function, DynamicBackendTag);
static constexpr const char* NG_BACKEND_NAME = "INTERPRETER";
const std::shared_ptr<Function> m_function;
std::shared_ptr<runtime::Backend> m_backend;
std::shared_ptr<ngraph::runtime::Executable> m_executable;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_result_tensors;
std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs;
size_t m_input_index = 0;
size_t m_output_index = 0;
};
template <>
struct supports_dynamic<INTERPRETER_Engine>
{
static constexpr bool value = true;
};
++m_input_index;
}
}
template <typename T>
void add_expected_output(const ngraph::Shape& expected_shape, const std::vector<T>& values) {
const auto results = m_function->get_results();
const auto function_output_type = results.at(m_output_index)->get_element_type();
m_expected_outputs.emplace_back(
std::make_shared<ngraph::op::Constant>(function_output_type, expected_shape, values));
++m_output_index;
}
private:
struct DynamicBackendTag {};
/// A private constructor that should only be used from the dynamic() member function
INTERPRETER_Engine(const std::shared_ptr<Function> function, DynamicBackendTag);
static constexpr const char* NG_BACKEND_NAME = "INTERPRETER";
const std::shared_ptr<Function> m_function;
std::shared_ptr<runtime::Backend> m_backend;
std::shared_ptr<ngraph::runtime::Executable> m_executable;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_input_tensors;
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> m_result_tensors;
std::vector<std::shared_ptr<ngraph::op::Constant>> m_expected_outputs;
size_t m_input_index = 0;
size_t m_output_index = 0;
};
template <>
struct supports_dynamic<INTERPRETER_Engine> {
static constexpr bool value = true;
};
} // namespace test
} // namespace ngraph

View File

@@ -2,15 +2,14 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "shared_utils.hpp"
#include <cmath>
#include <sstream>
#include "shared_utils.hpp"
testing::AssertionResult ngraph::test::compare_with_tolerance(const std::vector<float>& expected,
const std::vector<float>& results,
const float tolerance)
{
const float tolerance) {
auto comparison_result = testing::AssertionSuccess();
std::stringstream msg;
@@ -18,18 +17,15 @@ testing::AssertionResult ngraph::test::compare_with_tolerance(const std::vector<
bool rc = true;
for (std::size_t j = 0; j < expected.size(); ++j)
{
for (std::size_t j = 0; j < expected.size(); ++j) {
float diff = std::fabs(results[j] - expected[j]);
if (diff > tolerance)
{
if (diff > tolerance) {
msg << expected[j] << " is not close to " << results[j] << " at index " << j << "\n";
rc = false;
}
}
if (!rc)
{
if (!rc) {
comparison_result = testing::AssertionFailure();
comparison_result << msg.str();
}

View File

@@ -5,14 +5,13 @@
#pragma once
#include <gtest/gtest.h>
#include <vector>
namespace ngraph
{
namespace test
{
testing::AssertionResult compare_with_tolerance(const std::vector<float>& expected_results,
const std::vector<float>& results,
const float tolerance);
}
} // namespace ngraph
namespace ngraph {
namespace test {
testing::AssertionResult compare_with_tolerance(const std::vector<float>& expected_results,
const std::vector<float>& results,
const float tolerance);
}
} // namespace ngraph

View File

@@ -6,47 +6,43 @@
#include <gtest/gtest.h>
namespace ngraph
{
namespace test
{
/// An interface that each test case engine needs to implement. This interface wraps
/// a couple of generic methods which are required by the TestCase class to execute
/// a unit test for a given ngraph::Function.
/// The interface operates on C++ types while internally it can use implementation-specific
/// types, containers and structures.
class TestCaseEngine
{
public:
virtual ~TestCaseEngine() noexcept = default;
namespace ngraph {
namespace test {
/// An interface that each test case engine needs to implement. This interface wraps
/// a couple of generic methods which are required by the TestCase class to execute
/// a unit test for a given ngraph::Function.
/// The interface operates on C++ types while internally it can use implementation-specific
/// types, containers and structures.
class TestCaseEngine {
public:
virtual ~TestCaseEngine() noexcept = default;
/// Performs the inference using data stored as internal state
virtual void infer() = 0;
/// Performs the inference using data stored as internal state
virtual void infer() = 0;
/// Resets the internal state so that the test can be executed again
virtual void reset() = 0;
/// Resets the internal state so that the test can be executed again
virtual void reset() = 0;
/// Compares computed and expected results, returns AssertionSuccess or AssertionFailure
virtual testing::AssertionResult compare_results(const size_t tolerance_bits) = 0;
/// Compares computed and expected results, returns AssertionSuccess or AssertionFailure
virtual testing::AssertionResult compare_results(const size_t tolerance_bits) = 0;
/// Compares computed and expected results, returns AssertionSuccess or AssertionFailure
virtual testing::AssertionResult
compare_results_with_tolerance_as_fp(const float tolerance) = 0;
/// Compares computed and expected results, returns AssertionSuccess or AssertionFailure
virtual testing::AssertionResult compare_results_with_tolerance_as_fp(const float tolerance) = 0;
/// Additionally the interface implementing class needs to define
/// the following 2 methods. They are called from the TestCase class
/// but they can't be a part of interface since they need to be declared as templates
/// Additionally the interface implementing class needs to define
/// the following 2 methods. They are called from the TestCase class
/// but they can't be a part of interface since they need to be declared as templates
/// Passes data (along with its shape) to the next available input.
/// The data should be stored as internal state, not necessarily as vectors
// template <typename T>
// void add_input(const Shape& shape, const std::vector<T>& values)
/// Passes data (along with its shape) to the next available input.
/// The data should be stored as internal state, not necessarily as vectors
// template <typename T>
// void add_input(const Shape& shape, const std::vector<T>& values)
/// Sets the expected data (along with its shape) for the next available output
/// The data should be stored as internal state, not necessarily as vectors
// template <typename T>
// void add_expected_output(const ngraph::Shape& expected_shape,
// const std::vector<T>& values)
};
}
}
/// Sets the expected data (along with its shape) for the next available output
/// The data should be stored as internal state, not necessarily as vectors
// template <typename T>
// void add_expected_output(const ngraph::Shape& expected_shape,
// const std::vector<T>& values)
};
} // namespace test
} // namespace ngraph

View File

@@ -4,8 +4,7 @@
#include "util/float_util.hpp"
std::string ngraph::test::bfloat16_to_bits(bfloat16 f)
{
std::string ngraph::test::bfloat16_to_bits(bfloat16 f) {
std::stringstream ss;
ss << std::bitset<16>(f.to_bits());
std::string unformatted = ss.str();
@@ -19,16 +18,14 @@ std::string ngraph::test::bfloat16_to_bits(bfloat16 f)
formatted.append(" ");
// Mantissa
formatted.append(unformatted, 9, 3);
for (int i = 12; i < 16; i += 4)
{
for (int i = 12; i < 16; i += 4) {
formatted.push_back(' ');
formatted.append(unformatted, i, 4);
}
return formatted;
}
std::string ngraph::test::float16_to_bits(float16 f)
{
std::string ngraph::test::float16_to_bits(float16 f) {
std::stringstream ss;
ss << std::bitset<16>(f.to_bits());
std::string unformatted = ss.str();
@@ -42,16 +39,14 @@ std::string ngraph::test::float16_to_bits(float16 f)
formatted.append(" ");
// Mantissa
formatted.append(unformatted, 6, 2);
for (int i = 8; i < 16; i += 4)
{
for (int i = 8; i < 16; i += 4) {
formatted.push_back(' ');
formatted.append(unformatted, i, 4);
}
return formatted;
}
std::string ngraph::test::float_to_bits(float f)
{
std::string ngraph::test::float_to_bits(float f) {
FloatUnion fu{f};
std::stringstream ss;
ss << std::bitset<32>(fu.i);
@@ -66,16 +61,14 @@ std::string ngraph::test::float_to_bits(float f)
formatted.append(" ");
// Mantissa
formatted.append(unformatted, 9, 3);
for (int i = 12; i < 32; i += 4)
{
for (int i = 12; i < 32; i += 4) {
formatted.push_back(' ');
formatted.append(unformatted, i, 4);
}
return formatted;
}
std::string ngraph::test::double_to_bits(double d)
{
std::string ngraph::test::double_to_bits(double d) {
DoubleUnion du{d};
std::stringstream ss;
ss << std::bitset<64>(du.i);
@@ -89,50 +82,40 @@ std::string ngraph::test::double_to_bits(double d)
formatted.append(unformatted, 1, 11);
formatted.push_back(' ');
// Mantissa
for (int i = 12; i < 64; i += 4)
{
for (int i = 12; i < 64; i += 4) {
formatted.push_back(' ');
formatted.append(unformatted, i, 4);
}
return formatted;
}
ngraph::bfloat16 ngraph::test::bits_to_bfloat16(const std::string& s)
{
ngraph::bfloat16 ngraph::test::bits_to_bfloat16(const std::string& s) {
std::string unformatted = s;
unformatted.erase(remove_if(unformatted.begin(), unformatted.end(), ::isspace),
unformatted.end());
unformatted.erase(remove_if(unformatted.begin(), unformatted.end(), ::isspace), unformatted.end());
if (unformatted.size() != 16)
{
if (unformatted.size() != 16) {
throw ngraph_error("Input length must be 16");
}
std::bitset<16> bs(unformatted);
return bfloat16::from_bits(static_cast<uint16_t>(bs.to_ulong()));
}
ngraph::float16 ngraph::test::bits_to_float16(const std::string& s)
{
ngraph::float16 ngraph::test::bits_to_float16(const std::string& s) {
std::string unformatted = s;
unformatted.erase(remove_if(unformatted.begin(), unformatted.end(), ::isspace),
unformatted.end());
unformatted.erase(remove_if(unformatted.begin(), unformatted.end(), ::isspace), unformatted.end());
if (unformatted.size() != 16)
{
if (unformatted.size() != 16) {
throw ngraph_error("Input length must be 16");
}
std::bitset<16> bs(unformatted);
return float16::from_bits(static_cast<uint16_t>(bs.to_ulong()));
}
float ngraph::test::bits_to_float(const std::string& s)
{
float ngraph::test::bits_to_float(const std::string& s) {
std::string unformatted = s;
unformatted.erase(remove_if(unformatted.begin(), unformatted.end(), ::isspace),
unformatted.end());
unformatted.erase(remove_if(unformatted.begin(), unformatted.end(), ::isspace), unformatted.end());
if (unformatted.size() != 32)
{
if (unformatted.size() != 32) {
throw ngraph_error("Input length must be 32");
}
std::bitset<32> bs(unformatted);
@@ -141,14 +124,11 @@ float ngraph::test::bits_to_float(const std::string& s)
return fu.f;
}
double ngraph::test::bits_to_double(const std::string& s)
{
double ngraph::test::bits_to_double(const std::string& s) {
std::string unformatted = s;
unformatted.erase(remove_if(unformatted.begin(), unformatted.end(), ::isspace),
unformatted.end());
unformatted.erase(remove_if(unformatted.begin(), unformatted.end(), ::isspace), unformatted.end());
if (unformatted.size() != 64)
{
if (unformatted.size() != 64) {
throw ngraph_error("Input length must be 64");
}
std::bitset<64> bs(unformatted);

View File

@@ -10,44 +10,51 @@
#include "ngraph/ngraph.hpp"
namespace ngraph
{
namespace test
{
union FloatUnion {
FloatUnion() { i = 0; }
FloatUnion(float val) { f = val; }
FloatUnion(uint32_t val) { i = val; }
FloatUnion(uint32_t s, uint32_t e, uint32_t f)
: FloatUnion(s << 31 | e << 23 | f)
{
}
float f;
uint32_t i;
};
union DoubleUnion {
DoubleUnion() { i = 0; }
DoubleUnion(double val) { d = val; }
DoubleUnion(uint64_t val) { i = val; }
double d;
uint64_t i;
};
std::string bfloat16_to_bits(bfloat16 f);
std::string float16_to_bits(float16 f);
std::string float_to_bits(float f);
std::string double_to_bits(double d);
bfloat16 bits_to_bfloat16(const std::string& s);
float bits_to_float(const std::string& s);
double bits_to_double(const std::string& s);
float16 bits_to_float16(const std::string& s);
namespace ngraph {
namespace test {
union FloatUnion {
FloatUnion() {
i = 0;
}
}
FloatUnion(float val) {
f = val;
}
FloatUnion(uint32_t val) {
i = val;
}
FloatUnion(uint32_t s, uint32_t e, uint32_t f) : FloatUnion(s << 31 | e << 23 | f) {}
float f;
uint32_t i;
};
union DoubleUnion {
DoubleUnion() {
i = 0;
}
DoubleUnion(double val) {
d = val;
}
DoubleUnion(uint64_t val) {
i = val;
}
double d;
uint64_t i;
};
std::string bfloat16_to_bits(bfloat16 f);
std::string float16_to_bits(float16 f);
std::string float_to_bits(float f);
std::string double_to_bits(double d);
bfloat16 bits_to_bfloat16(const std::string& s);
float bits_to_float(const std::string& s);
double bits_to_double(const std::string& s);
float16 bits_to_float16(const std::string& s);
} // namespace test
} // namespace ngraph

View File

@@ -1,22 +0,0 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include "ngraph/ngraph.hpp"
static const std::vector<ngraph::element::Type> s_known_element_types = {
ngraph::element::from<float>(),
ngraph::element::from<double>(),
ngraph::element::from<int8_t>(),
ngraph::element::from<int16_t>(),
ngraph::element::from<int32_t>(),
ngraph::element::from<int64_t>(),
ngraph::element::from<uint8_t>(),
ngraph::element::from<uint16_t>(),
ngraph::element::from<uint32_t>(),
ngraph::element::from<uint64_t>(),
};

View File

@@ -3,20 +3,16 @@
//
// this is for more nuanced testing
class TestMatcher : public ngraph::pattern::Matcher
{
class TestMatcher : public ngraph::pattern::Matcher {
using ngraph::pattern::Matcher::Matcher;
public:
TestMatcher() {}
bool virtual match_value(const ngraph::Output<ngraph::Node>& pattern_value,
const ngraph::Output<ngraph::Node>& graph_value) override
{
if (ngraph::is_type<::ngraph::op::Parameter>(pattern_value.get_node_shared_ptr()))
{
const ngraph::Output<ngraph::Node>& graph_value) override {
if (ngraph::is_type<::ngraph::op::Parameter>(pattern_value.get_node_shared_ptr())) {
bool result = pattern_value == graph_value;
if (result)
{
if (result) {
m_matched_list.push_back(graph_value.get_node_shared_ptr());
}
return result;
@@ -26,11 +22,9 @@ public:
}
public:
bool match(const std::shared_ptr<ngraph::Node>& pattern_node,
const std::shared_ptr<ngraph::Node>& graph_node)
{
NGRAPH_CHECK(pattern_node && graph_node); // the same condition throws an exception in the
// non-test version of `match`
bool match(const std::shared_ptr<ngraph::Node>& pattern_node, const std::shared_ptr<ngraph::Node>& graph_node) {
NGRAPH_CHECK(pattern_node && graph_node); // the same condition throws an exception in the
// non-test version of `match`
NGRAPH_DEBUG << "Starting match pattern = " << pattern_node->get_name()
<< " , graph_node = " << graph_node->get_name();

View File

@@ -19,197 +19,180 @@
#include "ngraph/log.hpp"
#include "ngraph/shape.hpp"
namespace ngraph
{
namespace test
{
namespace init
{
// Recursively define types for N-deep initializer lists
template <typename T, size_t N>
struct NestedInitializerListWrapper
{
using type =
std::initializer_list<typename NestedInitializerListWrapper<T, N - 1>::type>;
};
namespace ngraph {
namespace test {
namespace init {
// Recursively define types for N-deep initializer lists
template <typename T, size_t N>
struct NestedInitializerListWrapper {
using type = std::initializer_list<typename NestedInitializerListWrapper<T, N - 1>::type>;
};
// 1-deep is a plain initializer_list
template <typename T>
struct NestedInitializerListWrapper<T, 1>
{
using type = std::initializer_list<T>;
};
// 1-deep is a plain initializer_list
template <typename T>
struct NestedInitializerListWrapper<T, 1> {
using type = std::initializer_list<T>;
};
// Scalar case is just the element type
template <typename T>
struct NestedInitializerListWrapper<T, 0>
{
using type = T;
};
// Scalar case is just the element type
template <typename T>
struct NestedInitializerListWrapper<T, 0> {
using type = T;
};
// Convenience type name for N-deep initializer lists of Ts
template <typename T, size_t N>
using NestedInitializerList = typename NestedInitializerListWrapper<T, N>::type;
// Convenience type name for N-deep initializer lists of Ts
template <typename T, size_t N>
using NestedInitializerList = typename NestedInitializerListWrapper<T, N>::type;
// Fill in a shape from a nested initializer list
// For a scalar, nothing to do.
template <typename T, size_t N>
typename std::enable_if<(N == 0), void>::type
fill_shape(Shape& /* shape */, const NestedInitializerList<T, N>& /* inits */)
{
}
// Fill in a shape from a nested initializer list
// For a scalar, nothing to do.
template <typename T, size_t N>
typename std::enable_if<(N == 0), void>::type fill_shape(Shape& /* shape */,
const NestedInitializerList<T, N>& /* inits */) {}
// Check that the inits match the shape
template <typename T, size_t N>
typename std::enable_if<(N == 0), void>::type
check_shape(const Shape& shape, const NestedInitializerList<T, N>& /* inits */)
{
if (shape.size() != 0)
{
throw std::invalid_argument("Initializers do not match shape");
}
}
// For a plain initializer list, the shape is the length of the list.
template <typename T, size_t N>
typename std::enable_if<(N == 1)>::type
fill_shape(Shape& shape, const NestedInitializerList<T, N>& inits)
{
shape.push_back(inits.size());
}
template <typename T, size_t N>
typename std::enable_if<(N == 1)>::type
check_shape(const Shape& shape, const NestedInitializerList<T, N>& inits)
{
if (shape.at(shape.size() - N) != inits.size())
{
throw std::invalid_argument("Initializers do not match shape");
}
}
// In the general case, we append our level's length and recurse.
template <typename T, size_t N>
typename std::enable_if<(N > 1), void>::type
fill_shape(Shape& shape, const NestedInitializerList<T, N>& inits)
{
shape.push_back(inits.size());
fill_shape<T, N - 1>(shape, *inits.begin());
}
template <typename T, size_t N>
typename std::enable_if<(N > 1), void>::type
check_shape(const Shape& shape, const NestedInitializerList<T, N>& inits)
{
if (shape.at(shape.size() - N) != inits.size())
{
throw std::invalid_argument("Initializers do not match shape");
}
for (auto it : inits)
{
check_shape<T, N - 1>(shape, it);
}
}
// Get the shape of inits.
template <typename T, size_t N>
Shape get_shape(const NestedInitializerList<T, N>& inits)
{
Shape shape;
fill_shape<T, N>(shape, inits);
check_shape<T, N>(shape, inits);
return shape;
}
template <typename IT, typename T, size_t N>
typename std::enable_if<(N == 1), IT>::type
flatten(IT it, const Shape& shape, const NestedInitializerList<T, N>& inits)
{
if (inits.size() != shape.at(shape.size() - N))
{
throw std::invalid_argument("Initializers do not match shape");
}
for (auto it1 : inits)
{
*(it++) = it1;
}
return it;
}
template <typename IT, typename T, size_t N>
typename std::enable_if<(N > 1), IT>::type
flatten(IT it, const Shape& shape, const NestedInitializerList<T, N>& inits)
{
if (inits.size() != shape.at(shape.size() - N))
{
throw std::invalid_argument("Initializers do not match shape");
}
for (auto it1 : inits)
{
it = flatten<IT, T, N - 1>(it, shape, it1);
}
return it;
}
template <typename IT, typename T, size_t N>
typename std::enable_if<(N == 0), IT>::type
flatten(IT it, const Shape& shape, const NestedInitializerList<T, 0>& init)
{
if (shape.size() != 0)
{
throw std::invalid_argument("Initializers do not match shape");
}
*(it++) = init;
return it;
}
}
template <typename T>
class NDArrayBase
{
using vtype = std::vector<T>;
public:
using type = T;
using iterator = typename vtype::iterator;
using const_iterator = typename vtype::const_iterator;
NDArrayBase(const Shape& shape)
: m_shape(shape)
, m_elements(shape_size(m_shape))
{
}
const Shape& get_shape() const { return m_shape; }
const_iterator begin() const { return m_elements.begin(); }
const_iterator end() const { return m_elements.end(); }
vtype get_vector() { return m_elements; }
const vtype get_vector() const { return m_elements; }
operator const vtype() const { return m_elements; }
operator vtype() { return m_elements; }
void* data() { return m_elements.data(); }
const void* data() const { return m_elements.data(); }
bool operator==(const NDArrayBase<T>& other) const
{
return m_shape == other.m_shape && m_elements == other.m_elements;
}
protected:
Shape m_shape;
vtype m_elements;
};
/// An N dimensional array of elements of type T
template <typename T, size_t N>
class NDArray : public NDArrayBase<T>
{
public:
NDArray(const init::NestedInitializerList<T, N>& initial_value)
: NDArrayBase<T>(init::get_shape<T, N>(initial_value))
{
init::flatten<typename std::vector<T>::iterator, T, N>(
NDArrayBase<T>::m_elements.begin(), NDArrayBase<T>::m_shape, initial_value);
}
};
// Check that the inits match the shape
template <typename T, size_t N>
typename std::enable_if<(N == 0), void>::type check_shape(const Shape& shape,
const NestedInitializerList<T, N>& /* inits */) {
if (shape.size() != 0) {
throw std::invalid_argument("Initializers do not match shape");
}
}
// For a plain initializer list, the shape is the length of the list.
template <typename T, size_t N>
typename std::enable_if<(N == 1)>::type fill_shape(Shape& shape, const NestedInitializerList<T, N>& inits) {
shape.push_back(inits.size());
}
template <typename T, size_t N>
typename std::enable_if<(N == 1)>::type check_shape(const Shape& shape, const NestedInitializerList<T, N>& inits) {
if (shape.at(shape.size() - N) != inits.size()) {
throw std::invalid_argument("Initializers do not match shape");
}
}
// In the general case, we append our level's length and recurse.
template <typename T, size_t N>
typename std::enable_if<(N > 1), void>::type fill_shape(Shape& shape, const NestedInitializerList<T, N>& inits) {
shape.push_back(inits.size());
fill_shape<T, N - 1>(shape, *inits.begin());
}
template <typename T, size_t N>
typename std::enable_if<(N > 1), void>::type check_shape(const Shape& shape, const NestedInitializerList<T, N>& inits) {
if (shape.at(shape.size() - N) != inits.size()) {
throw std::invalid_argument("Initializers do not match shape");
}
for (auto it : inits) {
check_shape<T, N - 1>(shape, it);
}
}
// Get the shape of inits.
template <typename T, size_t N>
Shape get_shape(const NestedInitializerList<T, N>& inits) {
Shape shape;
fill_shape<T, N>(shape, inits);
check_shape<T, N>(shape, inits);
return shape;
}
template <typename IT, typename T, size_t N>
typename std::enable_if<(N == 1), IT>::type flatten(IT it,
const Shape& shape,
const NestedInitializerList<T, N>& inits) {
if (inits.size() != shape.at(shape.size() - N)) {
throw std::invalid_argument("Initializers do not match shape");
}
for (auto it1 : inits) {
*(it++) = it1;
}
return it;
}
template <typename IT, typename T, size_t N>
typename std::enable_if<(N > 1), IT>::type flatten(IT it,
const Shape& shape,
const NestedInitializerList<T, N>& inits) {
if (inits.size() != shape.at(shape.size() - N)) {
throw std::invalid_argument("Initializers do not match shape");
}
for (auto it1 : inits) {
it = flatten<IT, T, N - 1>(it, shape, it1);
}
return it;
}
template <typename IT, typename T, size_t N>
typename std::enable_if<(N == 0), IT>::type flatten(IT it,
const Shape& shape,
const NestedInitializerList<T, 0>& init) {
if (shape.size() != 0) {
throw std::invalid_argument("Initializers do not match shape");
}
*(it++) = init;
return it;
}
} // namespace init
template <typename T>
class NDArrayBase {
using vtype = std::vector<T>;
public:
using type = T;
using iterator = typename vtype::iterator;
using const_iterator = typename vtype::const_iterator;
NDArrayBase(const Shape& shape) : m_shape(shape), m_elements(shape_size(m_shape)) {}
const Shape& get_shape() const {
return m_shape;
}
const_iterator begin() const {
return m_elements.begin();
}
const_iterator end() const {
return m_elements.end();
}
vtype get_vector() {
return m_elements;
}
const vtype get_vector() const {
return m_elements;
}
operator const vtype() const {
return m_elements;
}
operator vtype() {
return m_elements;
}
void* data() {
return m_elements.data();
}
const void* data() const {
return m_elements.data();
}
bool operator==(const NDArrayBase<T>& other) const {
return m_shape == other.m_shape && m_elements == other.m_elements;
}
protected:
Shape m_shape;
vtype m_elements;
};
/// An N dimensional array of elements of type T
template <typename T, size_t N>
class NDArray : public NDArrayBase<T> {
public:
NDArray(const init::NestedInitializerList<T, N>& initial_value)
: NDArrayBase<T>(init::get_shape<T, N>(initial_value)) {
init::flatten<typename std::vector<T>::iterator, T, N>(NDArrayBase<T>::m_elements.begin(),
NDArrayBase<T>::m_shape,
initial_value);
}
};
} // namespace test
} // namespace ngraph

View File

@@ -2,57 +2,163 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "onnx_test_util.hpp"
#include <onnx/onnx_pb.h>
#include <exception>
#include <fstream>
#include <onnx/onnx_pb.h>
#include <sstream>
#include "onnx_common/parser.hpp"
#include "onnx_test_util.hpp"
using namespace ngraph;
using namespace ngraph::test;
namespace
{
ComparisonResult compare_nodes(const ONNX_NAMESPACE::GraphProto& graph,
const ONNX_NAMESPACE::GraphProto& ref_graph)
{
if (graph.node_size() != ref_graph.node_size())
{
return ComparisonResult::fail("The number of nodes in compared models doesn't match");
namespace {
ComparisonResult compare_nodes(const ONNX_NAMESPACE::GraphProto& graph, const ONNX_NAMESPACE::GraphProto& ref_graph) {
if (graph.node_size() != ref_graph.node_size()) {
return ComparisonResult::fail("The number of nodes in compared models doesn't match");
} else {
for (int i = 0; i < graph.node_size(); ++i) {
const auto& lhs = graph.node(i);
const auto& rhs = ref_graph.node(i);
if (lhs.op_type() != rhs.op_type()) {
return ComparisonResult::fail("Operation types are different at index " + std::to_string(i) + ": " +
lhs.op_type() + " vs " + rhs.op_type());
}
for (int j = 0; j < lhs.input_size(); ++j) {
if (lhs.input(j) != rhs.input(j)) {
return ComparisonResult::fail("Input names don't match for nodes at index " + std::to_string(i) +
": " + lhs.input(j) + " vs " + rhs.input(j));
}
}
for (int j = 0; j < lhs.output_size(); ++j) {
if (lhs.output(j) != rhs.output(j)) {
return ComparisonResult::fail("Output names don't match for nodes at index " + std::to_string(i) +
": " + lhs.output(j) + " vs " + rhs.output(j));
}
}
}
else
{
for (int i = 0; i < graph.node_size(); ++i)
{
const auto& lhs = graph.node(i);
const auto& rhs = ref_graph.node(i);
}
if (lhs.op_type() != rhs.op_type())
{
return ComparisonResult::fail("Operation types are different at index " +
std::to_string(i) + ": " + lhs.op_type() +
" vs " + rhs.op_type());
}
return ComparisonResult::pass();
}
for (int j = 0; j < lhs.input_size(); ++j)
{
if (lhs.input(j) != rhs.input(j))
{
return ComparisonResult::fail(
"Input names don't match for nodes at index " + std::to_string(i) +
": " + lhs.input(j) + " vs " + rhs.input(j));
}
}
ComparisonResult compare_value_info(const ONNX_NAMESPACE::ValueInfoProto& lhs,
const ONNX_NAMESPACE::ValueInfoProto& rhs,
const std::string& item_type) {
if (lhs.name() != rhs.name()) {
return ComparisonResult::fail(item_type + " names in the graph don't match: " + lhs.name() + " vs " +
rhs.name());
}
for (int j = 0; j < lhs.output_size(); ++j)
{
if (lhs.output(j) != rhs.output(j))
{
return ComparisonResult::fail(
"Output names don't match for nodes at index " + std::to_string(i) +
": " + lhs.output(j) + " vs " + rhs.output(j));
const auto& lhs_tensor = lhs.type().tensor_type();
const auto& rhs_tensor = rhs.type().tensor_type();
if (lhs_tensor.elem_type() != rhs_tensor.elem_type()) {
return ComparisonResult::fail("Element types don't match for " + item_type + " " + lhs.name() + ": " +
std::to_string(lhs_tensor.elem_type()) + " vs " +
std::to_string(rhs_tensor.elem_type()));
}
const auto& lhs_shape = lhs_tensor.shape();
const auto& rhs_shape = rhs_tensor.shape();
if (lhs_shape.dim_size() != rhs_shape.dim_size()) {
return ComparisonResult::fail("Tensor ranks don't match for " + item_type + " " + lhs.name() + ": " +
std::to_string(lhs_shape.dim_size()) + " vs " +
std::to_string(rhs_shape.dim_size()));
} else {
for (int j = 0; j < lhs_shape.dim_size(); ++j) {
const auto& lhs_dim = lhs_shape.dim(j);
const auto& rhs_dim = rhs_shape.dim(j);
if ((lhs_dim.has_dim_value() && rhs_dim.has_dim_param()) ||
(rhs_dim.has_dim_value() && lhs_dim.has_dim_param())) {
return ComparisonResult::fail("Dynamic vs static dimension mismatch for " + item_type + " " +
lhs.name() + " at index: " + std::to_string(j));
} else if (lhs_dim.has_dim_value() && lhs_dim.dim_value() != rhs_dim.dim_value()) {
return ComparisonResult::fail("Shape dimensions don't match for " + item_type + " " + lhs.name() +
" at index: " + std::to_string(j) + ". " +
std::to_string(lhs_dim.dim_value()) + " vs " +
std::to_string(rhs_dim.dim_value()));
}
}
}
return ComparisonResult::pass();
}
ComparisonResult compare_inputs(const ONNX_NAMESPACE::GraphProto& graph, const ONNX_NAMESPACE::GraphProto& ref_graph) {
if (graph.input_size() != ref_graph.input_size()) {
return ComparisonResult::fail(
"The number of inputs in compared models doesn't match: " + std::to_string(graph.input_size()) + " vs " +
std::to_string(ref_graph.input_size()));
} else {
for (int i = 0; i < graph.input_size(); ++i) {
const auto& lhs = graph.input(i);
const auto& rhs = ref_graph.input(i);
const auto res = compare_value_info(lhs, rhs, "input");
if (!res.is_ok) {
return res;
}
}
return ComparisonResult::pass();
}
}
ComparisonResult compare_outputs(const ONNX_NAMESPACE::GraphProto& graph, const ONNX_NAMESPACE::GraphProto& ref_graph) {
if (graph.output_size() != ref_graph.output_size()) {
return ComparisonResult::fail("The number of outputs in compared models doesn't match" +
std::to_string(graph.output_size()) + " vs " +
std::to_string(ref_graph.output_size()));
} else {
for (int i = 0; i < graph.output_size(); ++i) {
const auto& lhs = graph.output(i);
const auto& rhs = ref_graph.output(i);
const auto res = compare_value_info(lhs, rhs, "output");
if (!res.is_ok) {
return res;
}
}
return ComparisonResult::pass();
}
}
ComparisonResult compare_initializers(const ONNX_NAMESPACE::GraphProto& graph,
const ONNX_NAMESPACE::GraphProto& ref_graph) {
if (graph.initializer_size() != ref_graph.initializer_size()) {
return ComparisonResult::fail("The number of initializers in compared models doesn't match" +
std::to_string(graph.initializer_size()) + " vs " +
std::to_string(ref_graph.initializer_size()));
} else {
for (int i = 0; i < graph.initializer_size(); ++i) {
const auto& lhs = graph.initializer(i);
const auto& rhs = ref_graph.initializer(i);
if (lhs.name() != rhs.name()) {
return ComparisonResult::fail("Initializer names in the graph don't match: " + lhs.name() + " vs " +
rhs.name());
} else if (lhs.data_type() != rhs.data_type()) {
return ComparisonResult::fail(
"Initializer data types in the graph don't match: " + std::to_string(lhs.data_type()) + " vs " +
std::to_string(rhs.data_type()));
} else if (lhs.dims_size() != rhs.dims_size()) {
return ComparisonResult::fail(
"Initializer ranks in the graph don't match: " + std::to_string(lhs.dims_size()) + " vs " +
std::to_string(rhs.dims_size()));
} else {
for (int j = 0; j < lhs.dims_size(); ++j) {
if (lhs.dims(j) != rhs.dims(j)) {
return ComparisonResult::fail("Shape dimensions don't match for initializer " + lhs.name() +
" at index: " + std::to_string(j) + ". " +
std::to_string(lhs.dims(j)) + " vs " +
std::to_string(rhs.dims(j)));
}
}
}
@@ -60,205 +166,35 @@ namespace
return ComparisonResult::pass();
}
}
ComparisonResult compare_value_info(const ONNX_NAMESPACE::ValueInfoProto& lhs,
const ONNX_NAMESPACE::ValueInfoProto& rhs,
const std::string& item_type)
{
if (lhs.name() != rhs.name())
{
return ComparisonResult::fail(
item_type + " names in the graph don't match: " + lhs.name() + " vs " + rhs.name());
}
const auto& lhs_tensor = lhs.type().tensor_type();
const auto& rhs_tensor = rhs.type().tensor_type();
if (lhs_tensor.elem_type() != rhs_tensor.elem_type())
{
return ComparisonResult::fail("Element types don't match for " + item_type + " " +
lhs.name() + ": " +
std::to_string(lhs_tensor.elem_type()) + " vs " +
std::to_string(rhs_tensor.elem_type()));
}
const auto& lhs_shape = lhs_tensor.shape();
const auto& rhs_shape = rhs_tensor.shape();
if (lhs_shape.dim_size() != rhs_shape.dim_size())
{
return ComparisonResult::fail("Tensor ranks don't match for " + item_type + " " +
lhs.name() + ": " + std::to_string(lhs_shape.dim_size()) +
" vs " + std::to_string(rhs_shape.dim_size()));
}
else
{
for (int j = 0; j < lhs_shape.dim_size(); ++j)
{
const auto& lhs_dim = lhs_shape.dim(j);
const auto& rhs_dim = rhs_shape.dim(j);
if ((lhs_dim.has_dim_value() && rhs_dim.has_dim_param()) ||
(rhs_dim.has_dim_value() && lhs_dim.has_dim_param()))
{
return ComparisonResult::fail("Dynamic vs static dimension mismatch for " +
item_type + " " + lhs.name() +
" at index: " + std::to_string(j));
}
else if (lhs_dim.has_dim_value() && lhs_dim.dim_value() != rhs_dim.dim_value())
{
return ComparisonResult::fail("Shape dimensions don't match for " + item_type +
" " + lhs.name() +
" at index: " + std::to_string(j) + ". " +
std::to_string(lhs_dim.dim_value()) + " vs " +
std::to_string(rhs_dim.dim_value()));
}
}
}
return ComparisonResult::pass();
ComparisonResult compare_onnx_graphs(const ONNX_NAMESPACE::GraphProto& graph,
const ONNX_NAMESPACE::GraphProto& ref_graph) {
ComparisonResult comparison = compare_inputs(graph, ref_graph);
if (!comparison.is_ok) {
return comparison;
}
ComparisonResult compare_inputs(const ONNX_NAMESPACE::GraphProto& graph,
const ONNX_NAMESPACE::GraphProto& ref_graph)
{
if (graph.input_size() != ref_graph.input_size())
{
return ComparisonResult::fail(
"The number of inputs in compared models doesn't match: " +
std::to_string(graph.input_size()) + " vs " +
std::to_string(ref_graph.input_size()));
}
else
{
for (int i = 0; i < graph.input_size(); ++i)
{
const auto& lhs = graph.input(i);
const auto& rhs = ref_graph.input(i);
const auto res = compare_value_info(lhs, rhs, "input");
if (!res.is_ok)
{
return res;
}
}
return ComparisonResult::pass();
}
comparison = compare_outputs(graph, ref_graph);
if (!comparison.is_ok) {
return comparison;
}
ComparisonResult compare_outputs(const ONNX_NAMESPACE::GraphProto& graph,
const ONNX_NAMESPACE::GraphProto& ref_graph)
{
if (graph.output_size() != ref_graph.output_size())
{
return ComparisonResult::fail("The number of outputs in compared models doesn't match" +
std::to_string(graph.output_size()) + " vs " +
std::to_string(ref_graph.output_size()));
}
else
{
for (int i = 0; i < graph.output_size(); ++i)
{
const auto& lhs = graph.output(i);
const auto& rhs = ref_graph.output(i);
const auto res = compare_value_info(lhs, rhs, "output");
if (!res.is_ok)
{
return res;
}
}
return ComparisonResult::pass();
}
comparison = compare_initializers(graph, ref_graph);
if (!comparison.is_ok) {
return comparison;
}
ComparisonResult compare_initializers(const ONNX_NAMESPACE::GraphProto& graph,
const ONNX_NAMESPACE::GraphProto& ref_graph)
{
if (graph.initializer_size() != ref_graph.initializer_size())
{
return ComparisonResult::fail(
"The number of initializers in compared models doesn't match" +
std::to_string(graph.initializer_size()) + " vs " +
std::to_string(ref_graph.initializer_size()));
}
else
{
for (int i = 0; i < graph.initializer_size(); ++i)
{
const auto& lhs = graph.initializer(i);
const auto& rhs = ref_graph.initializer(i);
if (lhs.name() != rhs.name())
{
return ComparisonResult::fail("Initializer names in the graph don't match: " +
lhs.name() + " vs " + rhs.name());
}
else if (lhs.data_type() != rhs.data_type())
{
return ComparisonResult::fail(
"Initializer data types in the graph don't match: " +
std::to_string(lhs.data_type()) + " vs " + std::to_string(rhs.data_type()));
}
else if (lhs.dims_size() != rhs.dims_size())
{
return ComparisonResult::fail("Initializer ranks in the graph don't match: " +
std::to_string(lhs.dims_size()) + " vs " +
std::to_string(rhs.dims_size()));
}
else
{
for (int j = 0; j < lhs.dims_size(); ++j)
{
if (lhs.dims(j) != rhs.dims(j))
{
return ComparisonResult::fail(
"Shape dimensions don't match for initializer " + lhs.name() +
" at index: " + std::to_string(j) + ". " +
std::to_string(lhs.dims(j)) + " vs " + std::to_string(rhs.dims(j)));
}
}
}
}
return ComparisonResult::pass();
}
}
ComparisonResult compare_onnx_graphs(const ONNX_NAMESPACE::GraphProto& graph,
const ONNX_NAMESPACE::GraphProto& ref_graph)
{
ComparisonResult comparison = compare_inputs(graph, ref_graph);
if (!comparison.is_ok)
{
return comparison;
}
comparison = compare_outputs(graph, ref_graph);
if (!comparison.is_ok)
{
return comparison;
}
comparison = compare_initializers(graph, ref_graph);
if (!comparison.is_ok)
{
return comparison;
}
return compare_nodes(graph, ref_graph);
}
} // namespace
namespace ngraph
{
namespace test
{
ComparisonResult compare_onnx_models(const std::string& model,
const std::string& reference_model_path)
{
std::stringstream model_stream{model};
const auto model_proto = onnx_common::parse_from_istream(model_stream);
const auto ref_model = onnx_common::parse_from_file(reference_model_path);
return compare_onnx_graphs(model_proto.graph(), ref_model.graph());
}
} // namespace test
} // namespace ngraph
return compare_nodes(graph, ref_graph);
}
} // namespace
namespace ngraph {
namespace test {
ComparisonResult compare_onnx_models(const std::string& model, const std::string& reference_model_path) {
std::stringstream model_stream{model};
const auto model_proto = onnx_common::parse_from_istream(model_stream);
const auto ref_model = onnx_common::parse_from_file(reference_model_path);
return compare_onnx_graphs(model_proto.graph(), ref_model.graph());
}
} // namespace test
} // namespace ngraph

View File

@@ -6,35 +6,28 @@
#include <string>
namespace ngraph
{
namespace test
{
struct ComparisonResult
{
ComparisonResult() = default;
ComparisonResult(std::string error)
: is_ok{false}
, error_message{std::move(error)}
{
}
ComparisonResult(ComparisonResult&&) = default;
ComparisonResult(const ComparisonResult&) = default;
ComparisonResult& operator=(ComparisonResult&&) = default;
ComparisonResult& operator=(const ComparisonResult&) = default;
namespace ngraph {
namespace test {
struct ComparisonResult {
ComparisonResult() = default;
ComparisonResult(std::string error) : is_ok{false}, error_message{std::move(error)} {}
ComparisonResult(ComparisonResult&&) = default;
ComparisonResult(const ComparisonResult&) = default;
ComparisonResult& operator=(ComparisonResult&&) = default;
ComparisonResult& operator=(const ComparisonResult&) = default;
bool is_ok = true;
std::string error_message;
bool is_ok = true;
std::string error_message;
static ComparisonResult pass() { return {}; }
static ComparisonResult fail(std::string error)
{
return ComparisonResult{std::move(error)};
}
};
static ComparisonResult pass() {
return {};
}
static ComparisonResult fail(std::string error) {
return ComparisonResult{std::move(error)};
}
};
ComparisonResult compare_onnx_models(const std::string& model,
const std::string& reference_model_path);
ComparisonResult compare_onnx_models(const std::string& model, const std::string& reference_model_path);
} // namespace test
} // namespace ngraph
} // namespace test
} // namespace ngraph

View File

@@ -4,26 +4,24 @@
#include "ngraph/provenance.hpp"
namespace ngraph
{
namespace test
{
/// \brief Enable provenance for the duration of a unit test.
///
/// During creation this object activates provenance support, when it's destroyed
/// it returns the provenance support to previous state.
class ProvenanceEnabler
{
public:
ProvenanceEnabler()
{
saved_enable_state = get_provenance_enabled();
set_provenance_enabled(true);
}
~ProvenanceEnabler() { set_provenance_enabled(saved_enable_state); }
private:
bool saved_enable_state;
};
namespace ngraph {
namespace test {
/// \brief Enable provenance for the duration of a unit test.
///
/// During creation this object activates provenance support, when it's destroyed
/// it returns the provenance support to previous state.
class ProvenanceEnabler {
public:
ProvenanceEnabler() {
saved_enable_state = get_provenance_enabled();
set_provenance_enabled(true);
}
}
~ProvenanceEnabler() {
set_provenance_enabled(saved_enable_state);
}
private:
bool saved_enable_state;
};
} // namespace test
} // namespace ngraph

View File

@@ -10,47 +10,38 @@
#include "ngraph/type/element_type.hpp"
#include "test_tools.hpp"
namespace ngraph
{
namespace test
{
/// \brief A predictable pseudo-random number generator
/// The seed is initialized so that we get repeatable pseudo-random numbers for tests
template <typename T>
class Uniform
{
public:
Uniform(T min, T max, T seed = 0)
: m_engine(seed)
, m_distribution(min, max)
, m_r(std::bind(m_distribution, m_engine))
{
}
namespace ngraph {
namespace test {
/// \brief A predictable pseudo-random number generator
/// The seed is initialized so that we get repeatable pseudo-random numbers for tests
template <typename T>
class Uniform {
public:
Uniform(T min, T max, T seed = 0)
: m_engine(seed),
m_distribution(min, max),
m_r(std::bind(m_distribution, m_engine)) {}
/// \brief Randomly initialize a tensor
/// \param ptv The tensor to initialize
const std::shared_ptr<runtime::Tensor>
initialize(const std::shared_ptr<runtime::Tensor>& ptv)
{
std::vector<T> vec = read_vector<T>(ptv);
initialize(vec);
write_vector(ptv, vec);
return ptv;
}
/// \brief Randomly initialize a vector
/// \param vec The tensor to initialize
void initialize(std::vector<T>& vec)
{
for (T& elt : vec)
{
elt = m_r();
}
}
protected:
std::default_random_engine m_engine;
std::uniform_real_distribution<T> m_distribution;
std::function<T()> m_r;
};
/// \brief Randomly initialize a tensor
/// \param ptv The tensor to initialize
const std::shared_ptr<runtime::Tensor> initialize(const std::shared_ptr<runtime::Tensor>& ptv) {
std::vector<T> vec = read_vector<T>(ptv);
initialize(vec);
write_vector(ptv, vec);
return ptv;
}
}
/// \brief Randomly initialize a vector
/// \param vec The tensor to initialize
void initialize(std::vector<T>& vec) {
for (T& elt : vec) {
elt = m_r();
}
}
protected:
std::default_random_engine m_engine;
std::uniform_real_distribution<T> m_distribution;
std::function<T()> m_r;
};
} // namespace test
} // namespace ngraph

View File

@@ -2,20 +2,16 @@
// SPDX-License-Identifier: Apache-2.0
//
#include <ie_core.hpp>
#include "test_case.hpp"
namespace ngraph
{
namespace test
{
std::shared_ptr<Function> function_from_ir(const std::string& xml_path,
const std::string& bin_path)
{
InferenceEngine::Core c;
auto network = c.ReadNetwork(xml_path, bin_path);
return network.getFunction();
}
}
#include <ie_core.hpp>
namespace ngraph {
namespace test {
std::shared_ptr<Function> function_from_ir(const std::string& xml_path, const std::string& bin_path) {
InferenceEngine::Core c;
auto network = c.ReadNetwork(xml_path, bin_path);
return network.getFunction();
}
} // namespace test
} // namespace ngraph

View File

@@ -13,194 +13,165 @@
#include "test_tools.hpp"
#include "util/engine/engine_factory.hpp"
namespace ngraph
{
namespace test
{
std::shared_ptr<Function> function_from_ir(const std::string& xml_path,
const std::string& bin_path = {});
namespace ngraph {
namespace test {
std::shared_ptr<Function> function_from_ir(const std::string& xml_path, const std::string& bin_path = {});
template <typename Engine, TestCaseType tct = TestCaseType::STATIC>
class TestCase
{
public:
TestCase(const std::shared_ptr<Function>& function)
: m_engine{create_engine<Engine>(function, tct)}
, m_function{function}
{
}
template <typename Engine, TestCaseType tct = TestCaseType::STATIC>
class TestCase {
public:
TestCase(const std::shared_ptr<Function>& function)
: m_engine{create_engine<Engine>(function, tct)},
m_function{function} {}
template <typename T>
void add_input(const Shape& shape, const std::vector<T>& values)
{
const auto params = m_function->get_parameters();
NGRAPH_CHECK(m_input_index < params.size(),
"All function parameters already have inputs.");
template <typename T>
void add_input(const Shape& shape, const std::vector<T>& values) {
const auto params = m_function->get_parameters();
NGRAPH_CHECK(m_input_index < params.size(), "All function parameters already have inputs.");
const auto& input_pshape = params.at(m_input_index)->get_partial_shape();
NGRAPH_CHECK(input_pshape.compatible(shape),
"Provided input shape ",
shape,
" is not compatible with nGraph function's expected input shape ",
input_pshape,
" for input ",
m_input_index);
const auto& input_pshape = params.at(m_input_index)->get_partial_shape();
NGRAPH_CHECK(input_pshape.compatible(shape),
"Provided input shape ",
shape,
" is not compatible with nGraph function's expected input shape ",
input_pshape,
" for input ",
m_input_index);
m_engine.template add_input<T>(shape, values);
m_engine.template add_input<T>(shape, values);
++m_input_index;
}
template <typename T>
void add_input(const std::vector<T>& values)
{
const auto& input_pshape =
m_function->get_parameters().at(m_input_index)->get_partial_shape();
NGRAPH_CHECK(input_pshape.is_static(),
"Input number ",
m_input_index,
" in the tested graph has dynamic shape. You need to provide ",
"shape information when setting values for this input.");
add_input<T>(input_pshape.to_shape(), values);
}
template <typename T>
void add_multiple_inputs(const std::vector<std::vector<T>>& vector_of_values)
{
for (const auto& value : vector_of_values)
{
add_input<T>(value);
}
}
template <typename T>
void add_input_from_file(const Shape& shape,
const std::string& basepath,
const std::string& filename)
{
NGRAPH_SUPPRESS_DEPRECATED_START
const auto filepath = ngraph::file_util::path_join(basepath, filename);
add_input_from_file<T>(shape, filepath);
NGRAPH_SUPPRESS_DEPRECATED_END
}
template <typename T>
void add_input_from_file(const std::string& basepath, const std::string& filename)
{
NGRAPH_SUPPRESS_DEPRECATED_START
const auto filepath = ngraph::file_util::path_join(basepath, filename);
add_input_from_file<T>(filepath);
NGRAPH_SUPPRESS_DEPRECATED_END
}
template <typename T>
void add_input_from_file(const Shape& shape, const std::string& filepath)
{
const auto value = read_binary_file<T>(filepath);
add_input<T>(shape, value);
}
template <typename T>
void add_input_from_file(const std::string& filepath)
{
const auto value = read_binary_file<T>(filepath);
add_input<T>(value);
}
template <typename T>
void add_expected_output(const Shape& expected_shape, const std::vector<T>& values)
{
const auto results = m_function->get_results();
NGRAPH_CHECK(m_output_index < results.size(),
"All function results already have expected outputs.");
const auto& output_pshape = results.at(m_output_index)->get_output_partial_shape(0);
NGRAPH_CHECK(output_pshape.compatible(expected_shape),
"Provided expected output shape ",
expected_shape,
" is not compatible with nGraph function's output shape ",
output_pshape,
" for output ",
m_output_index);
m_engine.template add_expected_output<T>(expected_shape, values);
++m_output_index;
}
template <typename T>
void add_expected_output(const std::vector<T>& values)
{
const auto results = m_function->get_results();
NGRAPH_CHECK(m_output_index < results.size(),
"All function results already have expected outputs.");
const auto shape = results.at(m_output_index)->get_shape();
add_expected_output<T>(shape, values);
}
template <typename T>
void add_expected_output_from_file(const ngraph::Shape& expected_shape,
const std::string& basepath,
const std::string& filename)
{
NGRAPH_SUPPRESS_DEPRECATED_START
const auto filepath = ngraph::file_util::path_join(basepath, filename);
add_expected_output_from_file<T>(expected_shape, filepath);
NGRAPH_SUPPRESS_DEPRECATED_END
}
template <typename T>
void add_expected_output_from_file(const ngraph::Shape& expected_shape,
const std::string& filepath)
{
const auto values = read_binary_file<T>(filepath);
add_expected_output<T>(expected_shape, values);
}
void run(const size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS)
{
m_engine.infer();
const auto res = m_engine.compare_results(tolerance_bits);
if (res != testing::AssertionSuccess())
{
std::cout << res.message() << std::endl;
}
m_input_index = 0;
m_output_index = 0;
m_engine.reset();
EXPECT_TRUE(res);
}
void run_with_tolerance_as_fp(const float tolerance = 1.0e-5f)
{
m_engine.infer();
const auto res = m_engine.compare_results_with_tolerance_as_fp(tolerance);
if (res != testing::AssertionSuccess())
{
std::cout << res.message() << std::endl;
}
m_input_index = 0;
m_output_index = 0;
m_engine.reset();
EXPECT_TRUE(res);
}
private:
Engine m_engine;
std::shared_ptr<Function> m_function;
size_t m_input_index = 0;
size_t m_output_index = 0;
};
++m_input_index;
}
}
template <typename T>
void add_input(const std::vector<T>& values) {
const auto& input_pshape = m_function->get_parameters().at(m_input_index)->get_partial_shape();
NGRAPH_CHECK(input_pshape.is_static(),
"Input number ",
m_input_index,
" in the tested graph has dynamic shape. You need to provide ",
"shape information when setting values for this input.");
add_input<T>(input_pshape.to_shape(), values);
}
template <typename T>
void add_multiple_inputs(const std::vector<std::vector<T>>& vector_of_values) {
for (const auto& value : vector_of_values) {
add_input<T>(value);
}
}
template <typename T>
void add_input_from_file(const Shape& shape, const std::string& basepath, const std::string& filename) {
NGRAPH_SUPPRESS_DEPRECATED_START
const auto filepath = ngraph::file_util::path_join(basepath, filename);
add_input_from_file<T>(shape, filepath);
NGRAPH_SUPPRESS_DEPRECATED_END
}
template <typename T>
void add_input_from_file(const std::string& basepath, const std::string& filename) {
NGRAPH_SUPPRESS_DEPRECATED_START
const auto filepath = ngraph::file_util::path_join(basepath, filename);
add_input_from_file<T>(filepath);
NGRAPH_SUPPRESS_DEPRECATED_END
}
template <typename T>
void add_input_from_file(const Shape& shape, const std::string& filepath) {
const auto value = read_binary_file<T>(filepath);
add_input<T>(shape, value);
}
template <typename T>
void add_input_from_file(const std::string& filepath) {
const auto value = read_binary_file<T>(filepath);
add_input<T>(value);
}
template <typename T>
void add_expected_output(const Shape& expected_shape, const std::vector<T>& values) {
const auto results = m_function->get_results();
NGRAPH_CHECK(m_output_index < results.size(), "All function results already have expected outputs.");
const auto& output_pshape = results.at(m_output_index)->get_output_partial_shape(0);
NGRAPH_CHECK(output_pshape.compatible(expected_shape),
"Provided expected output shape ",
expected_shape,
" is not compatible with nGraph function's output shape ",
output_pshape,
" for output ",
m_output_index);
m_engine.template add_expected_output<T>(expected_shape, values);
++m_output_index;
}
template <typename T>
void add_expected_output(const std::vector<T>& values) {
const auto results = m_function->get_results();
NGRAPH_CHECK(m_output_index < results.size(), "All function results already have expected outputs.");
const auto shape = results.at(m_output_index)->get_shape();
add_expected_output<T>(shape, values);
}
template <typename T>
void add_expected_output_from_file(const ngraph::Shape& expected_shape,
const std::string& basepath,
const std::string& filename) {
NGRAPH_SUPPRESS_DEPRECATED_START
const auto filepath = ngraph::file_util::path_join(basepath, filename);
add_expected_output_from_file<T>(expected_shape, filepath);
NGRAPH_SUPPRESS_DEPRECATED_END
}
template <typename T>
void add_expected_output_from_file(const ngraph::Shape& expected_shape, const std::string& filepath) {
const auto values = read_binary_file<T>(filepath);
add_expected_output<T>(expected_shape, values);
}
void run(const size_t tolerance_bits = DEFAULT_FLOAT_TOLERANCE_BITS) {
m_engine.infer();
const auto res = m_engine.compare_results(tolerance_bits);
if (res != testing::AssertionSuccess()) {
std::cout << res.message() << std::endl;
}
m_input_index = 0;
m_output_index = 0;
m_engine.reset();
EXPECT_TRUE(res);
}
void run_with_tolerance_as_fp(const float tolerance = 1.0e-5f) {
m_engine.infer();
const auto res = m_engine.compare_results_with_tolerance_as_fp(tolerance);
if (res != testing::AssertionSuccess()) {
std::cout << res.message() << std::endl;
}
m_input_index = 0;
m_output_index = 0;
m_engine.reset();
EXPECT_TRUE(res);
}
private:
Engine m_engine;
std::shared_ptr<Function> m_function;
size_t m_input_index = 0;
size_t m_output_index = 0;
};
} // namespace test
} // namespace ngraph

View File

@@ -2,62 +2,49 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "util/test_control.hpp"
#include <fstream>
#include <unordered_map>
#include <unordered_set>
#include "ngraph/log.hpp"
#include "ngraph/util.hpp"
#include "util/test_control.hpp"
using namespace std;
using namespace ngraph;
static unordered_set<string>& get_blacklist(const string& backend)
{
static unordered_set<string>& get_blacklist(const string& backend) {
static unordered_map<string, unordered_set<string>> s_blacklists;
return s_blacklists[backend];
}
string ngraph::prepend_disabled(const string& backend_name,
const string& test_name,
const string& manifest)
{
string ngraph::prepend_disabled(const string& backend_name, const string& test_name, const string& manifest) {
string rc = test_name;
unordered_set<string>& blacklist = get_blacklist(backend_name);
if (blacklist.empty() && !manifest.empty())
{
if (blacklist.empty() && !manifest.empty()) {
ifstream f(manifest);
string line;
while (getline(f, line))
{
while (getline(f, line)) {
size_t pound_pos = line.find('#');
line = (pound_pos > line.size()) ? line : line.substr(0, pound_pos);
line = trim(line);
if (line.size() > 1)
{
if (line.size() > 1) {
blacklist.insert(line);
}
}
}
string compound_test_name = backend_name + "." + test_name;
if (blacklist.find(test_name) != blacklist.end() ||
blacklist.find(compound_test_name) != blacklist.end())
{
if (blacklist.find(test_name) != blacklist.end() || blacklist.find(compound_test_name) != blacklist.end()) {
rc = "DISABLED_" + test_name;
}
return rc;
}
string ngraph::combine_test_backend_and_case(const string& backend_name,
const string& test_casename)
{
if (backend_name == test_casename)
{
string ngraph::combine_test_backend_and_case(const string& backend_name, const string& test_casename) {
if (backend_name == test_casename) {
return backend_name;
}
else
{
} else {
return backend_name + "/" + test_casename;
}
}

View File

@@ -8,55 +8,44 @@
// Copied from gtest
namespace ngraph
{
std::string prepend_disabled(const std::string& backend_name,
const std::string& test_name,
const std::string& manifest);
namespace ngraph {
std::string prepend_disabled(const std::string& backend_name,
const std::string& test_name,
const std::string& manifest);
std::string combine_test_backend_and_case(const std::string& backend_name,
const std::string& test_casename);
} // namespace ngraph
std::string combine_test_backend_and_case(const std::string& backend_name, const std::string& test_casename);
} // namespace ngraph
#define NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name) \
#define NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name) \
backend_name##_##test_case_name##_##test_name##_Test
#define NGRAPH_GTEST_TEST_(backend_name, test_case_name, test_name, parent_class, parent_id) \
class NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name) \
: public parent_class \
{ \
public: \
NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name)() {} \
\
private: \
void TestBody() override; \
static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \
GTEST_DISALLOW_COPY_AND_ASSIGN_(NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, \
test_case_name, \
test_name)); \
}; \
\
::testing::TestInfo* const NGRAPH_GTEST_TEST_CLASS_NAME_( \
backend_name, test_case_name, test_name)::test_info_ = \
::testing::internal::MakeAndRegisterTestInfo( \
::ngraph::combine_test_backend_and_case(#backend_name, #test_case_name).c_str(), \
::ngraph::prepend_disabled(#backend_name, #test_name, s_manifest).c_str(), \
nullptr, \
nullptr, \
::testing::internal::CodeLocation(__FILE__, __LINE__), \
(parent_id), \
parent_class::SetUpTestCase, \
parent_class::TearDownTestCase, \
new ::testing::internal::TestFactoryImpl<NGRAPH_GTEST_TEST_CLASS_NAME_( \
backend_name, test_case_name, test_name)>); \
#define NGRAPH_GTEST_TEST_(backend_name, test_case_name, test_name, parent_class, parent_id) \
class NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name) : public parent_class { \
public: \
NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name)() {} \
\
private: \
void TestBody() override; \
static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \
GTEST_DISALLOW_COPY_AND_ASSIGN_(NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name)); \
}; \
\
::testing::TestInfo* const NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name)::test_info_ = \
::testing::internal::MakeAndRegisterTestInfo( \
::ngraph::combine_test_backend_and_case(#backend_name, #test_case_name).c_str(), \
::ngraph::prepend_disabled(#backend_name, #test_name, s_manifest).c_str(), \
nullptr, \
nullptr, \
::testing::internal::CodeLocation(__FILE__, __LINE__), \
(parent_id), \
parent_class::SetUpTestCase, \
parent_class::TearDownTestCase, \
new ::testing::internal::TestFactoryImpl< \
NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name)>); \
void NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name)::TestBody()
#define NGRAPH_TEST(test_case_name, test_name) \
NGRAPH_GTEST_TEST_(test_case_name, \
test_case_name, \
test_name, \
::testing::Test, \
::testing::internal::GetTestTypeId())
#define NGRAPH_TEST(test_case_name, test_name) \
NGRAPH_GTEST_TEST_(test_case_name, test_case_name, test_name, ::testing::Test, ::testing::internal::GetTestTypeId())
// NGRAPH_TEST_F facilitates the use of the same configuration parameters for multiple
// unit tests similar to the original TEST_F, but with the introduction of a new 0th
@@ -76,11 +65,11 @@ namespace ngraph
// should be:
// --gtest_filter=BACKENDNAME*.*
// (rather than the BACKENDNAME.* that worked before the use of NGRAPH_TEST_F)
#define NGRAPH_TEST_F(backend_name, test_fixture, test_name) \
NGRAPH_GTEST_TEST_(backend_name, \
test_fixture, \
test_name, \
test_fixture, \
#define NGRAPH_TEST_F(backend_name, test_fixture, test_name) \
NGRAPH_GTEST_TEST_(backend_name, \
test_fixture, \
test_name, \
test_fixture, \
::testing::internal::GetTypeId<test_fixture>())
// NGRAPH_TEST_P combined with NGRAPH_INSTANTIATE_TEST_SUITE_P facilate the generation
@@ -91,39 +80,30 @@ namespace ngraph
// Start by defining a class derived from ::testing::TestWithParam<T>, which you'll pass
// for the test_case_name parameter.
// Then use NGRAPH_INSTANTIATE_TEST_SUITE_P to define each generation of test cases (see below).
#define NGRAPH_TEST_P(backend_name, test_case_name, test_name) \
class NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name) \
: public test_case_name \
{ \
public: \
NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name)() {} \
void TestBody() override; \
\
private: \
static int AddToRegistry() \
{ \
::testing::UnitTest::GetInstance() \
->parameterized_test_registry() \
.GetTestCasePatternHolder<test_case_name>( \
#backend_name "/" #test_case_name, \
::testing::internal::CodeLocation(__FILE__, __LINE__)) \
->AddTestPattern( \
#backend_name "/" #test_case_name, \
::ngraph::prepend_disabled( \
#backend_name "/" #test_case_name, #test_name, s_manifest) \
.c_str(), \
new ::testing::internal::TestMetaFactory<NGRAPH_GTEST_TEST_CLASS_NAME_( \
backend_name, test_case_name, test_name)>()); \
return 0; \
} \
static int gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \
GTEST_DISALLOW_COPY_AND_ASSIGN_(NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, \
test_case_name, \
test_name)); \
}; \
int NGRAPH_GTEST_TEST_CLASS_NAME_( \
backend_name, test_case_name, test_name)::gtest_registering_dummy_ = \
NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name)::AddToRegistry(); \
#define NGRAPH_TEST_P(backend_name, test_case_name, test_name) \
class NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name) : public test_case_name { \
public: \
NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name)() {} \
void TestBody() override; \
\
private: \
static int AddToRegistry() { \
::testing::UnitTest::GetInstance() \
->parameterized_test_registry() \
.GetTestCasePatternHolder<test_case_name>(#backend_name "/" #test_case_name, \
::testing::internal::CodeLocation(__FILE__, __LINE__)) \
->AddTestPattern( \
#backend_name "/" #test_case_name, \
::ngraph::prepend_disabled(#backend_name "/" #test_case_name, #test_name, s_manifest).c_str(), \
new ::testing::internal::TestMetaFactory< \
NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name)>()); \
return 0; \
} \
static int gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \
GTEST_DISALLOW_COPY_AND_ASSIGN_(NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name)); \
}; \
int NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name)::gtest_registering_dummy_ = \
NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name)::AddToRegistry(); \
void NGRAPH_GTEST_TEST_CLASS_NAME_(backend_name, test_case_name, test_name)::TestBody()
// Use NGRAPH_INSTANTIATE_TEST_SUITE_P to create a generated set of test case variations.
@@ -159,26 +139,22 @@ namespace ngraph
// the filter to run all the tests for a given backend should be:
// --gtest_filter=BACKENDNAME*.*
// (rather than the BACKENDNAME.* that worked before the use of NGRAPH_TEST_P)
#define NGRAPH_INSTANTIATE_TEST_SUITE_P(backend_name, prefix, test_suite_name, generator) \
static ::testing::internal::ParamGenerator<test_suite_name::ParamType> \
gtest_##prefix##backend_name##test_suite_name##_EvalGenerator_() \
{ \
return generator; \
} \
static ::std::string gtest_##prefix##backend_name##test_suite_name##_EvalGenerateName_( \
const ::testing::TestParamInfo<test_suite_name::ParamType>& info) \
{ \
return ::testing::internal::DefaultParamName<test_suite_name::ParamType>(info); \
} \
static int gtest_##prefix##backend_name##test_suite_name##_dummy_ GTEST_ATTRIBUTE_UNUSED_ = \
::testing::UnitTest::GetInstance() \
->parameterized_test_registry() \
.GetTestCasePatternHolder<test_suite_name>( \
#backend_name "/" #test_suite_name, \
::testing::internal::CodeLocation(__FILE__, __LINE__)) \
->AddTestSuiteInstantiation( \
#prefix[0] != '\0' ? #backend_name "/" #prefix : "", \
&gtest_##prefix##backend_name##test_suite_name##_EvalGenerator_, \
&gtest_##prefix##backend_name##test_suite_name##_EvalGenerateName_, \
__FILE__, \
__LINE__)
#define NGRAPH_INSTANTIATE_TEST_SUITE_P(backend_name, prefix, test_suite_name, generator) \
static ::testing::internal::ParamGenerator<test_suite_name::ParamType> \
gtest_##prefix##backend_name##test_suite_name##_EvalGenerator_() { \
return generator; \
} \
static ::std::string gtest_##prefix##backend_name##test_suite_name##_EvalGenerateName_( \
const ::testing::TestParamInfo<test_suite_name::ParamType>& info) { \
return ::testing::internal::DefaultParamName<test_suite_name::ParamType>(info); \
} \
static int gtest_##prefix##backend_name##test_suite_name##_dummy_ GTEST_ATTRIBUTE_UNUSED_ = \
::testing::UnitTest::GetInstance() \
->parameterized_test_registry() \
.GetTestCasePatternHolder<test_suite_name>(#backend_name "/" #test_suite_name, \
::testing::internal::CodeLocation(__FILE__, __LINE__)) \
->AddTestSuiteInstantiation(#prefix[0] != '\0' ? #backend_name "/" #prefix : "", \
&gtest_##prefix##backend_name##test_suite_name##_EvalGenerator_, \
&gtest_##prefix##backend_name##test_suite_name##_EvalGenerateName_, \
__FILE__, \
__LINE__)

View File

@@ -2,11 +2,12 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "test_tools.hpp"
#include <algorithm>
#include "ngraph/ngraph.hpp"
#include "ngraph/util.hpp"
#include "test_tools.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
@@ -15,41 +16,34 @@ using namespace ngraph;
// This function traverses the vector of ops and verifies that each op's dependencies (its inputs)
// is located earlier in the vector. That is enough to be valid
bool validate_list(const vector<shared_ptr<Node>>& nodes)
{
bool validate_list(const vector<shared_ptr<Node>>& nodes) {
bool rc = true;
for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
{
for (auto it = nodes.rbegin(); it != nodes.rend(); it++) {
auto node_tmp = *it;
NodeVector dependencies_tmp;
for (auto& val : node_tmp->input_values())
dependencies_tmp.emplace_back(val.get_node_shared_ptr());
vector<Node*> dependencies;
for (shared_ptr<Node> n : dependencies_tmp)
{
for (shared_ptr<Node> n : dependencies_tmp) {
dependencies.push_back(n.get());
}
auto tmp = it;
for (tmp++; tmp != nodes.rend(); tmp++)
{
for (tmp++; tmp != nodes.rend(); tmp++) {
auto dep_tmp = *tmp;
auto found = find(dependencies.begin(), dependencies.end(), dep_tmp.get());
if (found != dependencies.end())
{
if (found != dependencies.end()) {
dependencies.erase(found);
}
}
if (dependencies.size() > 0)
{
if (dependencies.size() > 0) {
rc = false;
}
}
return rc;
}
shared_ptr<Function> make_test_graph()
{
shared_ptr<Function> make_test_graph() {
auto arg_0 = make_shared<op::Parameter>(element::f32, Shape{2, 2});
auto arg_1 = make_shared<op::Parameter>(element::f32, Shape{2, 2});
auto arg_2 = make_shared<op::Parameter>(element::f32, Shape{2, 2});
@@ -72,169 +66,113 @@ shared_ptr<Function> make_test_graph()
}
template <>
void copy_data<bool>(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<bool>& data)
{
void copy_data<bool>(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<bool>& data) {
std::vector<char> data_char(data.begin(), data.end());
copy_data(tv, data_char);
}
template <>
void init_int_tv<char>(ngraph::runtime::Tensor* tv,
std::default_random_engine& engine,
char min,
char max)
{
void init_int_tv<char>(ngraph::runtime::Tensor* tv, std::default_random_engine& engine, char min, char max) {
size_t size = tv->get_element_count();
std::uniform_int_distribution<int16_t> dist(static_cast<short>(min), static_cast<short>(max));
std::vector<char> vec(size);
for (char& element : vec)
{
for (char& element : vec) {
element = static_cast<char>(dist(engine));
}
tv->write(vec.data(), vec.size() * sizeof(char));
}
template <>
void init_int_tv<int8_t>(ngraph::runtime::Tensor* tv,
std::default_random_engine& engine,
int8_t min,
int8_t max)
{
void init_int_tv<int8_t>(ngraph::runtime::Tensor* tv, std::default_random_engine& engine, int8_t min, int8_t max) {
size_t size = tv->get_element_count();
std::uniform_int_distribution<int16_t> dist(static_cast<short>(min), static_cast<short>(max));
std::vector<int8_t> vec(size);
for (int8_t& element : vec)
{
for (int8_t& element : vec) {
element = static_cast<int8_t>(dist(engine));
}
tv->write(vec.data(), vec.size() * sizeof(int8_t));
}
template <>
void init_int_tv<uint8_t>(ngraph::runtime::Tensor* tv,
std::default_random_engine& engine,
uint8_t min,
uint8_t max)
{
void init_int_tv<uint8_t>(ngraph::runtime::Tensor* tv, std::default_random_engine& engine, uint8_t min, uint8_t max) {
size_t size = tv->get_element_count();
std::uniform_int_distribution<int16_t> dist(static_cast<short>(min), static_cast<short>(max));
std::vector<uint8_t> vec(size);
for (uint8_t& element : vec)
{
for (uint8_t& element : vec) {
element = static_cast<uint8_t>(dist(engine));
}
tv->write(vec.data(), vec.size() * sizeof(uint8_t));
}
void random_init(ngraph::runtime::Tensor* tv, std::default_random_engine& engine)
{
void random_init(ngraph::runtime::Tensor* tv, std::default_random_engine& engine) {
element::Type et = tv->get_element_type();
if (et == element::boolean)
{
if (et == element::boolean) {
init_int_tv<char>(tv, engine, 0, 1);
}
else if (et == element::f32)
{
} else if (et == element::f32) {
init_real_tv<float>(tv, engine, numeric_limits<float>::min(), 1.0f);
}
else if (et == element::f64)
{
} else if (et == element::f64) {
init_real_tv<double>(tv, engine, numeric_limits<double>::min(), 1.0);
}
else if (et == element::i8)
{
} else if (et == element::i8) {
init_int_tv<int8_t>(tv, engine, -1, 1);
}
else if (et == element::i16)
{
} else if (et == element::i16) {
init_int_tv<int16_t>(tv, engine, -1, 1);
}
else if (et == element::i32)
{
} else if (et == element::i32) {
init_int_tv<int32_t>(tv, engine, 0, 1);
}
else if (et == element::i64)
{
} else if (et == element::i64) {
init_int_tv<int64_t>(tv, engine, 0, 1);
}
else if (et == element::u8)
{
} else if (et == element::u8) {
init_int_tv<uint8_t>(tv, engine, 0, 1);
}
else if (et == element::u16)
{
} else if (et == element::u16) {
init_int_tv<uint16_t>(tv, engine, 0, 1);
}
else if (et == element::u32)
{
} else if (et == element::u32) {
init_int_tv<uint32_t>(tv, engine, 0, 1);
}
else if (et == element::u64)
{
} else if (et == element::u64) {
init_int_tv<uint64_t>(tv, engine, 0, 1);
}
else
{
} else {
throw runtime_error("unsupported type");
}
}
template <>
string get_results_str(const std::vector<char>& ref_data,
const std::vector<char>& actual_data,
size_t max_results)
{
string get_results_str(const std::vector<char>& ref_data, const std::vector<char>& actual_data, size_t max_results) {
stringstream ss;
size_t num_results = std::min(static_cast<size_t>(max_results), ref_data.size());
ss << "First " << num_results << " results";
for (size_t i = 0; i < num_results; ++i)
{
for (size_t i = 0; i < num_results; ++i) {
ss << std::endl
<< std::setw(4) << i << " ref: " << std::setw(16) << std::left
<< static_cast<int>(ref_data[i]) << " actual: " << std::setw(16) << std::left
<< static_cast<int>(actual_data[i]);
<< std::setw(4) << i << " ref: " << std::setw(16) << std::left << static_cast<int>(ref_data[i])
<< " actual: " << std::setw(16) << std::left << static_cast<int>(actual_data[i]);
}
ss << std::endl;
return ss.str();
}
::testing::AssertionResult test_ordered_ops(shared_ptr<Function> f, const NodeVector& required_ops)
{
::testing::AssertionResult test_ordered_ops(shared_ptr<Function> f, const NodeVector& required_ops) {
unordered_set<Node*> seen;
for (auto& node_ptr : f->get_ordered_ops())
{
for (auto& node_ptr : f->get_ordered_ops()) {
Node* node = node_ptr.get();
if (seen.count(node) > 0)
{
if (seen.count(node) > 0) {
return ::testing::AssertionFailure() << "Duplication in ordered ops";
}
size_t arg_count = node->get_input_size();
for (size_t i = 0; i < arg_count; ++i)
{
for (size_t i = 0; i < arg_count; ++i) {
Node* dep = node->get_input_node_ptr(i);
if (seen.count(dep) == 0)
{
return ::testing::AssertionFailure()
<< "Argument " << *dep << " does not occur before op" << *node;
if (seen.count(dep) == 0) {
return ::testing::AssertionFailure() << "Argument " << *dep << " does not occur before op" << *node;
}
}
for (auto& dep_ptr : node->get_control_dependencies())
{
if (seen.count(dep_ptr.get()) == 0)
{
for (auto& dep_ptr : node->get_control_dependencies()) {
if (seen.count(dep_ptr.get()) == 0) {
return ::testing::AssertionFailure()
<< "Control dependency " << *dep_ptr << " does not occur before op" << *node;
}
}
seen.insert(node);
}
for (auto& node_ptr : required_ops)
{
if (seen.count(node_ptr.get()) == 0)
{
return ::testing::AssertionFailure()
<< "Required op " << *node_ptr << "does not occur in ordered ops";
for (auto& node_ptr : required_ops) {
if (seen.count(node_ptr.get()) == 0) {
return ::testing::AssertionFailure() << "Required op " << *node_ptr << "does not occur in ordered ops";
}
}
return ::testing::AssertionSuccess();
@@ -242,9 +180,7 @@ string get_results_str(const std::vector<char>& ref_data,
constexpr NodeTypeInfo ngraph::TestOpMultiOut::type_info;
bool ngraph::TestOpMultiOut::evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const
{
bool ngraph::TestOpMultiOut::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
inputs[0]->read(outputs[0]->get_data_ptr(), inputs[0]->get_size_in_bytes());
inputs[1]->read(outputs[1]->get_data_ptr(), inputs[1]->get_size_in_bytes());
return true;

View File

@@ -16,62 +16,52 @@
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "ngraph/type/element_type_traits.hpp"
#include "ngraph/node.hpp"
#include "runtime/backend.hpp"
namespace ngraph
{
class TestOpMultiOut : public op::Op
{
public:
static constexpr NodeTypeInfo type_info{"TestOpMultiOut", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
TestOpMultiOut() = default;
namespace ngraph {
class TestOpMultiOut : public op::Op {
public:
static constexpr NodeTypeInfo type_info{"TestOpMultiOut", 0};
const NodeTypeInfo& get_type_info() const override {
return type_info;
}
TestOpMultiOut() = default;
TestOpMultiOut(const Output<Node>& output_1, const Output<Node>& output_2)
: Op({output_1, output_2})
{
validate_and_infer_types();
}
void validate_and_infer_types() override
{
set_output_size(2);
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
set_output_type(1, get_input_element_type(1), get_input_partial_shape(1));
}
TestOpMultiOut(const Output<Node>& output_1, const Output<Node>& output_2) : Op({output_1, output_2}) {
validate_and_infer_types();
}
void validate_and_infer_types() override {
set_output_size(2);
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
set_output_type(1, get_input_element_type(1), get_input_partial_shape(1));
}
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override
{
return std::make_shared<TestOpMultiOut>(new_args.at(0), new_args.at(1));
}
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) const override;
};
} // namespace ngraph
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override {
return std::make_shared<TestOpMultiOut>(new_args.at(0), new_args.at(1));
}
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
};
} // namespace ngraph
bool validate_list(const std::vector<std::shared_ptr<ngraph::Node>>& nodes);
std::shared_ptr<ngraph::Function> make_test_graph();
template <typename T>
void copy_data(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<T>& data)
{
void copy_data(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<T>& data) {
size_t data_size = data.size() * sizeof(T);
if (data_size > 0)
{
if (data_size > 0) {
tv->write(data.data(), data_size);
}
}
template <ngraph::element::Type_t ET>
ngraph::HostTensorPtr
make_host_tensor(const ngraph::Shape& shape,
const std::vector<typename ngraph::element_type_traits<ET>::value_type>& data)
{
ngraph::HostTensorPtr make_host_tensor(const ngraph::Shape& shape,
const std::vector<typename ngraph::element_type_traits<ET>::value_type>& data) {
NGRAPH_CHECK(shape_size(shape) == data.size(), "Incorrect number of initialization elements");
auto host_tensor = std::make_shared<ngraph::HostTensor>(ET, shape);
copy_data(host_tensor, data);
@@ -82,19 +72,15 @@ template <>
void copy_data<bool>(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<bool>& data);
template <typename T>
void write_vector(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<T>& values)
{
void write_vector(std::shared_ptr<ngraph::runtime::Tensor> tv, const std::vector<T>& values) {
tv->write(values.data(), values.size() * sizeof(T));
}
template <typename T>
std::vector<std::shared_ptr<T>> get_ops_of_type(std::shared_ptr<ngraph::Function> f)
{
std::vector<std::shared_ptr<T>> get_ops_of_type(std::shared_ptr<ngraph::Function> f) {
std::vector<std::shared_ptr<T>> ops;
for (auto op : f->get_ops())
{
if (auto cop = ngraph::as_type_ptr<T>(op))
{
for (auto op : f->get_ops()) {
if (auto cop = ngraph::as_type_ptr<T>(op)) {
ops.push_back(cop);
}
}
@@ -103,13 +89,10 @@ std::vector<std::shared_ptr<T>> get_ops_of_type(std::shared_ptr<ngraph::Function
}
template <typename T>
size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f)
{
size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f) {
size_t count = 0;
for (auto op : f->get_ops())
{
if (ngraph::is_type<T>(op))
{
for (auto op : f->get_ops()) {
if (ngraph::is_type<T>(op)) {
count++;
}
}
@@ -118,26 +101,22 @@ size_t count_ops_of_type(std::shared_ptr<ngraph::Function> f)
}
template <typename T>
void init_int_tv(ngraph::runtime::Tensor* tv, std::default_random_engine& engine, T min, T max)
{
void init_int_tv(ngraph::runtime::Tensor* tv, std::default_random_engine& engine, T min, T max) {
size_t size = tv->get_element_count();
std::uniform_int_distribution<T> dist(min, max);
std::vector<T> vec(size);
for (T& element : vec)
{
for (T& element : vec) {
element = dist(engine);
}
tv->write(vec.data(), vec.size() * sizeof(T));
}
template <typename T>
void init_real_tv(ngraph::runtime::Tensor* tv, std::default_random_engine& engine, T min, T max)
{
void init_real_tv(ngraph::runtime::Tensor* tv, std::default_random_engine& engine, T min, T max) {
size_t size = tv->get_element_count();
std::uniform_real_distribution<T> dist(min, max);
std::vector<T> vec(size);
for (T& element : vec)
{
for (T& element : vec) {
element = dist(engine);
}
tv->write(vec.data(), vec.size() * sizeof(T));
@@ -146,27 +125,22 @@ void init_real_tv(ngraph::runtime::Tensor* tv, std::default_random_engine& engin
void random_init(ngraph::runtime::Tensor* tv, std::default_random_engine& engine);
template <typename T1, typename T2>
std::vector<std::shared_ptr<ngraph::runtime::Tensor>>
prepare_and_run(const std::shared_ptr<ngraph::Function>& function,
std::vector<std::vector<T1>> t1args,
std::vector<std::vector<T2>> t2args,
const std::string& backend_id)
{
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> prepare_and_run(const std::shared_ptr<ngraph::Function>& function,
std::vector<std::vector<T1>> t1args,
std::vector<std::vector<T2>> t2args,
const std::string& backend_id) {
auto backend = ngraph::runtime::Backend::create(backend_id);
auto parms = function->get_parameters();
if (parms.size() != t1args.size() + t2args.size())
{
if (parms.size() != t1args.size() + t2args.size()) {
throw ngraph::ngraph_error("number of parameters and arguments don't match");
}
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> arg_tensors(t1args.size() +
t2args.size());
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> arg_tensors(t1args.size() + t2args.size());
size_t total_arg_count = 0;
for (size_t i = 0; i < t1args.size(); i++)
{
for (size_t i = 0; i < t1args.size(); i++) {
auto t = backend->create_tensor(parms.at(total_arg_count)->get_element_type(),
parms.at(total_arg_count)->get_shape());
auto x = t1args.at(i);
@@ -175,8 +149,7 @@ std::vector<std::shared_ptr<ngraph::runtime::Tensor>>
total_arg_count++;
}
for (size_t i = 0; i < t2args.size(); i++)
{
for (size_t i = 0; i < t2args.size(); i++) {
auto t = backend->create_tensor(parms.at(total_arg_count)->get_element_type(),
parms.at(total_arg_count)->get_shape());
copy_data(t, t2args.at(i));
@@ -187,10 +160,8 @@ std::vector<std::shared_ptr<ngraph::runtime::Tensor>>
auto results = function->get_results();
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> result_tensors(results.size());
for (size_t i = 0; i < results.size(); i++)
{
result_tensors.at(i) =
backend->create_tensor(results.at(i)->get_element_type(), results.at(i)->get_shape());
for (size_t i = 0; i < results.size(); i++) {
result_tensors.at(i) = backend->create_tensor(results.at(i)->get_element_type(), results.at(i)->get_shape());
}
auto handle = backend->compile(function);
@@ -200,11 +171,9 @@ std::vector<std::shared_ptr<ngraph::runtime::Tensor>>
}
template <typename T>
std::vector<std::shared_ptr<ngraph::runtime::Tensor>>
prepare_and_run(const std::shared_ptr<ngraph::Function>& function,
std::vector<std::vector<T>> args,
const std::string& backend_id)
{
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> prepare_and_run(const std::shared_ptr<ngraph::Function>& function,
std::vector<std::vector<T>> args,
const std::string& backend_id) {
std::vector<std::vector<T>> emptyargs;
return prepare_and_run<T, T>(function, args, emptyargs, backend_id);
}
@@ -213,14 +182,12 @@ template <typename TIN1, typename TIN2, typename TOUT>
std::vector<std::vector<TOUT>> execute(const std::shared_ptr<ngraph::Function>& function,
std::vector<std::vector<TIN1>> t1args,
std::vector<std::vector<TIN2>> t2args,
const std::string& backend_id)
{
const std::string& backend_id) {
std::vector<std::shared_ptr<ngraph::runtime::Tensor>> result_tensors =
prepare_and_run(function, t1args, t2args, backend_id);
std::vector<std::vector<TOUT>> result_vectors;
for (auto rt : result_tensors)
{
for (auto rt : result_tensors) {
result_vectors.push_back(read_vector<TOUT>(rt));
}
return result_vectors;
@@ -229,8 +196,7 @@ std::vector<std::vector<TOUT>> execute(const std::shared_ptr<ngraph::Function>&
template <typename TIN, typename TOUT = TIN>
std::vector<std::vector<TOUT>> execute(const std::shared_ptr<ngraph::Function>& function,
std::vector<std::vector<TIN>> args,
const std::string& backend_id)
{
const std::string& backend_id) {
std::vector<std::vector<TIN>> emptyargs;
return execute<TIN, TIN, TOUT>(function, args, emptyargs, backend_id);
}
@@ -238,13 +204,11 @@ std::vector<std::vector<TOUT>> execute(const std::shared_ptr<ngraph::Function>&
template <typename T>
std::string get_results_str(const std::vector<T>& ref_data,
const std::vector<T>& actual_data,
size_t max_results = 16)
{
size_t max_results = 16) {
std::stringstream ss;
size_t num_results = std::min(static_cast<size_t>(max_results), ref_data.size());
ss << "First " << num_results << " results";
for (size_t i = 0; i < num_results; ++i)
{
for (size_t i = 0; i < num_results; ++i) {
ss << std::endl
// use unary + operator to force integral values to be displayed as numbers
<< std::setw(4) << i << " ref: " << std::setw(16) << std::left << +ref_data[i]
@@ -269,35 +233,29 @@ std::string get_results_str(const std::vector<char>& ref_data,
/// \return Return vector of data read from input binary file.
///
template <typename T>
std::vector<T> read_binary_file(const std::string& path)
{
std::vector<T> read_binary_file(const std::string& path) {
std::vector<T> file_content;
std::ifstream inputs_fs{path, std::ios::in | std::ios::binary};
if (!inputs_fs)
{
if (!inputs_fs) {
throw std::runtime_error("Failed to open the file: " + path);
}
inputs_fs.seekg(0, std::ios::end);
auto size = inputs_fs.tellg();
inputs_fs.seekg(0, std::ios::beg);
if (size % sizeof(T) != 0)
{
throw std::runtime_error(
"Error reading binary file content: Input file size (in bytes) "
"is not a multiple of requested data type size.");
if (size % sizeof(T) != 0) {
throw std::runtime_error("Error reading binary file content: Input file size (in bytes) "
"is not a multiple of requested data type size.");
}
file_content.resize(size / sizeof(T));
inputs_fs.read(reinterpret_cast<char*>(file_content.data()), size);
return file_content;
}
testing::AssertionResult test_ordered_ops(std::shared_ptr<ngraph::Function> f,
const ngraph::NodeVector& required_ops);
testing::AssertionResult test_ordered_ops(std::shared_ptr<ngraph::Function> f, const ngraph::NodeVector& required_ops);
template <ngraph::element::Type_t ET>
ngraph::HostTensorPtr make_host_tensor(const ngraph::Shape& shape)
{
ngraph::HostTensorPtr make_host_tensor(const ngraph::Shape& shape) {
auto host_tensor = std::make_shared<ngraph::HostTensor>(ET, shape);
static std::default_random_engine engine(2112);
random_init(host_tensor.get(), engine);

View File

@@ -6,14 +6,11 @@
#include "gtest/gtest.h"
#define EXPECT_HAS_SUBSTRING(haystack, needle) \
EXPECT_PRED_FORMAT2(testing::IsSubstring, needle, haystack)
#define EXPECT_HAS_SUBSTRING(haystack, needle) EXPECT_PRED_FORMAT2(testing::IsSubstring, needle, haystack)
struct PrintToDummyParamName
{
struct PrintToDummyParamName {
template <class ParamType>
std::string operator()(const ::testing::TestParamInfo<ParamType>& info) const
{
std::string operator()(const ::testing::TestParamInfo<ParamType>& info) const {
return "dummy" + std::to_string(info.index);
}
};

View File

@@ -14,398 +14,347 @@
#include "ngraph/ops.hpp"
#include "ngraph/runtime/host_tensor.hpp"
namespace ngraph
{
namespace test
{
class ValueHolder
{
template <typename T>
T& invalid()
{
NGRAPH_CHECK(false, "Invalid type access");
}
namespace ngraph {
namespace test {
class ValueHolder {
template <typename T>
T& invalid() {
NGRAPH_CHECK(false, "Invalid type access");
}
public:
virtual ~ValueHolder() {}
virtual operator bool&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator float&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator double&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::string&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator int8_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator int16_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator int32_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator int64_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator uint8_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator uint16_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator uint32_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator uint64_t&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<std::string>&()
{
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<float>&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<double>&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<int8_t>&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<int16_t>&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<int32_t>&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<int64_t>&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<uint8_t>&() { NGRAPH_CHECK(false, "Invalid type access"); }
virtual operator std::vector<uint16_t>&()
{
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<uint32_t>&()
{
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<uint64_t>&()
{
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator HostTensorPtr&() { NGRAPH_CHECK(false, "Invalid type access"); }
uint64_t get_index() { return m_index; }
public:
virtual ~ValueHolder() {}
virtual operator bool&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator float&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator double&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::string&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator int8_t&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator int16_t&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator int32_t&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator int64_t&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator uint8_t&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator uint16_t&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator uint32_t&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator uint64_t&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<std::string>&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<float>&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<double>&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<int8_t>&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<int16_t>&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<int32_t>&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<int64_t>&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<uint8_t>&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<uint16_t>&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<uint32_t>&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator std::vector<uint64_t>&() {
NGRAPH_CHECK(false, "Invalid type access");
}
virtual operator HostTensorPtr&() {
NGRAPH_CHECK(false, "Invalid type access");
}
uint64_t get_index() {
return m_index;
}
protected:
uint64_t m_index{0};
};
protected:
uint64_t m_index{0};
};
template <typename T>
class ValueHolderImp : public ValueHolder
{
public:
ValueHolderImp(const T& value, uint64_t index)
: m_value(value)
{
m_index = index;
}
operator T&() override { return m_value; }
template <typename T>
class ValueHolderImp : public ValueHolder {
public:
ValueHolderImp(const T& value, uint64_t index) : m_value(value) {
m_index = index;
}
operator T&() override {
return m_value;
}
protected:
T m_value;
};
protected:
T m_value;
};
class ValueMap
{
using map_type = std::unordered_map<std::string, std::shared_ptr<ValueHolder>>;
class ValueMap {
using map_type = std::unordered_map<std::string, std::shared_ptr<ValueHolder>>;
public:
/// \brief Set to print serialization information
void set_print(bool value) { m_print = value; }
template <typename T>
void insert(const std::string& name, const T& value)
{
std::pair<map_type::iterator, bool> result = m_values.insert(map_type::value_type(
name, std::make_shared<ValueHolderImp<T>>(value, m_write_count++)));
NGRAPH_CHECK(result.second, name, " is already in use");
}
template <typename T>
void insert_scalar(const std::string& name, const T& value)
{
std::pair<map_type::iterator, bool> result = m_values.insert(map_type::value_type(
name, std::make_shared<ValueHolderImp<T>>(value, m_write_count++)));
NGRAPH_CHECK(result.second, name, " is already in use");
if (m_print)
{
std::cerr << "SER: " << name << " = " << value << std::endl;
}
}
template <typename T>
void insert_vector(const std::string& name, const T& value)
{
std::pair<map_type::iterator, bool> result = m_values.insert(map_type::value_type(
name, std::make_shared<ValueHolderImp<T>>(value, m_write_count++)));
NGRAPH_CHECK(result.second, name, " is already in use");
if (m_print)
{
std::cerr << "SER: " << name << " = [";
std::string comma = "";
for (auto val : value)
{
std::cerr << comma << val;
comma = ", ";
}
std::cerr << "]" << std::endl;
}
public:
/// \brief Set to print serialization information
void set_print(bool value) {
m_print = value;
}
template <typename T>
void insert(const std::string& name, const T& value) {
std::pair<map_type::iterator, bool> result =
m_values.insert(map_type::value_type(name, std::make_shared<ValueHolderImp<T>>(value, m_write_count++)));
NGRAPH_CHECK(result.second, name, " is already in use");
}
template <typename T>
void insert_scalar(const std::string& name, const T& value) {
std::pair<map_type::iterator, bool> result =
m_values.insert(map_type::value_type(name, std::make_shared<ValueHolderImp<T>>(value, m_write_count++)));
NGRAPH_CHECK(result.second, name, " is already in use");
if (m_print) {
std::cerr << "SER: " << name << " = " << value << std::endl;
}
}
template <typename T>
void insert_vector(const std::string& name, const T& value) {
std::pair<map_type::iterator, bool> result =
m_values.insert(map_type::value_type(name, std::make_shared<ValueHolderImp<T>>(value, m_write_count++)));
NGRAPH_CHECK(result.second, name, " is already in use");
if (m_print) {
std::cerr << "SER: " << name << " = [";
std::string comma = "";
for (auto val : value) {
std::cerr << comma << val;
comma = ", ";
}
std::cerr << "]" << std::endl;
}
}
std::size_t get_value_map_size() const
{
return m_values.size();
}
std::size_t get_value_map_size() const {
return m_values.size();
}
template <typename T>
T& get(const std::string& name)
{
auto& value_holder = *m_values.at(name);
NGRAPH_CHECK(m_read_count++ == value_holder.get_index());
return static_cast<T&>(*m_values.at(name));
}
template <typename T>
T& get(const std::string& name) {
auto& value_holder = *m_values.at(name);
NGRAPH_CHECK(m_read_count++ == value_holder.get_index());
return static_cast<T&>(*m_values.at(name));
}
protected:
map_type m_values;
uint64_t m_write_count{0};
uint64_t m_read_count{0};
bool m_print{false};
};
protected:
map_type m_values;
uint64_t m_write_count{0};
uint64_t m_read_count{0};
bool m_print{false};
};
class DeserializeAttributeVisitor : public AttributeVisitor
{
public:
DeserializeAttributeVisitor(ValueMap& value_map)
: m_values(value_map)
{
}
void on_adapter(const std::string& name, ValueAccessor<void>& adapter) override
{
if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<
std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(&adapter))
{
auto& data = m_values.get<HostTensorPtr>(name);
data->read(a->get()->get_ptr(), a->get()->size());
}
else
{
NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be unmarshalled");
}
}
// The remaining adapter methods fall back on the void adapter if not implemented
void on_adapter(const std::string& name, ValueAccessor<std::string>& adapter) override
{
adapter.set(m_values.get<std::string>(name));
};
void on_adapter(const std::string& name, ValueAccessor<bool>& adapter) override
{
adapter.set(m_values.get<bool>(name));
};
void on_adapter(const std::string& name, ValueAccessor<int64_t>& adapter) override
{
adapter.set(m_values.get<int64_t>(name));
}
void on_adapter(const std::string& name, ValueAccessor<double>& adapter) override
{
adapter.set(m_values.get<double>(name));
}
class DeserializeAttributeVisitor : public AttributeVisitor {
public:
DeserializeAttributeVisitor(ValueMap& value_map) : m_values(value_map) {}
void on_adapter(const std::string& name, ValueAccessor<void>& adapter) override {
if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
&adapter)) {
auto& data = m_values.get<HostTensorPtr>(name);
data->read(a->get()->get_ptr(), a->get()->size());
} else {
NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be unmarshalled");
}
}
// The remaining adapter methods fall back on the void adapter if not implemented
void on_adapter(const std::string& name, ValueAccessor<std::string>& adapter) override {
adapter.set(m_values.get<std::string>(name));
};
void on_adapter(const std::string& name, ValueAccessor<bool>& adapter) override {
adapter.set(m_values.get<bool>(name));
};
void on_adapter(const std::string& name, ValueAccessor<int64_t>& adapter) override {
adapter.set(m_values.get<int64_t>(name));
}
void on_adapter(const std::string& name, ValueAccessor<double>& adapter) override {
adapter.set(m_values.get<double>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int8_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<int8_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int16_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<int16_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int32_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<int32_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int64_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<int64_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint8_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<uint8_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint16_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<uint16_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint32_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<uint32_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint64_t>>& adapter) override
{
adapter.set(m_values.get<std::vector<uint64_t>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<std::string>>& adapter) override
{
adapter.set(m_values.get<std::vector<std::string>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<float>>& adapter) override
{
adapter.set(m_values.get<std::vector<float>>(name));
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<double>>& adapter) override
{
adapter.set(m_values.get<std::vector<double>>(name));
}
void on_adapter(const std::string& name, ValueAccessor<void*>& adapter) override
{
HostTensorPtr& data = m_values.get<HostTensorPtr>(name);
data->read(adapter.get_ptr(), adapter.size());
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<int8_t>>& adapter) override {
adapter.set(m_values.get<std::vector<int8_t>>(name));
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<int16_t>>& adapter) override {
adapter.set(m_values.get<std::vector<int16_t>>(name));
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<int32_t>>& adapter) override {
adapter.set(m_values.get<std::vector<int32_t>>(name));
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<int64_t>>& adapter) override {
adapter.set(m_values.get<std::vector<int64_t>>(name));
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<uint8_t>>& adapter) override {
adapter.set(m_values.get<std::vector<uint8_t>>(name));
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<uint16_t>>& adapter) override {
adapter.set(m_values.get<std::vector<uint16_t>>(name));
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<uint32_t>>& adapter) override {
adapter.set(m_values.get<std::vector<uint32_t>>(name));
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<uint64_t>>& adapter) override {
adapter.set(m_values.get<std::vector<uint64_t>>(name));
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<std::string>>& adapter) override {
adapter.set(m_values.get<std::vector<std::string>>(name));
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<float>>& adapter) override {
adapter.set(m_values.get<std::vector<float>>(name));
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<double>>& adapter) override {
adapter.set(m_values.get<std::vector<double>>(name));
}
void on_adapter(const std::string& name, ValueAccessor<void*>& adapter) override {
HostTensorPtr& data = m_values.get<HostTensorPtr>(name);
data->read(adapter.get_ptr(), adapter.size());
}
protected:
ValueMap& m_values;
};
protected:
ValueMap& m_values;
};
class SerializeAttributeVisitor : public AttributeVisitor
{
public:
SerializeAttributeVisitor(ValueMap& value_map)
: m_values(value_map)
{
}
class SerializeAttributeVisitor : public AttributeVisitor {
public:
SerializeAttributeVisitor(ValueMap& value_map) : m_values(value_map) {}
void on_adapter(const std::string& name, ValueAccessor<void>& adapter) override
{
if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<
std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(&adapter))
{
HostTensorPtr data =
std::make_shared<HostTensor>(element::u8, Shape{a->get()->size()});
data->write(a->get()->get_ptr(), a->get()->size());
m_values.insert(name, data);
}
else
{
NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be marshalled");
}
}
// The remaining adapter methods fall back on the void adapter if not implemented
void on_adapter(const std::string& name, ValueAccessor<std::string>& adapter) override
{
m_values.insert_scalar(name, adapter.get());
};
void on_adapter(const std::string& name, ValueAccessor<bool>& adapter) override
{
m_values.insert_scalar(name, adapter.get());
};
void on_adapter(const std::string& name, ValueAccessor<void>& adapter) override {
if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<std::shared_ptr<ngraph::runtime::AlignedBuffer>>>(
&adapter)) {
HostTensorPtr data = std::make_shared<HostTensor>(element::u8, Shape{a->get()->size()});
data->write(a->get()->get_ptr(), a->get()->size());
m_values.insert(name, data);
} else {
NGRAPH_CHECK(false, "Attribute \"", name, "\" cannot be marshalled");
}
}
// The remaining adapter methods fall back on the void adapter if not implemented
void on_adapter(const std::string& name, ValueAccessor<std::string>& adapter) override {
m_values.insert_scalar(name, adapter.get());
};
void on_adapter(const std::string& name, ValueAccessor<bool>& adapter) override {
m_values.insert_scalar(name, adapter.get());
};
void on_adapter(const std::string& name, ValueAccessor<int64_t>& adapter) override
{
m_values.insert_scalar(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<double>& adapter) override
{
m_values.insert_scalar(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<std::string>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<float>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<double>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int8_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int16_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int32_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<int64_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint8_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint16_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint32_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name,
ValueAccessor<std::vector<uint64_t>>& adapter) override
{
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<void*>& adapter) override
{
HostTensorPtr data =
std::make_shared<HostTensor>(element::u8, Shape{adapter.size()});
data->write(adapter.get_ptr(), adapter.size());
m_values.insert(name, data);
}
void on_adapter(const std::string& name, ValueAccessor<int64_t>& adapter) override {
m_values.insert_scalar(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<double>& adapter) override {
m_values.insert_scalar(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<std::string>>& adapter) override {
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<float>>& adapter) override {
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<double>>& adapter) override {
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<int8_t>>& adapter) override {
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<int16_t>>& adapter) override {
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<int32_t>>& adapter) override {
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<int64_t>>& adapter) override {
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<uint8_t>>& adapter) override {
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<uint16_t>>& adapter) override {
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<uint32_t>>& adapter) override {
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<std::vector<uint64_t>>& adapter) override {
m_values.insert_vector(name, adapter.get());
}
void on_adapter(const std::string& name, ValueAccessor<void*>& adapter) override {
HostTensorPtr data = std::make_shared<HostTensor>(element::u8, Shape{adapter.size()});
data->write(adapter.get_ptr(), adapter.size());
m_values.insert(name, data);
}
protected:
ValueMap& m_values;
};
protected:
ValueMap& m_values;
};
class NodeBuilder : public ValueMap, public DeserializeAttributeVisitor
{
public:
NodeBuilder()
: DeserializeAttributeVisitor(static_cast<ValueMap&>(*this))
, m_serializer(*this)
{
}
class NodeBuilder : public ValueMap, public DeserializeAttributeVisitor {
public:
NodeBuilder() : DeserializeAttributeVisitor(static_cast<ValueMap&>(*this)), m_serializer(*this) {}
NodeBuilder(const std::shared_ptr<Node>& node)
: DeserializeAttributeVisitor(static_cast<ValueMap&>(*this))
, m_serializer(*this)
{
save_node(node);
}
NodeBuilder(const std::shared_ptr<Node>& node)
: DeserializeAttributeVisitor(static_cast<ValueMap&>(*this)),
m_serializer(*this) {
save_node(node);
}
void save_node(std::shared_ptr<Node> node)
{
m_node_type_info = node->get_type_info();
node->visit_attributes(m_serializer);
}
void save_node(std::shared_ptr<Node> node) {
m_node_type_info = node->get_type_info();
node->visit_attributes(m_serializer);
}
// Does not validate, since inputs aren't set
std::shared_ptr<Node> create()
{
std::shared_ptr<Node> node(get_ops().create(m_node_type_info));
node->visit_attributes(*this);
return node;
}
AttributeVisitor& get_node_saver() { return m_serializer; }
AttributeVisitor& get_node_loader() { return *this; }
static FactoryRegistry<Node>& get_ops()
{
static FactoryRegistry<Node> registry = [] {
FactoryRegistry<Node> registry;
// Does not validate, since inputs aren't set
std::shared_ptr<Node> create() {
std::shared_ptr<Node> node(get_ops().create(m_node_type_info));
node->visit_attributes(*this);
return node;
}
AttributeVisitor& get_node_saver() {
return m_serializer;
}
AttributeVisitor& get_node_loader() {
return *this;
}
static FactoryRegistry<Node>& get_ops() {
static FactoryRegistry<Node> registry = [] {
FactoryRegistry<Node> registry;
#define NGRAPH_OP(NAME, NAMESPACE, VERSION) registry.register_factory<NAMESPACE::NAME>();
#include "op_version_tbl.hpp"
#undef NGRAPH_OP
return registry;
}();
return registry;
}
protected:
Node::type_info_t m_node_type_info;
SerializeAttributeVisitor m_serializer;
};
return registry;
}();
return registry;
}
}
protected:
Node::type_info_t m_node_type_info;
SerializeAttributeVisitor m_serializer;
};
} // namespace test
} // namespace ngraph