Merge remote-tracking branch 'upstream/master'
This commit is contained in:
commit
a6b2800be2
@ -14,6 +14,7 @@
|
||||
|
||||
namespace TemplatePlugin {
|
||||
|
||||
// forward declaration
|
||||
class Plugin;
|
||||
|
||||
/**
|
||||
|
@ -8,17 +8,6 @@
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
#include <ie_blob.h>
|
||||
#include <description_buffer.hpp>
|
||||
#include <debug.h>
|
||||
#include <ie_layouts.h>
|
||||
#include <threading/ie_executor_manager.hpp>
|
||||
#include <blob_transform.hpp>
|
||||
#include <ie_parallel.hpp>
|
||||
#include <ie_memcpy.h>
|
||||
#include <precision_utils.h>
|
||||
|
||||
#include "template/template_config.hpp"
|
||||
#include "template_infer_request.hpp"
|
||||
#include "template_executable_network.hpp"
|
||||
#include "template_plugin.hpp"
|
||||
|
@ -7,22 +7,21 @@
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <chrono>
|
||||
|
||||
#include <ie_common.h>
|
||||
#include <cpp_interfaces/impl/ie_executable_network_internal.hpp>
|
||||
#include <threading/ie_itask_executor.hpp>
|
||||
#include <openvino/itt.hpp>
|
||||
|
||||
#include <ie_input_info.hpp>
|
||||
#include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
|
||||
|
||||
#include <ngraph/runtime/tensor.hpp>
|
||||
#include <executable.hpp>
|
||||
|
||||
#include "template_config.hpp"
|
||||
|
||||
|
||||
namespace TemplatePlugin {
|
||||
|
||||
// forward declaration
|
||||
class ExecutableNetwork;
|
||||
|
||||
// ! [infer_request:header]
|
||||
|
@ -81,50 +81,19 @@ InferenceEngine::ExecutableNetworkInternal::Ptr Plugin::LoadExeNetworkImpl(const
|
||||
const ConfigMap &config) {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::TemplatePlugin, "Plugin::LoadExeNetworkImpl");
|
||||
|
||||
auto cfg = Configuration{ config, _cfg };
|
||||
InferenceEngine::InputsDataMap networkInputs = network.getInputsInfo();
|
||||
InferenceEngine::OutputsDataMap networkOutputs = network.getOutputsInfo();
|
||||
|
||||
// TODO: check with precisions supported by Template device
|
||||
|
||||
for (auto networkOutput : networkOutputs) {
|
||||
auto output_precision = networkOutput.second->getPrecision();
|
||||
|
||||
if (output_precision != InferenceEngine::Precision::FP32 &&
|
||||
output_precision != InferenceEngine::Precision::FP16 &&
|
||||
output_precision != InferenceEngine::Precision::U8) {
|
||||
IE_THROW() << "Template device supports only U8, FP16 and FP32 output precision.";
|
||||
}
|
||||
}
|
||||
|
||||
for (auto networkInput : networkInputs) {
|
||||
auto input_precision = networkInput.second->getTensorDesc().getPrecision();
|
||||
|
||||
if (input_precision != InferenceEngine::Precision::FP32 &&
|
||||
input_precision != InferenceEngine::Precision::FP16 &&
|
||||
input_precision != InferenceEngine::Precision::I16 &&
|
||||
input_precision != InferenceEngine::Precision::U8) {
|
||||
IE_THROW() << "Input image format " << input_precision << " is not supported yet.\n"
|
||||
<< "Supported formats are: FP32, FP16, I16 and U8.";
|
||||
}
|
||||
}
|
||||
|
||||
auto function = network.getFunction();
|
||||
if (function == nullptr) {
|
||||
IE_THROW() << "TEMPLATE plugin can compile only IR v10 networks";
|
||||
}
|
||||
|
||||
return std::make_shared<ExecutableNetwork>(function, cfg, std::static_pointer_cast<Plugin>(shared_from_this()));
|
||||
auto fullConfig = Configuration{ config, _cfg };
|
||||
return std::make_shared<ExecutableNetwork>(network.getFunction(), fullConfig,
|
||||
std::static_pointer_cast<Plugin>(shared_from_this()));
|
||||
}
|
||||
// ! [plugin:load_exe_network_impl]
|
||||
|
||||
// ! [plugin:import_network_impl]
|
||||
InferenceEngine::ExecutableNetworkInternal::Ptr
|
||||
Plugin::ImportNetworkImpl(std::istream& model, const std::map<std::string, std::string>& config) {
|
||||
Plugin::ImportNetworkImpl(std::istream& modelStream, const std::map<std::string, std::string>& config) {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::TemplatePlugin, "Plugin::ImportNetworkImpl");
|
||||
|
||||
Configuration cfg(config);
|
||||
return std::make_shared<ExecutableNetwork>(model, cfg,
|
||||
auto fullConfig = Configuration{ config, _cfg };
|
||||
return std::make_shared<ExecutableNetwork>(modelStream, fullConfig,
|
||||
std::static_pointer_cast<Plugin>(shared_from_this()));
|
||||
}
|
||||
// ! [plugin:import_network_impl]
|
||||
@ -133,13 +102,8 @@ Plugin::ImportNetworkImpl(std::istream& model, const std::map<std::string, std::
|
||||
InferenceEngine::QueryNetworkResult Plugin::QueryNetwork(const InferenceEngine::CNNNetwork &network, const ConfigMap& config) const {
|
||||
OV_ITT_SCOPED_TASK(itt::domains::TemplatePlugin, "Plugin::QueryNetwork");
|
||||
|
||||
InferenceEngine::QueryNetworkResult res;
|
||||
Configuration cfg{config, _cfg, false};
|
||||
|
||||
Configuration fullConfig{config, _cfg, false};
|
||||
auto function = network.getFunction();
|
||||
if (function == nullptr) {
|
||||
IE_THROW() << "Template Plugin supports only ngraph cnn network representation";
|
||||
}
|
||||
|
||||
// 1. First of all we should store initial input operation set
|
||||
std::unordered_set<std::string> originalOps;
|
||||
@ -207,6 +171,7 @@ InferenceEngine::QueryNetworkResult Plugin::QueryNetwork(const InferenceEngine::
|
||||
}
|
||||
|
||||
// 7. Produce the result
|
||||
InferenceEngine::QueryNetworkResult res;
|
||||
for (auto&& layerName : supported) {
|
||||
res.supportedLayersMap.emplace(layerName, GetName());
|
||||
}
|
||||
|
@ -114,16 +114,6 @@ install(FILES samples/CMakeLists.txt
|
||||
DESTINATION ${IE_CPACK_IE_DIR}/samples/c
|
||||
COMPONENT c_samples)
|
||||
|
||||
# install Python samples
|
||||
|
||||
if(ENABLE_PYTHON)
|
||||
ie_cpack_add_component(python_samples DEPENDS core)
|
||||
|
||||
install(DIRECTORY ${ie_python_api_SOURCE_DIR}/sample/
|
||||
DESTINATION ${IE_CPACK_IE_DIR}/samples/python
|
||||
COMPONENT python_samples)
|
||||
endif()
|
||||
|
||||
# install speech demo files
|
||||
|
||||
if(SPEECH_LIBS_AND_DEMOS)
|
||||
|
@ -90,4 +90,12 @@ install(PROGRAMS src/openvino/__init__.py
|
||||
DESTINATION ${PYTHON_BRIDGE_CPACK_PATH}/${PYTHON_VERSION}/openvino
|
||||
COMPONENT ${PYTHON_VERSION})
|
||||
|
||||
ie_cpack(${PYTHON_VERSION})
|
||||
# install Python samples
|
||||
|
||||
ie_cpack_add_component(python_samples)
|
||||
|
||||
install(DIRECTORY sample/
|
||||
DESTINATION ${IE_CPACK_IE_DIR}/samples/python
|
||||
COMPONENT python_samples)
|
||||
|
||||
ie_cpack(${PYTHON_VERSION} python_samples)
|
||||
|
@ -4,11 +4,9 @@
|
||||
|
||||
#include "test_utils_api_impl.hpp"
|
||||
|
||||
#include <common_test_utils/ngraph_test_utils.hpp>
|
||||
#include <string>
|
||||
|
||||
#include <common_test_utils/ngraph_test_utils.hpp>
|
||||
|
||||
std::pair<bool, std::string> InferenceEnginePython::CompareNetworks(InferenceEnginePython::IENetwork lhs,
|
||||
InferenceEnginePython::IENetwork rhs) {
|
||||
std::pair<bool, std::string> InferenceEnginePython::CompareNetworks(InferenceEnginePython::IENetwork lhs, InferenceEnginePython::IENetwork rhs) {
|
||||
return compare_functions(lhs.actual->getFunction(), rhs.actual->getFunction(), true, true, false, true);
|
||||
}
|
||||
|
@ -286,6 +286,11 @@ struct QueryNetworkResult {
|
||||
*/
|
||||
using ConstOutputsDataMap = std::map<std::string, CDataPtr>;
|
||||
|
||||
/**
|
||||
* @brief A collection that contains string as key, and Data smart pointer as value
|
||||
*/
|
||||
using OutputsDataMap = std::map<std::string, DataPtr>;
|
||||
|
||||
namespace details {
|
||||
struct INFERENCE_ENGINE_DEPRECATED("Use InferRequest::Exception")
|
||||
INFERENCE_ENGINE_API_CLASS(InferenceEngineException) : public std::runtime_error {
|
||||
|
@ -238,7 +238,7 @@ public:
|
||||
* @brief Returns devices available for neural networks inference
|
||||
*
|
||||
* @return A vector of devices. The devices are returned as { CPU, FPGA.0, FPGA.1, MYRIAD }
|
||||
If there more than one device of specific type, they are enumerated with .# suffix.
|
||||
* If there more than one device of specific type, they are enumerated with .# suffix.
|
||||
*/
|
||||
std::vector<std::string> GetAvailableDevices() const;
|
||||
|
||||
|
@ -34,11 +34,6 @@ class Function;
|
||||
|
||||
namespace InferenceEngine {
|
||||
|
||||
/**
|
||||
* @brief A collection that contains string as key, and Data smart pointer as value
|
||||
*/
|
||||
using OutputsDataMap = std::map<std::string, DataPtr>;
|
||||
|
||||
/**
|
||||
* @deprecated Use InferenceEngine::CNNNetwork wrapper instead
|
||||
* @interface ICNNNetwork
|
||||
|
@ -591,6 +591,43 @@ public:
|
||||
return copyParameterValue(GetCPPPluginByName(parsed._deviceName).GetMetric(name, parsed._config));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns devices available for neural networks inference
|
||||
*
|
||||
* @return A vector of devices. The devices are returned as { CPU, FPGA.0, FPGA.1, MYRIAD }
|
||||
* If there more than one device of specific type, they are enumerated with .# suffix.
|
||||
*/
|
||||
std::vector<std::string> GetAvailableDevices() const override {
|
||||
std::vector<std::string> devices;
|
||||
const std::string propertyName = METRIC_KEY(AVAILABLE_DEVICES);
|
||||
|
||||
for (auto&& deviceName : GetListOfDevicesInRegistry()) {
|
||||
std::vector<std::string> devicesIDs;
|
||||
try {
|
||||
const Parameter p = GetMetric(deviceName, propertyName);
|
||||
devicesIDs = p.as<std::vector<std::string>>();
|
||||
} catch (Exception&) {
|
||||
// plugin is not created by e.g. invalid env
|
||||
} catch (const std::exception& ex) {
|
||||
IE_THROW() << "An exception is thrown while trying to create the " << deviceName
|
||||
<< " device and call GetMetric: " << ex.what();
|
||||
} catch (...) {
|
||||
IE_THROW() << "Unknown exception is thrown while trying to create the " << deviceName
|
||||
<< " device and call GetMetric";
|
||||
}
|
||||
|
||||
if (devicesIDs.size() > 1) {
|
||||
for (auto&& deviceID : devicesIDs) {
|
||||
devices.push_back(deviceName + '.' + deviceID);
|
||||
}
|
||||
} else if (!devicesIDs.empty()) {
|
||||
devices.push_back(deviceName);
|
||||
}
|
||||
}
|
||||
|
||||
return devices;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Returns reference to CPP plugin wrapper by a device name
|
||||
* @param deviceName A name of device
|
||||
@ -1007,35 +1044,7 @@ Parameter Core::GetMetric(const std::string& deviceName, const std::string& name
|
||||
}
|
||||
|
||||
std::vector<std::string> Core::GetAvailableDevices() const {
|
||||
std::vector<std::string> devices;
|
||||
|
||||
std::string propertyName = METRIC_KEY(AVAILABLE_DEVICES);
|
||||
|
||||
for (auto&& deviceName : _impl->GetListOfDevicesInRegistry()) {
|
||||
std::vector<std::string> devicesIDs;
|
||||
try {
|
||||
Parameter p = GetMetric(deviceName, propertyName);
|
||||
devicesIDs = p.as<std::vector<std::string>>();
|
||||
} catch (Exception&) {
|
||||
// plugin is not created by e.g. invalid env
|
||||
} catch (const std::exception& ex) {
|
||||
IE_THROW() << "An exception is thrown while trying to create the " << deviceName
|
||||
<< " device and call GetMetric: " << ex.what();
|
||||
} catch (...) {
|
||||
IE_THROW() << "Unknown exception is thrown while trying to create the " << deviceName
|
||||
<< " device and call GetMetric";
|
||||
}
|
||||
|
||||
if (devicesIDs.size() > 1) {
|
||||
for (auto&& deviceID : devicesIDs) {
|
||||
devices.push_back(deviceName + '.' + deviceID);
|
||||
}
|
||||
} else if (!devicesIDs.empty()) {
|
||||
devices.push_back(deviceName);
|
||||
}
|
||||
}
|
||||
|
||||
return devices;
|
||||
return _impl->GetAvailableDevices();
|
||||
}
|
||||
|
||||
void Core::RegisterPlugin(const std::string& pluginName, const std::string& deviceName) {
|
||||
|
@ -100,6 +100,14 @@ public:
|
||||
*/
|
||||
virtual Parameter GetMetric(const std::string& deviceName, const std::string& name) const = 0;
|
||||
|
||||
/**
|
||||
* @brief Returns devices available for neural networks inference
|
||||
*
|
||||
* @return A vector of devices. The devices are returned as { CPU, FPGA.0, FPGA.1, MYRIAD }
|
||||
* If there more than one device of specific type, they are enumerated with .# suffix.
|
||||
*/
|
||||
virtual std::vector<std::string> GetAvailableDevices() const = 0;
|
||||
|
||||
/**
|
||||
* @brief Default virtual destructor
|
||||
*/
|
||||
|
@ -700,13 +700,6 @@ V10Parser::V10Parser::GenericLayerParams XmlDeserializer::parseGenericParams(
|
||||
port.dims.push_back(dim);
|
||||
}
|
||||
|
||||
ngraph::element::Type type(ngraph::element::Type_t::undefined);
|
||||
// Input port hasn't precision
|
||||
if (!input) {
|
||||
const std::string& preStr = GetStrAttr(parentNode, "precision");
|
||||
type = InferenceEngine::details::convertPrecision(preStr);
|
||||
}
|
||||
port.precision = type;
|
||||
std::vector<std::string> names;
|
||||
if (getParameters<std::string>(parentNode, "names", names)) {
|
||||
for (size_t i = 0; i < names.size(); i++) {
|
||||
|
@ -67,7 +67,6 @@ public:
|
||||
struct GenericLayerParams {
|
||||
struct LayerPortData {
|
||||
size_t portId;
|
||||
ngraph::element::Type_t precision;
|
||||
SizeVector dims;
|
||||
std::unordered_set<std::string> names;
|
||||
};
|
||||
|
@ -57,5 +57,7 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
R"(.*(LPT/StridedSliceTransformation).*)",
|
||||
// TODO: Issue: 48106
|
||||
R"(.*ConstantResultSubgraphTest.*inPrc=I16.*)",
|
||||
// TODO: Issue: 54436
|
||||
R"(.*LSTMSequence.*CompareWithRefs.*mode=PURE_SEQ_RAND_SEQ_LEN_PARAM.*direction=bidirectional_clip=0.7_netPRC=FP32.*)",
|
||||
};
|
||||
}
|
||||
|
@ -28,6 +28,7 @@ public:
|
||||
const InferenceEngine::CNNNetwork&, const std::string&, const std::map<std::string, std::string>&));
|
||||
|
||||
MOCK_QUALIFIED_METHOD2(GetMetric, const, InferenceEngine::Parameter(const std::string&, const std::string&));
|
||||
MOCK_QUALIFIED_METHOD0(GetAvailableDevices, const, std::vector<std::string>());
|
||||
|
||||
~MockICore() = default;
|
||||
};
|
||||
|
@ -3,6 +3,8 @@
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include <gtest/gtest.h>
|
||||
#include <legacy/layer_transform.hpp>
|
||||
#include "gna_matcher.hpp"
|
||||
@ -18,6 +20,15 @@ class GNAAOTTests : public GNATest<>{
|
||||
files_to_remove.push_back(file_to_remove);
|
||||
return file_to_remove;
|
||||
}
|
||||
|
||||
std::string generateFileName(const std::string& baseName) const {
|
||||
using namespace std::chrono;
|
||||
std::stringstream ss;
|
||||
auto ts = duration_cast<microseconds>(high_resolution_clock::now().time_since_epoch());
|
||||
ss << std::this_thread::get_id() << "_" << ts.count() << "_" << baseName;
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
for (auto & file : files_to_remove) {
|
||||
std::remove(file.c_str());
|
||||
@ -30,7 +41,7 @@ class GNAAOTTests : public GNATest<>{
|
||||
|
||||
TEST_F(GNAAOTTests, DISABLED_AffineWith2AffineOutputs_canbe_export_imported) {
|
||||
|
||||
const std::string X = registerFileForRemove("unit_tests.bin");
|
||||
const std::string X = registerFileForRemove(generateFileName("unit_tests.bin"));
|
||||
|
||||
// running export to a file
|
||||
export_network(AffineWith2AffineOutputsModel())
|
||||
@ -52,7 +63,7 @@ TEST_F(GNAAOTTests, DISABLED_AffineWith2AffineOutputs_canbe_imported_verify_stru
|
||||
save_args().onInferModel(AffineWith2AffineOutputsModel())
|
||||
.inNotCompactMode().withGNAConfig(GNA_CONFIG_KEY(SCALE_FACTOR), 1.0f).from().gna().propagate_forward().to(&nnet_type);
|
||||
|
||||
const std::string X = registerFileForRemove("unit_tests.bin");
|
||||
const std::string X = registerFileForRemove(generateFileName("unit_tests.bin"));
|
||||
|
||||
// running export to a file
|
||||
export_network(AffineWith2AffineOutputsModel())
|
||||
@ -70,7 +81,7 @@ TEST_F(GNAAOTTests, TwoInputsModel_canbe_export_imported) {
|
||||
GTEST_SKIP();
|
||||
#endif
|
||||
|
||||
const std::string X = registerFileForRemove("unit_tests.bin");
|
||||
const std::string X = registerFileForRemove(generateFileName("unit_tests.bin"));
|
||||
|
||||
// running export to a file
|
||||
export_network(TwoInputsModelForIO())
|
||||
@ -90,7 +101,7 @@ TEST_F(GNAAOTTests, PermuteModel_canbe_export_imported) {
|
||||
GTEST_SKIP();
|
||||
#endif
|
||||
|
||||
const std::string X = registerFileForRemove("unit_tests.bin");
|
||||
const std::string X = registerFileForRemove(generateFileName("unit_tests.bin"));
|
||||
|
||||
// running export to a file
|
||||
export_network(PermuteModelForIO())
|
||||
@ -107,7 +118,7 @@ TEST_F(GNAAOTTests, PoolingModel_canbe_export_imported) {
|
||||
GTEST_SKIP();
|
||||
#endif
|
||||
|
||||
const std::string X = registerFileForRemove("unit_tests.bin");
|
||||
const std::string X = registerFileForRemove(generateFileName("unit_tests.bin"));
|
||||
|
||||
// running export to a file
|
||||
export_network(maxpoolAfterRelu())
|
||||
@ -127,7 +138,7 @@ TEST_F(GNAAOTTests, DISABLED_CanConvertFromAOTtoSueModel) {
|
||||
.inNotCompactMode().inNotCompactMode().withGNAConfig(GNA_CONFIG_KEY(SCALE_FACTOR), 1.0f)
|
||||
.from().gna().propagate_forward().to(&nnet_type);
|
||||
|
||||
const std::string X = registerFileForRemove("unit_tests.bin");
|
||||
const std::string X = registerFileForRemove(generateFileName("unit_tests.bin"));
|
||||
|
||||
// running export to a file
|
||||
export_network(AffineWith2AffineOutputsModel())
|
||||
|
@ -337,7 +337,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
|
||||
in_port = len(Node(graph, node_name).in_nodes())
|
||||
|
||||
Node(graph, node_name).add_input_port(in_port)
|
||||
Node(graph, in_node_id).add_output_port(out_port)
|
||||
Node(graph, in_node_id).add_output_port(out_port, skip_if_exist=True)
|
||||
|
||||
graph.add_edge(in_node_id, node_name, **create_edge_attrs(in_node_id, node_name, in_node_id, in_port, out_port))
|
||||
elif tokens[0] == b'output-node':
|
||||
@ -528,7 +528,7 @@ def parse_specifier(string, graph, layer_node_map):
|
||||
const_node = Const(graph, {'name': scale_const_name, 'value': float_array([scale_value])}).create_node()
|
||||
|
||||
node = Node(graph, node_name)
|
||||
graph.create_edge(const_node, scale_node, 0, 0, create_edge_attrs(const_node.id, scale_name.id, const_node.id))
|
||||
graph.create_edge(const_node, scale_node, 0, 0, create_edge_attrs(const_node.id, scale_node.id, const_node.id))
|
||||
out_port = len(node.out_nodes())
|
||||
graph.create_edge(node, scale_node, out_port, 1, create_edge_attrs(node_name, scale_node.id, node_name, 1, out_port))
|
||||
else:
|
||||
|
@ -203,3 +203,31 @@ class TestKaldiModelsLoading(unittest.TestCase):
|
||||
)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'tdnn1.relu')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_component_map_loading_scale(self):
|
||||
test_map = "input-node name=input dim=16\n" + \
|
||||
"component-node name=lda component=lda input=Scale(0.1, input)\n" + \
|
||||
"\n"
|
||||
graph = Graph(name="test_graph_component_map_loading_scale")
|
||||
|
||||
test_top_map = load_topology_map(io.BytesIO(bytes(test_map, 'ascii')), graph)
|
||||
|
||||
ref_map = {b"lda": ["lda"]}
|
||||
self.assertEqual(test_top_map, ref_map)
|
||||
self.assertTrue("input" in graph.nodes())
|
||||
self.assertListEqual(list(Node(graph, 'input')['shape']), [1, 16])
|
||||
|
||||
ref_graph = build_graph({'input': {'shape': np.array([1, 16]), 'kind': 'op', 'op': 'Parameter'},
|
||||
'lda': {'kind': 'op'},
|
||||
'mul': {'kind': 'op'},
|
||||
'scale_const': {'kind': 'op', 'op': 'Const'},
|
||||
},
|
||||
[
|
||||
('input', 'mul', {'in': 0}),
|
||||
('scale_const', 'mul', {'in': 1}),
|
||||
('mul', 'lda', {'out': 0}),
|
||||
]
|
||||
)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'lda')
|
||||
self.assertTrue(flag, resp)
|
||||
|
@ -5,6 +5,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <map>
|
||||
|
||||
#include "ngraph/descriptor/tensor.hpp"
|
||||
#include "ngraph/partial_shape.hpp"
|
||||
@ -23,6 +24,8 @@ namespace ngraph
|
||||
{
|
||||
};
|
||||
|
||||
class Variant;
|
||||
|
||||
/// \brief A handle for one of a node's inputs.
|
||||
template <>
|
||||
class NGRAPH_API Input<Node>
|
||||
@ -58,6 +61,12 @@ namespace ngraph
|
||||
/// \param new_source_output A handle for the output that will replace this input's source.
|
||||
void replace_source_output(const Output<Node>& new_source_output) const;
|
||||
|
||||
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
|
||||
/// \return The reference to runtime info map
|
||||
RTMap& get_rt_info();
|
||||
/// \return The constant reference to runtime info map
|
||||
const RTMap& get_rt_info() const;
|
||||
|
||||
bool operator==(const Input& other) const;
|
||||
bool operator!=(const Input& other) const;
|
||||
bool operator<(const Input& other) const;
|
||||
@ -101,6 +110,10 @@ namespace ngraph
|
||||
/// \return true if this input is relevant to its node's output values; else false.
|
||||
bool get_is_relevant_to_values() const;
|
||||
|
||||
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
|
||||
/// \return The constant reference to runtime info map
|
||||
const RTMap& get_rt_info() const;
|
||||
|
||||
bool operator==(const Input& other) const;
|
||||
bool operator!=(const Input& other) const;
|
||||
bool operator<(const Input& other) const;
|
||||
|
@ -219,11 +219,29 @@ namespace ngraph
|
||||
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
|
||||
|
||||
protected:
|
||||
bool apply_matcher_passes(std::shared_ptr<Function> f,
|
||||
std::deque<std::shared_ptr<Node>> nodes_to_run);
|
||||
|
||||
bool m_enable_shape_inference = false;
|
||||
|
||||
std::vector<std::shared_ptr<ngraph::pass::MatcherPass>> m_matchers;
|
||||
};
|
||||
|
||||
class NGRAPH_API BackwardGraphRewrite : public ngraph::pass::GraphRewrite
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
BackwardGraphRewrite() = default;
|
||||
|
||||
explicit BackwardGraphRewrite(const std::shared_ptr<MatcherPass>& pass)
|
||||
: GraphRewrite(pass)
|
||||
{
|
||||
}
|
||||
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
};
|
||||
|
||||
class NGRAPH_API RecurrentGraphRewrite : public ngraph::pass::FunctionPass
|
||||
{
|
||||
public:
|
||||
|
@ -82,6 +82,20 @@ namespace ngraph
|
||||
{
|
||||
}
|
||||
|
||||
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
|
||||
|
||||
RTMap& Input<Node>::get_rt_info() { return m_node->m_outputs.at(m_index).get_rt_info(); }
|
||||
|
||||
const RTMap& Input<Node>::get_rt_info() const
|
||||
{
|
||||
return m_node->m_outputs.at(m_index).get_rt_info();
|
||||
}
|
||||
|
||||
const RTMap& Input<const Node>::get_rt_info() const
|
||||
{
|
||||
return m_node->m_outputs.at(m_index).get_rt_info();
|
||||
}
|
||||
|
||||
const Node* Input<const Node>::get_node() const { return m_node; }
|
||||
size_t Input<const Node>::get_index() const { return m_index; }
|
||||
const element::Type& Input<const Node>::get_element_type() const
|
||||
|
@ -54,6 +54,8 @@ using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::GraphRewrite, "ngraph::pass::GraphRewrite", 0);
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::BackwardGraphRewrite, "ngraph::pass::BackwardGraphRewrite", 0);
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::MatcherPass, "ngraph::pass::MatcherPass", 0);
|
||||
|
||||
namespace ngraph
|
||||
@ -71,19 +73,35 @@ namespace ngraph
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
|
||||
bool pass::BackwardGraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f)
|
||||
{
|
||||
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::GraphRewrite::run_on_function");
|
||||
|
||||
bool rewritten = false;
|
||||
const auto& pass_config = get_pass_config();
|
||||
// Initialize execution queue with nodes in topological order
|
||||
deque<std::shared_ptr<Node>> nodes_to_run;
|
||||
for (auto& node : f->get_ordered_ops())
|
||||
{
|
||||
nodes_to_run.emplace_front(node);
|
||||
}
|
||||
return apply_matcher_passes(f, std::move(nodes_to_run));
|
||||
}
|
||||
|
||||
bool pass::GraphRewrite::run_on_function(std::shared_ptr<ngraph::Function> f)
|
||||
{
|
||||
// Initialize execution queue with nodes in topological order
|
||||
deque<std::shared_ptr<Node>> nodes_to_run;
|
||||
for (auto& node : f->get_ordered_ops())
|
||||
{
|
||||
nodes_to_run.emplace_back(node);
|
||||
}
|
||||
return apply_matcher_passes(f, std::move(nodes_to_run));
|
||||
}
|
||||
|
||||
bool pass::GraphRewrite::apply_matcher_passes(shared_ptr<Function> f,
|
||||
deque<std::shared_ptr<Node>> nodes_to_run)
|
||||
{
|
||||
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::GraphRewrite::run_on_function");
|
||||
|
||||
bool rewritten = false;
|
||||
const auto& pass_config = get_pass_config();
|
||||
|
||||
// Check that all Matchers in MatcherPasses has type bases root node
|
||||
bool all_roots_has_type = true;
|
||||
|
@ -39,6 +39,23 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class GatherNodesPass : public ngraph::pass::MatcherPass
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
GatherNodesPass(NodeVector & order)
|
||||
: MatcherPass()
|
||||
{
|
||||
ngraph::matcher_pass_callback callback = [&order](pattern::Matcher& m) {
|
||||
order.push_back(m.get_match_root());
|
||||
return false;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::any_input(), "GatherNodesPass");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
};
|
||||
|
||||
class Anchor : public ngraph::pass::GraphRewrite
|
||||
{
|
||||
public:
|
||||
@ -51,6 +68,7 @@ public:
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(TestPass, "TestPass", 0);
|
||||
NGRAPH_RTTI_DEFINITION(Anchor, "Anchor", 0);
|
||||
NGRAPH_RTTI_DEFINITION(GatherNodesPass, "GatherNodesPass", 0);
|
||||
|
||||
std::shared_ptr<Function> get_function()
|
||||
{
|
||||
@ -77,6 +95,34 @@ ngraph::pass::param_callback get_callback()
|
||||
};
|
||||
}
|
||||
|
||||
TEST(GraphRewriteOrderTest, MatcherPass)
|
||||
{
|
||||
auto f = get_function();
|
||||
|
||||
NodeVector order;
|
||||
ngraph::pass::Manager m;
|
||||
auto pass = m.register_pass<pass::GraphRewrite>();
|
||||
pass->add_matcher<GatherNodesPass>(order);
|
||||
m.run_passes(f);
|
||||
|
||||
ASSERT_EQ(order, f->get_ordered_ops());
|
||||
}
|
||||
|
||||
TEST(BackwardGraphRewriteOrderTest, MatcherPass)
|
||||
{
|
||||
auto f = get_function();
|
||||
|
||||
NodeVector order;
|
||||
ngraph::pass::Manager m;
|
||||
auto pass = m.register_pass<pass::BackwardGraphRewrite>();
|
||||
pass->add_matcher<GatherNodesPass>(order);
|
||||
m.run_passes(f);
|
||||
|
||||
auto ref_order = f->get_ordered_ops();
|
||||
std::reverse(ref_order.begin(), ref_order.end());
|
||||
ASSERT_EQ(order, ref_order);
|
||||
}
|
||||
|
||||
TEST(GraphRewriteTest, MatcherPassCallback)
|
||||
{
|
||||
auto f = get_function();
|
||||
|
@ -124,9 +124,25 @@ TEST(op, variant)
|
||||
EXPECT_EQ(ship.y, 4);
|
||||
|
||||
auto node = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
// Check Node RTInfo
|
||||
node->get_rt_info()["A"] = var_ship;
|
||||
auto node_var_ship = node->get_rt_info().at("A");
|
||||
ASSERT_TRUE((is_type<VariantWrapper<Ship>>(node_var_ship)));
|
||||
Ship& node_ship = as_type_ptr<VariantWrapper<Ship>>(node_var_ship)->get();
|
||||
EXPECT_EQ(&node_ship, &ship);
|
||||
|
||||
// Check Node Input<Node> RTInfo
|
||||
auto relu = make_shared<op::Relu>(node);
|
||||
relu->input(0).get_rt_info()["A"] = var_ship;
|
||||
auto node_input_var_ship = node->get_rt_info().at("A");
|
||||
ASSERT_TRUE((is_type<VariantWrapper<Ship>>(node_input_var_ship)));
|
||||
Ship& node_input_ship = as_type_ptr<VariantWrapper<Ship>>(node_input_var_ship)->get();
|
||||
EXPECT_EQ(&node_input_ship, &ship);
|
||||
|
||||
// Check Node Input<Node> RTInfo
|
||||
node->output(0).get_rt_info()["A"] = var_ship;
|
||||
auto node_output_var_ship = node->get_rt_info().at("A");
|
||||
ASSERT_TRUE((is_type<VariantWrapper<Ship>>(node_output_var_ship)));
|
||||
Ship& node_output_ship = as_type_ptr<VariantWrapper<Ship>>(node_input_var_ship)->get();
|
||||
EXPECT_EQ(&node_output_ship, &ship);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user