Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Steve Yoo 2021-04-28 07:50:13 +09:00
commit a6b2800be2
25 changed files with 258 additions and 132 deletions

View File

@ -14,6 +14,7 @@
namespace TemplatePlugin {
// forward declaration
class Plugin;
/**

View File

@ -8,17 +8,6 @@
#include <string>
#include <map>
#include <ie_blob.h>
#include <description_buffer.hpp>
#include <debug.h>
#include <ie_layouts.h>
#include <threading/ie_executor_manager.hpp>
#include <blob_transform.hpp>
#include <ie_parallel.hpp>
#include <ie_memcpy.h>
#include <precision_utils.h>
#include "template/template_config.hpp"
#include "template_infer_request.hpp"
#include "template_executable_network.hpp"
#include "template_plugin.hpp"

View File

@ -7,22 +7,21 @@
#include <map>
#include <string>
#include <vector>
#include <array>
#include <memory>
#include <unordered_map>
#include <chrono>
#include <ie_common.h>
#include <cpp_interfaces/impl/ie_executable_network_internal.hpp>
#include <threading/ie_itask_executor.hpp>
#include <openvino/itt.hpp>
#include <ie_input_info.hpp>
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
#include <ngraph/runtime/tensor.hpp>
#include <executable.hpp>
#include "template_config.hpp"
namespace TemplatePlugin {
// forward declaration
class ExecutableNetwork;
// ! [infer_request:header]

View File

@ -81,50 +81,19 @@ InferenceEngine::ExecutableNetworkInternal::Ptr Plugin::LoadExeNetworkImpl(const
const ConfigMap &config) {
OV_ITT_SCOPED_TASK(itt::domains::TemplatePlugin, "Plugin::LoadExeNetworkImpl");
auto cfg = Configuration{ config, _cfg };
InferenceEngine::InputsDataMap networkInputs = network.getInputsInfo();
InferenceEngine::OutputsDataMap networkOutputs = network.getOutputsInfo();
// TODO: check with precisions supported by Template device
for (auto networkOutput : networkOutputs) {
auto output_precision = networkOutput.second->getPrecision();
if (output_precision != InferenceEngine::Precision::FP32 &&
output_precision != InferenceEngine::Precision::FP16 &&
output_precision != InferenceEngine::Precision::U8) {
IE_THROW() << "Template device supports only U8, FP16 and FP32 output precision.";
}
}
for (auto networkInput : networkInputs) {
auto input_precision = networkInput.second->getTensorDesc().getPrecision();
if (input_precision != InferenceEngine::Precision::FP32 &&
input_precision != InferenceEngine::Precision::FP16 &&
input_precision != InferenceEngine::Precision::I16 &&
input_precision != InferenceEngine::Precision::U8) {
IE_THROW() << "Input image format " << input_precision << " is not supported yet.\n"
<< "Supported formats are: FP32, FP16, I16 and U8.";
}
}
auto function = network.getFunction();
if (function == nullptr) {
IE_THROW() << "TEMPLATE plugin can compile only IR v10 networks";
}
return std::make_shared<ExecutableNetwork>(function, cfg, std::static_pointer_cast<Plugin>(shared_from_this()));
auto fullConfig = Configuration{ config, _cfg };
return std::make_shared<ExecutableNetwork>(network.getFunction(), fullConfig,
std::static_pointer_cast<Plugin>(shared_from_this()));
}
// ! [plugin:load_exe_network_impl]
// ! [plugin:import_network_impl]
InferenceEngine::ExecutableNetworkInternal::Ptr
Plugin::ImportNetworkImpl(std::istream& model, const std::map<std::string, std::string>& config) {
Plugin::ImportNetworkImpl(std::istream& modelStream, const std::map<std::string, std::string>& config) {
OV_ITT_SCOPED_TASK(itt::domains::TemplatePlugin, "Plugin::ImportNetworkImpl");
Configuration cfg(config);
return std::make_shared<ExecutableNetwork>(model, cfg,
auto fullConfig = Configuration{ config, _cfg };
return std::make_shared<ExecutableNetwork>(modelStream, fullConfig,
std::static_pointer_cast<Plugin>(shared_from_this()));
}
// ! [plugin:import_network_impl]
@ -133,13 +102,8 @@ Plugin::ImportNetworkImpl(std::istream& model, const std::map<std::string, std::
InferenceEngine::QueryNetworkResult Plugin::QueryNetwork(const InferenceEngine::CNNNetwork &network, const ConfigMap& config) const {
OV_ITT_SCOPED_TASK(itt::domains::TemplatePlugin, "Plugin::QueryNetwork");
InferenceEngine::QueryNetworkResult res;
Configuration cfg{config, _cfg, false};
Configuration fullConfig{config, _cfg, false};
auto function = network.getFunction();
if (function == nullptr) {
IE_THROW() << "Template Plugin supports only ngraph cnn network representation";
}
// 1. First of all we should store initial input operation set
std::unordered_set<std::string> originalOps;
@ -207,6 +171,7 @@ InferenceEngine::QueryNetworkResult Plugin::QueryNetwork(const InferenceEngine::
}
// 7. Produce the result
InferenceEngine::QueryNetworkResult res;
for (auto&& layerName : supported) {
res.supportedLayersMap.emplace(layerName, GetName());
}

View File

@ -114,16 +114,6 @@ install(FILES samples/CMakeLists.txt
DESTINATION ${IE_CPACK_IE_DIR}/samples/c
COMPONENT c_samples)
# install Python samples
if(ENABLE_PYTHON)
ie_cpack_add_component(python_samples DEPENDS core)
install(DIRECTORY ${ie_python_api_SOURCE_DIR}/sample/
DESTINATION ${IE_CPACK_IE_DIR}/samples/python
COMPONENT python_samples)
endif()
# install speech demo files
if(SPEECH_LIBS_AND_DEMOS)

View File

@ -90,4 +90,12 @@ install(PROGRAMS src/openvino/__init__.py
DESTINATION ${PYTHON_BRIDGE_CPACK_PATH}/${PYTHON_VERSION}/openvino
COMPONENT ${PYTHON_VERSION})
ie_cpack(${PYTHON_VERSION})
# install Python samples
ie_cpack_add_component(python_samples)
install(DIRECTORY sample/
DESTINATION ${IE_CPACK_IE_DIR}/samples/python
COMPONENT python_samples)
ie_cpack(${PYTHON_VERSION} python_samples)

View File

@ -4,11 +4,9 @@
#include "test_utils_api_impl.hpp"
#include <common_test_utils/ngraph_test_utils.hpp>
#include <string>
#include <common_test_utils/ngraph_test_utils.hpp>
std::pair<bool, std::string> InferenceEnginePython::CompareNetworks(InferenceEnginePython::IENetwork lhs,
InferenceEnginePython::IENetwork rhs) {
std::pair<bool, std::string> InferenceEnginePython::CompareNetworks(InferenceEnginePython::IENetwork lhs, InferenceEnginePython::IENetwork rhs) {
return compare_functions(lhs.actual->getFunction(), rhs.actual->getFunction(), true, true, false, true);
}

View File

@ -286,6 +286,11 @@ struct QueryNetworkResult {
*/
using ConstOutputsDataMap = std::map<std::string, CDataPtr>;
/**
* @brief A collection that contains string as key, and Data smart pointer as value
*/
using OutputsDataMap = std::map<std::string, DataPtr>;
namespace details {
struct INFERENCE_ENGINE_DEPRECATED("Use InferRequest::Exception")
INFERENCE_ENGINE_API_CLASS(InferenceEngineException) : public std::runtime_error {

View File

@ -238,7 +238,7 @@ public:
* @brief Returns devices available for neural networks inference
*
* @return A vector of devices. The devices are returned as { CPU, FPGA.0, FPGA.1, MYRIAD }
If there more than one device of specific type, they are enumerated with .# suffix.
* If there more than one device of specific type, they are enumerated with .# suffix.
*/
std::vector<std::string> GetAvailableDevices() const;

View File

@ -34,11 +34,6 @@ class Function;
namespace InferenceEngine {
/**
* @brief A collection that contains string as key, and Data smart pointer as value
*/
using OutputsDataMap = std::map<std::string, DataPtr>;
/**
* @deprecated Use InferenceEngine::CNNNetwork wrapper instead
* @interface ICNNNetwork

View File

@ -591,6 +591,43 @@ public:
return copyParameterValue(GetCPPPluginByName(parsed._deviceName).GetMetric(name, parsed._config));
}
/**
* @brief Returns devices available for neural networks inference
*
* @return A vector of devices. The devices are returned as { CPU, FPGA.0, FPGA.1, MYRIAD }
* If there more than one device of specific type, they are enumerated with .# suffix.
*/
std::vector<std::string> GetAvailableDevices() const override {
std::vector<std::string> devices;
const std::string propertyName = METRIC_KEY(AVAILABLE_DEVICES);
for (auto&& deviceName : GetListOfDevicesInRegistry()) {
std::vector<std::string> devicesIDs;
try {
const Parameter p = GetMetric(deviceName, propertyName);
devicesIDs = p.as<std::vector<std::string>>();
} catch (Exception&) {
// plugin is not created by e.g. invalid env
} catch (const std::exception& ex) {
IE_THROW() << "An exception is thrown while trying to create the " << deviceName
<< " device and call GetMetric: " << ex.what();
} catch (...) {
IE_THROW() << "Unknown exception is thrown while trying to create the " << deviceName
<< " device and call GetMetric";
}
if (devicesIDs.size() > 1) {
for (auto&& deviceID : devicesIDs) {
devices.push_back(deviceName + '.' + deviceID);
}
} else if (!devicesIDs.empty()) {
devices.push_back(deviceName);
}
}
return devices;
}
/**
* @brief Returns reference to CPP plugin wrapper by a device name
* @param deviceName A name of device
@ -1007,35 +1044,7 @@ Parameter Core::GetMetric(const std::string& deviceName, const std::string& name
}
std::vector<std::string> Core::GetAvailableDevices() const {
std::vector<std::string> devices;
std::string propertyName = METRIC_KEY(AVAILABLE_DEVICES);
for (auto&& deviceName : _impl->GetListOfDevicesInRegistry()) {
std::vector<std::string> devicesIDs;
try {
Parameter p = GetMetric(deviceName, propertyName);
devicesIDs = p.as<std::vector<std::string>>();
} catch (Exception&) {
// plugin is not created by e.g. invalid env
} catch (const std::exception& ex) {
IE_THROW() << "An exception is thrown while trying to create the " << deviceName
<< " device and call GetMetric: " << ex.what();
} catch (...) {
IE_THROW() << "Unknown exception is thrown while trying to create the " << deviceName
<< " device and call GetMetric";
}
if (devicesIDs.size() > 1) {
for (auto&& deviceID : devicesIDs) {
devices.push_back(deviceName + '.' + deviceID);
}
} else if (!devicesIDs.empty()) {
devices.push_back(deviceName);
}
}
return devices;
return _impl->GetAvailableDevices();
}
void Core::RegisterPlugin(const std::string& pluginName, const std::string& deviceName) {

View File

@ -100,6 +100,14 @@ public:
*/
virtual Parameter GetMetric(const std::string& deviceName, const std::string& name) const = 0;
/**
* @brief Returns devices available for neural networks inference
*
* @return A vector of devices. The devices are returned as { CPU, FPGA.0, FPGA.1, MYRIAD }
* If there more than one device of specific type, they are enumerated with .# suffix.
*/
virtual std::vector<std::string> GetAvailableDevices() const = 0;
/**
* @brief Default virtual destructor
*/

View File

@ -700,13 +700,6 @@ V10Parser::V10Parser::GenericLayerParams XmlDeserializer::parseGenericParams(
port.dims.push_back(dim);
}
ngraph::element::Type type(ngraph::element::Type_t::undefined);
// Input port hasn't precision
if (!input) {
const std::string& preStr = GetStrAttr(parentNode, "precision");
type = InferenceEngine::details::convertPrecision(preStr);
}
port.precision = type;
std::vector<std::string> names;
if (getParameters<std::string>(parentNode, "names", names)) {
for (size_t i = 0; i < names.size(); i++) {

View File

@ -67,7 +67,6 @@ public:
struct GenericLayerParams {
struct LayerPortData {
size_t portId;
ngraph::element::Type_t precision;
SizeVector dims;
std::unordered_set<std::string> names;
};

View File

@ -57,5 +57,7 @@ std::vector<std::string> disabledTestPatterns() {
R"(.*(LPT/StridedSliceTransformation).*)",
// TODO: Issue: 48106
R"(.*ConstantResultSubgraphTest.*inPrc=I16.*)",
// TODO: Issue: 54436
R"(.*LSTMSequence.*CompareWithRefs.*mode=PURE_SEQ_RAND_SEQ_LEN_PARAM.*direction=bidirectional_clip=0.7_netPRC=FP32.*)",
};
}

View File

@ -28,6 +28,7 @@ public:
const InferenceEngine::CNNNetwork&, const std::string&, const std::map<std::string, std::string>&));
MOCK_QUALIFIED_METHOD2(GetMetric, const, InferenceEngine::Parameter(const std::string&, const std::string&));
MOCK_QUALIFIED_METHOD0(GetAvailableDevices, const, std::vector<std::string>());
~MockICore() = default;
};

View File

@ -3,6 +3,8 @@
//
#include <vector>
#include <thread>
#include <chrono>
#include <gtest/gtest.h>
#include <legacy/layer_transform.hpp>
#include "gna_matcher.hpp"
@ -18,6 +20,15 @@ class GNAAOTTests : public GNATest<>{
files_to_remove.push_back(file_to_remove);
return file_to_remove;
}
std::string generateFileName(const std::string& baseName) const {
using namespace std::chrono;
std::stringstream ss;
auto ts = duration_cast<microseconds>(high_resolution_clock::now().time_since_epoch());
ss << std::this_thread::get_id() << "_" << ts.count() << "_" << baseName;
return ss.str();
}
void TearDown() override {
for (auto & file : files_to_remove) {
std::remove(file.c_str());
@ -30,7 +41,7 @@ class GNAAOTTests : public GNATest<>{
TEST_F(GNAAOTTests, DISABLED_AffineWith2AffineOutputs_canbe_export_imported) {
const std::string X = registerFileForRemove("unit_tests.bin");
const std::string X = registerFileForRemove(generateFileName("unit_tests.bin"));
// running export to a file
export_network(AffineWith2AffineOutputsModel())
@ -52,7 +63,7 @@ TEST_F(GNAAOTTests, DISABLED_AffineWith2AffineOutputs_canbe_imported_verify_stru
save_args().onInferModel(AffineWith2AffineOutputsModel())
.inNotCompactMode().withGNAConfig(GNA_CONFIG_KEY(SCALE_FACTOR), 1.0f).from().gna().propagate_forward().to(&nnet_type);
const std::string X = registerFileForRemove("unit_tests.bin");
const std::string X = registerFileForRemove(generateFileName("unit_tests.bin"));
// running export to a file
export_network(AffineWith2AffineOutputsModel())
@ -70,7 +81,7 @@ TEST_F(GNAAOTTests, TwoInputsModel_canbe_export_imported) {
GTEST_SKIP();
#endif
const std::string X = registerFileForRemove("unit_tests.bin");
const std::string X = registerFileForRemove(generateFileName("unit_tests.bin"));
// running export to a file
export_network(TwoInputsModelForIO())
@ -90,7 +101,7 @@ TEST_F(GNAAOTTests, PermuteModel_canbe_export_imported) {
GTEST_SKIP();
#endif
const std::string X = registerFileForRemove("unit_tests.bin");
const std::string X = registerFileForRemove(generateFileName("unit_tests.bin"));
// running export to a file
export_network(PermuteModelForIO())
@ -107,7 +118,7 @@ TEST_F(GNAAOTTests, PoolingModel_canbe_export_imported) {
GTEST_SKIP();
#endif
const std::string X = registerFileForRemove("unit_tests.bin");
const std::string X = registerFileForRemove(generateFileName("unit_tests.bin"));
// running export to a file
export_network(maxpoolAfterRelu())
@ -127,7 +138,7 @@ TEST_F(GNAAOTTests, DISABLED_CanConvertFromAOTtoSueModel) {
.inNotCompactMode().inNotCompactMode().withGNAConfig(GNA_CONFIG_KEY(SCALE_FACTOR), 1.0f)
.from().gna().propagate_forward().to(&nnet_type);
const std::string X = registerFileForRemove("unit_tests.bin");
const std::string X = registerFileForRemove(generateFileName("unit_tests.bin"));
// running export to a file
export_network(AffineWith2AffineOutputsModel())

View File

@ -337,7 +337,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
in_port = len(Node(graph, node_name).in_nodes())
Node(graph, node_name).add_input_port(in_port)
Node(graph, in_node_id).add_output_port(out_port)
Node(graph, in_node_id).add_output_port(out_port, skip_if_exist=True)
graph.add_edge(in_node_id, node_name, **create_edge_attrs(in_node_id, node_name, in_node_id, in_port, out_port))
elif tokens[0] == b'output-node':
@ -528,7 +528,7 @@ def parse_specifier(string, graph, layer_node_map):
const_node = Const(graph, {'name': scale_const_name, 'value': float_array([scale_value])}).create_node()
node = Node(graph, node_name)
graph.create_edge(const_node, scale_node, 0, 0, create_edge_attrs(const_node.id, scale_name.id, const_node.id))
graph.create_edge(const_node, scale_node, 0, 0, create_edge_attrs(const_node.id, scale_node.id, const_node.id))
out_port = len(node.out_nodes())
graph.create_edge(node, scale_node, out_port, 1, create_edge_attrs(node_name, scale_node.id, node_name, 1, out_port))
else:

View File

@ -203,3 +203,31 @@ class TestKaldiModelsLoading(unittest.TestCase):
)
(flag, resp) = compare_graphs(graph, ref_graph, 'tdnn1.relu')
self.assertTrue(flag, resp)
def test_component_map_loading_scale(self):
test_map = "input-node name=input dim=16\n" + \
"component-node name=lda component=lda input=Scale(0.1, input)\n" + \
"\n"
graph = Graph(name="test_graph_component_map_loading_scale")
test_top_map = load_topology_map(io.BytesIO(bytes(test_map, 'ascii')), graph)
ref_map = {b"lda": ["lda"]}
self.assertEqual(test_top_map, ref_map)
self.assertTrue("input" in graph.nodes())
self.assertListEqual(list(Node(graph, 'input')['shape']), [1, 16])
ref_graph = build_graph({'input': {'shape': np.array([1, 16]), 'kind': 'op', 'op': 'Parameter'},
'lda': {'kind': 'op'},
'mul': {'kind': 'op'},
'scale_const': {'kind': 'op', 'op': 'Const'},
},
[
('input', 'mul', {'in': 0}),
('scale_const', 'mul', {'in': 1}),
('mul', 'lda', {'out': 0}),
]
)
(flag, resp) = compare_graphs(graph, ref_graph, 'lda')
self.assertTrue(flag, resp)

View File

@ -5,6 +5,7 @@
#pragma once
#include <cstring>
#include <map>
#include "ngraph/descriptor/tensor.hpp"
#include "ngraph/partial_shape.hpp"
@ -23,6 +24,8 @@ namespace ngraph
{
};
class Variant;
/// \brief A handle for one of a node's inputs.
template <>
class NGRAPH_API Input<Node>
@ -58,6 +61,12 @@ namespace ngraph
/// \param new_source_output A handle for the output that will replace this input's source.
void replace_source_output(const Output<Node>& new_source_output) const;
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
/// \return The reference to runtime info map
RTMap& get_rt_info();
/// \return The constant reference to runtime info map
const RTMap& get_rt_info() const;
bool operator==(const Input& other) const;
bool operator!=(const Input& other) const;
bool operator<(const Input& other) const;
@ -101,6 +110,10 @@ namespace ngraph
/// \return true if this input is relevant to its node's output values; else false.
bool get_is_relevant_to_values() const;
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
/// \return The constant reference to runtime info map
const RTMap& get_rt_info() const;
bool operator==(const Input& other) const;
bool operator!=(const Input& other) const;
bool operator<(const Input& other) const;

View File

@ -219,11 +219,29 @@ namespace ngraph
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
protected:
bool apply_matcher_passes(std::shared_ptr<Function> f,
std::deque<std::shared_ptr<Node>> nodes_to_run);
bool m_enable_shape_inference = false;
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
};
class NGRAPH_API BackwardGraphRewrite : public ngraph::pass::GraphRewrite
{
public:
NGRAPH_RTTI_DECLARATION;
BackwardGraphRewrite() = default;
explicit BackwardGraphRewrite(const std::shared_ptr<MatcherPass>& pass)
: GraphRewrite(pass)
{
}
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
class NGRAPH_API RecurrentGraphRewrite : public ngraph::pass::FunctionPass
{
public:

View File

@ -82,6 +82,20 @@ namespace ngraph
{
}
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
RTMap& Input<Node>::get_rt_info() { return m_node->m_outputs.at(m_index).get_rt_info(); }
const RTMap& Input<Node>::get_rt_info() const
{
return m_node->m_outputs.at(m_index).get_rt_info();
}
const RTMap& Input<const Node>::get_rt_info() const
{
return m_node->m_outputs.at(m_index).get_rt_info();
}
const Node* Input<const Node>::get_node() const { return m_node; }
size_t Input<const Node>::get_index() const { return m_index; }
const element::Type& Input<const Node>::get_element_type() const

View File

@ -54,6 +54,8 @@ using namespace ngraph;
NGRAPH_RTTI_DEFINITION(ngraph::pass::GraphRewrite, "ngraph::pass::GraphRewrite", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::BackwardGraphRewrite, "ngraph::pass::BackwardGraphRewrite", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::MatcherPass, "ngraph::pass::MatcherPass", 0);
namespace ngraph
@ -71,19 +73,35 @@ namespace ngraph
} // namespace pass
} // namespace ngraph
bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
bool pass::BackwardGraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f)
{
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::GraphRewrite::run_on_function");
bool rewritten = false;
const auto& pass_config = get_pass_config();
// Initialize execution queue with nodes in topological order
deque<std::shared_ptr<Node>> nodes_to_run;
for (auto& node : f->get_ordered_ops())
{
nodes_to_run.emplace_front(node);
}
return apply_matcher_passes(f, std::move(nodes_to_run));
}
bool pass::GraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f)
{
// Initialize execution queue with nodes in topological order
deque<std::shared_ptr<Node>> nodes_to_run;
for (auto& node : f->get_ordered_ops())
{
nodes_to_run.emplace_back(node);
}
return apply_matcher_passes(f, std::move(nodes_to_run));
}
bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f,
deque<std::shared_ptr<Node>> nodes_to_run)
{
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::GraphRewrite::run_on_function");
bool rewritten = false;
const auto& pass_config = get_pass_config();
// Check that all Matchers in MatcherPasses has type bases root node
bool all_roots_has_type = true;

View File

@ -39,6 +39,23 @@ public:
}
};
class GatherNodesPass : public ngraph::pass::MatcherPass
{
public:
NGRAPH_RTTI_DECLARATION;
GatherNodesPass(NodeVector & order)
: MatcherPass()
{
ngraph::matcher_pass_callback callback = [&order](pattern::Matcher& m) {
order.push_back(m.get_match_root());
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::any_input(), "GatherNodesPass");
this->register_matcher(m, callback);
}
};
class Anchor : public ngraph::pass::GraphRewrite
{
public:
@ -51,6 +68,7 @@ public:
NGRAPH_RTTI_DEFINITION(TestPass, "TestPass", 0);
NGRAPH_RTTI_DEFINITION(Anchor, "Anchor", 0);
NGRAPH_RTTI_DEFINITION(GatherNodesPass, "GatherNodesPass", 0);
std::shared_ptr<Function> get_function()
{
@ -77,6 +95,34 @@ ngraph::pass::param_callback get_callback()
};
}
TEST(GraphRewriteOrderTest, MatcherPass)
{
auto f = get_function();
NodeVector order;
ngraph::pass::Manager m;
auto pass = m.register_pass<pass::GraphRewrite>();
pass->add_matcher<GatherNodesPass>(order);
m.run_passes(f);
ASSERT_EQ(order, f->get_ordered_ops());
}
TEST(BackwardGraphRewriteOrderTest, MatcherPass)
{
auto f = get_function();
NodeVector order;
ngraph::pass::Manager m;
auto pass = m.register_pass<pass::BackwardGraphRewrite>();
pass->add_matcher<GatherNodesPass>(order);
m.run_passes(f);
auto ref_order = f->get_ordered_ops();
std::reverse(ref_order.begin(), ref_order.end());
ASSERT_EQ(order, ref_order);
}
TEST(GraphRewriteTest, MatcherPassCallback)
{
auto f = get_function();

View File

@ -124,9 +124,25 @@ TEST(op, variant)
EXPECT_EQ(ship.y, 4);
auto node = make_shared<op::Parameter>(element::f32, Shape{1});
// Check Node RTInfo
node->get_rt_info()["A"] = var_ship;
auto node_var_ship = node->get_rt_info().at("A");
ASSERT_TRUE((is_type<VariantWrapper<Ship>>(node_var_ship)));
Ship& node_ship = as_type_ptr<VariantWrapper<Ship>>(node_var_ship)->get();
EXPECT_EQ(&node_ship, &ship);
// Check Node Input<Node> RTInfo
auto relu = make_shared<op::Relu>(node);
relu->input(0).get_rt_info()["A"] = var_ship;
auto node_input_var_ship = node->get_rt_info().at("A");
ASSERT_TRUE((is_type<VariantWrapper<Ship>>(node_input_var_ship)));
Ship& node_input_ship = as_type_ptr<VariantWrapper<Ship>>(node_input_var_ship)->get();
EXPECT_EQ(&node_input_ship, &ship);
// Check Node Input<Node> RTInfo
node->output(0).get_rt_info()["A"] = var_ship;
auto node_output_var_ship = node->get_rt_info().at("A");
ASSERT_TRUE((is_type<VariantWrapper<Ship>>(node_output_var_ship)));
Ship& node_output_ship = as_type_ptr<VariantWrapper<Ship>>(node_input_var_ship)->get();
EXPECT_EQ(&node_output_ship, &ship);
}