Python API For compare_functions (#3938)

* Added python API for compare_functions

* Fixed compare_funcion constant comparision, graph traversal

* Add tests for python API functions

* Move CompareNetworks to separate python module

* Update python API tests

* Added dev package support

* ENABLE_TESTS

* Update constant comparator

* Fix merge conflict
This commit is contained in:
Gleb Kazantaev 2021-01-22 23:37:50 +03:00 committed by GitHub
parent 2d39555191
commit 94b2cc1dad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 181 additions and 27 deletions

View File

@ -55,6 +55,10 @@ set (PYTHON_BRIDGE_SRC_ROOT ${CMAKE_CURRENT_SOURCE_DIR})
add_subdirectory (src/openvino/inference_engine)
add_subdirectory (src/openvino/offline_transformations)
if (ENABLE_TESTS)
add_subdirectory(src/openvino/test_utils)
endif()
# Check Cython version
if(CYTHON_VERSION VERSION_LESS "0.29")
message(FATAL_ERROR "OpenVINO Python API needs at least Cython version 0.29, found version ${CYTHON_VERSION}")

View File

@ -0,0 +1,45 @@
# Copyright (C) 2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
set(TARGET_NAME "test_utils_api")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PYTHON_BRIDGE_OUTPUT_DIRECTORY}/test_utils)
file(GLOB SOURCE
${CMAKE_CURRENT_SOURCE_DIR}/test_utils_api.pyx
${CMAKE_CURRENT_SOURCE_DIR}/test_utils_api_impl.cpp)
set_source_files_properties(${SOURCE} PROPERTIES CYTHON_IS_CXX ON)
# create target
cython_add_module(${TARGET_NAME} ${SOURCE})
set(INSTALLED_TARGETS ${TARGET_NAME})
add_dependencies(${TARGET_NAME} ie_api)
if(COMMAND ie_add_vs_version_file)
foreach(target IN LISTS INSTALLED_TARGETS)
ie_add_vs_version_file(NAME ${target}
FILEDESCRIPTION "Test Utils Python library")
endforeach()
endif()
if(TARGET commonTestUtils)
list(APPEND InferenceEngine_LIBRARIES commonTestUtils)
else()
list(APPEND InferenceEngine_LIBRARIES IE::commonTestUtils)
endif()
target_include_directories(${TARGET_NAME} PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}/../inference_engine")
target_link_libraries(${TARGET_NAME} PRIVATE ${InferenceEngine_LIBRARIES})
# Compatibility with python 2.7 which has deprecated "register" specifier
if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
target_compile_options(${TARGET_NAME} PRIVATE "-Wno-error=register")
endif()
# perform copy
add_custom_command(TARGET ${TARGET_NAME}
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${PYTHON_BRIDGE_SRC_ROOT}/src/openvino/test_utils/__init__.py ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/__init__.py
)

View File

@ -0,0 +1,2 @@
from .test_utils_api import *
__all__ = ['CompareNetworks']

View File

@ -0,0 +1,26 @@
"""
Copyright (C) 2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the 'License');
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an 'AS IS' BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from .cimport test_utils_api_impl_defs as C
from ..inference_engine.ie_api cimport IENetwork
from libcpp cimport bool
from libcpp.string cimport string
from libcpp.pair cimport pair
def CompareNetworks(IENetwork lhs, IENetwork rhs):
cdef pair[bool, string] c_pair
c_pair = C.CompareNetworks(lhs.impl, rhs.impl)
return c_pair

View File

@ -0,0 +1,14 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "test_utils_api_impl.hpp"
#include <string>
#include <common_test_utils/ngraph_test_utils.hpp>
std::pair<bool, std::string> InferenceEnginePython::CompareNetworks(InferenceEnginePython::IENetwork lhs,
InferenceEnginePython::IENetwork rhs) {
return compare_functions(lhs.actual->getFunction(), rhs.actual->getFunction(), true, true, false, true);
}

View File

@ -0,0 +1,14 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "Python.h"
#include "ie_api_impl.hpp"
namespace InferenceEnginePython {
std::pair<bool, std::string> CompareNetworks(InferenceEnginePython::IENetwork, InferenceEnginePython::IENetwork);
}; // namespace InferenceEnginePython

View File

@ -0,0 +1,8 @@
from libcpp cimport bool
from libcpp.string cimport string
from libcpp.pair cimport pair
from ..inference_engine.ie_api_impl_defs cimport IENetwork
cdef extern from "test_utils_api_impl.hpp" namespace "InferenceEnginePython":
cdef pair[bool, string] CompareNetworks(IENetwork lhs, IENetwork rhs)

View File

@ -1,5 +1,5 @@
from openvino.inference_engine import IECore, IENetwork
from openvino.offline_transformations import ApplyMOCTransformations
from openvino.offline_transformations import ApplyMOCTransformations, ApplyLowLatencyTransformation
import ngraph as ng
from ngraph.impl.op import Parameter
@ -10,8 +10,7 @@ from conftest import model_path
test_net_xml, test_net_bin = model_path()
def test_offline_api():
def get_test_cnnnetwork():
element_type = Type.f32
param = Parameter(element_type, Shape([1, 3, 22, 22]))
relu = ng.relu(param)
@ -20,9 +19,22 @@ def test_offline_api():
cnnNetwork = IENetwork(caps)
assert cnnNetwork != None
return cnnNetwork
ApplyMOCTransformations(cnnNetwork, False)
func2 = ng.function_from_cnn(cnnNetwork)
assert func2 != None
assert len(func2.get_ops()) == 3
def test_moc_transformations():
net = get_test_cnnnetwork()
ApplyMOCTransformations(net, False)
f = ng.function_from_cnn(net)
assert f != None
assert len(f.get_ops()) == 3
def test_low_latency_transformations():
net = get_test_cnnnetwork()
ApplyLowLatencyTransformation(net)
f = ng.function_from_cnn(net)
assert f != None
assert len(f.get_ops()) == 3

View File

@ -0,0 +1,26 @@
from openvino.inference_engine import IECore, IENetwork
import ngraph as ng
from ngraph.impl.op import Parameter
from ngraph.impl import Function, Shape, Type
def get_test_cnnnetwork():
element_type = Type.f32
param = Parameter(element_type, Shape([1, 3, 22, 22]))
relu = ng.relu(param)
func = Function([relu], [param], 'test')
caps = Function.to_capsule(func)
cnnNetwork = IENetwork(caps)
assert cnnNetwork != None
return cnnNetwork
def test_compare_networks():
try:
from openvino.test_utils import CompareNetworks
net = get_test_cnnnetwork()
status, msg = CompareNetworks(net, net)
assert status
except:
print("openvino.test_utils.CompareNetworks is not available")

View File

@ -11,6 +11,7 @@
#include <ngraph/function.hpp>
#include <ngraph/op/util/op_types.hpp>
#include <ngraph/op/util/sub_graph_base.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/pass/visualize_tree.hpp>
@ -82,18 +83,6 @@ std::string name(const Node& n) {
return n->get_friendly_name();
}
template <typename Constant>
bool equal(const Constant& c1, const Constant& c2) {
const auto equal_float_str = [](const std::string& s1, const std::string s2) {
return std::abs(std::stof(s1) - std::stof(s2)) < 0.001;
};
const auto& c1v = c1.get_value_strings();
const auto& c2v = c2.get_value_strings();
return c1v.size() == c2v.size() &&
std::equal(begin(c1v), end(c1v), begin(c2v), equal_float_str);
}
} // namespace
std::pair<bool, std::string> compare_functions(
@ -132,8 +121,9 @@ std::pair<bool, std::string> compare_functions(
std::ostringstream err_log;
using ComparedNodes = std::pair<std::shared_ptr<ngraph::Node>, std::shared_ptr<ngraph::Node>>;
using ComparedNodes = std::pair<ngraph::Node*, ngraph::Node*>;
std::queue<ComparedNodes> q;
std::unordered_set<ngraph::Node *> used;
for (size_t i = 0; i < f1_results.size(); ++i) {
if (compareNames) {
@ -144,7 +134,8 @@ std::pair<bool, std::string> compare_functions(
" and " + name(f2_results[i]->get_input_node_shared_ptr(0)));
}
}
q.push({f1_results[i], f2_results[i]});
q.push({ f1_results[i].get(), f2_results[i].get() });
used.insert(f1_results[i].get());
}
while (!q.empty()) {
@ -159,8 +150,8 @@ std::pair<bool, std::string> compare_functions(
return error(typeInfoToStr(type_info1) + " != " + typeInfoToStr(type_info2));
}
auto subgraph1 = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node1);
auto subgraph2 = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(node2);
auto subgraph1 = dynamic_cast<ngraph::op::util::SubGraphOp *>(node1);
auto subgraph2 = dynamic_cast<ngraph::op::util::SubGraphOp *>(node2);
if (subgraph1 && subgraph2) {
auto res = compare_functions(subgraph1->get_function(), subgraph2->get_function(),
@ -197,7 +188,18 @@ std::pair<bool, std::string> compare_functions(
auto const1 = ngraph::as_type_ptr<Constant>(node1->get_input_node_shared_ptr(i));
auto const2 = ngraph::as_type_ptr<Constant>(node2->get_input_node_shared_ptr(i));
if (const1 && const2 && !equal(*const1, *const2)) {
const auto equal = [](std::shared_ptr<Constant> c1, std::shared_ptr<Constant> c2) {
const auto &c1v = c1->cast_vector<double>();
const auto &c2v = c2->cast_vector<double>();
return c1v.size() == c2v.size() &&
std::equal(begin(c1v), end(c1v), begin(c2v),
[](const double &s1, const double & s2) {
return std::abs(s1 - s2) < 0.001;
});
};
if (const1 && const2 && !equal(const1, const2)) {
err_log << "Different Constant values detected\n"
<< node1->description() << " Input(" << i << ") and "
<< node2->description() << " Input(" << i << ")" << std::endl;
@ -239,9 +241,10 @@ std::pair<bool, std::string> compare_functions(
<< std::endl;
}
q.push(
{node1->input_value(i).get_node_shared_ptr(),
node2->input_value(i).get_node_shared_ptr()});
if (!used.count(node1->input_value(i).get_node())) {
q.push({node1->input_value(i).get_node(), node2->input_value(i).get_node()});
used.insert(node1->input_value(i).get_node());
}
}
for (int i = 0; i < node1->outputs().size(); ++i) {