Dynamism in new API (#7742)

* Fixed precisions conversion in new API

* Added tests

* Fixed old IR cases

* Disable FP16

* Fixed regex for CentoOS

* Refactored tests to use new API

* Temp

* Fixed tests

* Moved smart reshape related sources to ngraph

* Added tests for invalid names

* Moved reshape to tensor_names

* clang-format

* Fixed CC build

* Removed IEConv, IEDeconv from primitives pririty
This commit is contained in:
Ilya Lavrenov 2021-09-30 15:04:24 +03:00 committed by GitHub
parent bd5b1bf99f
commit 757db35528
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 958 additions and 716 deletions

View File

@ -355,11 +355,7 @@ StatusCode CNNNetworkNGraphImpl::reshape(const std::map<std::string, ngraph::Par
ssr_manager.register_pass<ngraph::pass::SmartReshape>(); ssr_manager.register_pass<ngraph::pass::SmartReshape>();
ssr_manager.run_passes(_ngraph_function); ssr_manager.run_passes(_ngraph_function);
std::map<std::string, ngraph::PartialShape> reshapeShapes; reshape(inputShapes);
for (const auto& item : inputShapes) {
reshapeShapes[item.first] = ngraph::PartialShape(item.second);
}
reshape(reshapeShapes);
} catch (std::exception& ex) { } catch (std::exception& ex) {
reshape(originalInputShapes); reshape(originalInputShapes);
return DescriptionBuffer(GENERAL_ERROR, responseDesc) << ex.what(); return DescriptionBuffer(GENERAL_ERROR, responseDesc) << ex.what();
@ -368,7 +364,7 @@ StatusCode CNNNetworkNGraphImpl::reshape(const std::map<std::string, ngraph::Par
return OK; return OK;
} }
StatusCode CNNNetworkNGraphImpl::reshape(const std::map<std::string, std::vector<size_t>>& inputShapes, StatusCode CNNNetworkNGraphImpl::reshape(const std::map<std::string, SizeVector>& inputShapes,
ResponseDesc* responseDesc) noexcept { ResponseDesc* responseDesc) noexcept {
std::map<std::string, ngraph::PartialShape> shapes; std::map<std::string, ngraph::PartialShape> shapes;
for (const auto& shape : inputShapes) for (const auto& shape : inputShapes)

View File

@ -73,8 +73,7 @@ public:
virtual void validate(int = 10); virtual void validate(int = 10);
StatusCode reshape(const std::map<std::string, std::vector<size_t>>& inputShapes, StatusCode reshape(const std::map<std::string, SizeVector>& inputShapes, ResponseDesc* resp) noexcept override;
ResponseDesc* resp) noexcept override;
StatusCode reshape(const std::map<std::string, ngraph::PartialShape>& inputShapes, StatusCode reshape(const std::map<std::string, ngraph::PartialShape>& inputShapes,
ResponseDesc* resp) noexcept override; ResponseDesc* resp) noexcept override;

View File

@ -1298,22 +1298,22 @@ std::map<std::string, Version> Core::get_versions(const std::string& deviceName)
} }
#ifdef OPENVINO_ENABLE_UNICODE_PATH_SUPPORT #ifdef OPENVINO_ENABLE_UNICODE_PATH_SUPPORT
std::shared_ptr<ngraph::Function> Core::read_model(const std::wstring& modelPath, const std::wstring& binPath) const { std::shared_ptr<ov::Function> Core::read_model(const std::wstring& modelPath, const std::wstring& binPath) const {
OV_CORE_CALL_STATEMENT( OV_CORE_CALL_STATEMENT(
return _impl->ReadNetwork(ov::util::wstring_to_string(modelPath), ov::util::wstring_to_string(binPath)) return _impl->ReadNetwork(ov::util::wstring_to_string(modelPath), ov::util::wstring_to_string(binPath))
.getFunction();); .getFunction(););
} }
#endif #endif
std::shared_ptr<ngraph::Function> Core::read_model(const std::string& modelPath, const std::string& binPath) const { std::shared_ptr<ov::Function> Core::read_model(const std::string& modelPath, const std::string& binPath) const {
OV_CORE_CALL_STATEMENT(return _impl->ReadNetwork(modelPath, binPath).getFunction();); OV_CORE_CALL_STATEMENT(return _impl->ReadNetwork(modelPath, binPath).getFunction(););
} }
std::shared_ptr<ngraph::Function> Core::read_model(const std::string& model, const ie::Blob::CPtr& weights) const { std::shared_ptr<ov::Function> Core::read_model(const std::string& model, const ie::Blob::CPtr& weights) const {
OV_CORE_CALL_STATEMENT(return _impl->ReadNetwork(model, weights).getFunction();); OV_CORE_CALL_STATEMENT(return _impl->ReadNetwork(model, weights).getFunction(););
} }
ExecutableNetwork Core::compile_model(const std::shared_ptr<const ngraph::Function>& model, ExecutableNetwork Core::compile_model(const std::shared_ptr<const ov::Function>& model,
const std::string& deviceName, const std::string& deviceName,
const ConfigMap& config) { const ConfigMap& config) {
OV_CORE_CALL_STATEMENT({ OV_CORE_CALL_STATEMENT({
@ -1333,7 +1333,7 @@ ExecutableNetwork Core::compile_model(const std::string& modelPath,
}); });
} }
ExecutableNetwork Core::compile_model(const std::shared_ptr<const ngraph::Function>& model, ExecutableNetwork Core::compile_model(const std::shared_ptr<const ov::Function>& model,
const RemoteContext& context, const RemoteContext& context,
const ConfigMap& config) { const ConfigMap& config) {
OV_CORE_CALL_STATEMENT({ OV_CORE_CALL_STATEMENT({
@ -1382,7 +1382,7 @@ ExecutableNetwork Core::import_model(std::istream& modelStream, const RemoteCont
}); });
} }
SupportedOpsMap Core::query_model(const std::shared_ptr<const ngraph::Function>& model, SupportedOpsMap Core::query_model(const std::shared_ptr<const ov::Function>& model,
const std::string& deviceName, const std::string& deviceName,
const ConfigMap& config) const { const ConfigMap& config) const {
OV_CORE_CALL_STATEMENT({ OV_CORE_CALL_STATEMENT({

View File

@ -19,8 +19,6 @@
#include "multi_device_plugin.hpp" #include "multi_device_plugin.hpp"
#include "ngraph/opsets/opset1.hpp" #include "ngraph/opsets/opset1.hpp"
#include "ngraph_ops/convolution_ie.hpp"
#include "ngraph_ops/deconvolution_ie.hpp"
#include "transformations/utils/utils.hpp" #include "transformations/utils/utils.hpp"
// ------------------------------MultiDeviceExecutableNetwork---------------------------- // ------------------------------MultiDeviceExecutableNetwork----------------------------
@ -38,9 +36,7 @@ std::string GetNetworkPrecision(const InferenceEngine::CNNNetwork &network) {
if (std::dynamic_pointer_cast<ngraph::opset1::Convolution>(node) || if (std::dynamic_pointer_cast<ngraph::opset1::Convolution>(node) ||
std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution>(node) || std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution>(node) ||
std::dynamic_pointer_cast<ngraph::opset1::GroupConvolutionBackpropData>(node) || std::dynamic_pointer_cast<ngraph::opset1::GroupConvolutionBackpropData>(node) ||
std::dynamic_pointer_cast<ngraph::opset1::ConvolutionBackpropData>(node) || std::dynamic_pointer_cast<ngraph::opset1::ConvolutionBackpropData>(node)) {
std::dynamic_pointer_cast<ngraph::op::ConvolutionIE>(node) ||
std::dynamic_pointer_cast<ngraph::op::DeconvolutionIE>(node)) {
auto layerType = node->input(1).get_element_type().get_type_name(); auto layerType = node->input(1).get_element_type().get_type_name();
if (layerType == "f32") if (layerType == "f32")
return METRIC_VALUE(FP32); return METRIC_VALUE(FP32);

View File

@ -12,8 +12,6 @@
#include <ngraph/opsets/opset1.hpp> #include <ngraph/opsets/opset1.hpp>
#include <transformations/utils/utils.hpp> #include <transformations/utils/utils.hpp>
#include "ngraph_ops/convolution_ie.hpp"
#include "ngraph_ops/deconvolution_ie.hpp"
#include <ie_metric_helpers.hpp> #include <ie_metric_helpers.hpp>
#include <ie_performance_hints.hpp> #include <ie_performance_hints.hpp>
@ -37,9 +35,7 @@ namespace {
if (std::dynamic_pointer_cast<ngraph::opset1::Convolution>(node) || if (std::dynamic_pointer_cast<ngraph::opset1::Convolution>(node) ||
std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution>(node) || std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution>(node) ||
std::dynamic_pointer_cast<ngraph::opset1::GroupConvolutionBackpropData>(node) || std::dynamic_pointer_cast<ngraph::opset1::GroupConvolutionBackpropData>(node) ||
std::dynamic_pointer_cast<ngraph::opset1::ConvolutionBackpropData>(node) || std::dynamic_pointer_cast<ngraph::opset1::ConvolutionBackpropData>(node)) {
std::dynamic_pointer_cast<ngraph::op::ConvolutionIE>(node) ||
std::dynamic_pointer_cast<ngraph::op::DeconvolutionIE>(node)) {
auto layerType = node->input(1).get_element_type().get_type_name(); auto layerType = node->input(1).get_element_type().get_type_name();
if (layerType == "f32") if (layerType == "f32")
return METRIC_VALUE(FP32); return METRIC_VALUE(FP32);

View File

@ -12,8 +12,6 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp> #include <ngraph/pass/graph_rewrite.hpp>
/** /**
@ -26,7 +24,7 @@ namespace ngraph {
*/ */
namespace pass { namespace pass {
class TRANSFORMATIONS_API InitNodeInfo; class NGRAPH_API InitNodeInfo;
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -19,7 +19,6 @@
#include <ngraph/variant.hpp> #include <ngraph/variant.hpp>
#include <openvino/core/rtti.hpp> #include <openvino/core/rtti.hpp>
#include <ngraph/attribute_visitor.hpp> #include <ngraph/attribute_visitor.hpp>
#include <transformations_visibility.hpp>
namespace ngraph { namespace ngraph {
@ -29,7 +28,7 @@ namespace ngraph {
* @brief FusedName class represents runtime info attribute that stores * @brief FusedName class represents runtime info attribute that stores
* all operation names that was fully or partially fused into node * all operation names that was fully or partially fused into node
*/ */
class TRANSFORMATIONS_API FusedNames { class NGRAPH_API FusedNames {
private: private:
std::set<std::string> fused_names; std::set<std::string> fused_names;
@ -72,7 +71,7 @@ public:
* @brief getFusedNames return string with operation names separated by coma in alphabetical order * @brief getFusedNames return string with operation names separated by coma in alphabetical order
* @param[in] node The node will be used to get FusedNames attribute * @param[in] node The node will be used to get FusedNames attribute
*/ */
TRANSFORMATIONS_API std::string getFusedNames(const std::shared_ptr<ngraph::Node> & node); NGRAPH_API std::string getFusedNames(const std::shared_ptr<ngraph::Node> & node);
/** /**
* @ingroup ie_runtime_attr_api * @ingroup ie_runtime_attr_api
@ -80,16 +79,16 @@ TRANSFORMATIONS_API std::string getFusedNames(const std::shared_ptr<ngraph::Node
* @param[in] node The node will be used to get FusedNames attribute * @param[in] node The node will be used to get FusedNames attribute
* @return vector of strings * @return vector of strings
*/ */
TRANSFORMATIONS_API std::vector<std::string> getFusedNamesVector(const std::shared_ptr<ngraph::Node> & node); NGRAPH_API std::vector<std::string> getFusedNamesVector(const std::shared_ptr<ngraph::Node> & node);
} // namespace ngraph } // namespace ngraph
namespace ov { namespace ov {
extern template class TRANSFORMATIONS_API VariantImpl<ngraph::FusedNames>; extern template class NGRAPH_API VariantImpl<ngraph::FusedNames>;
template<> template<>
class TRANSFORMATIONS_API VariantWrapper<ngraph::FusedNames> : public VariantImpl<ngraph::FusedNames> { class NGRAPH_API VariantWrapper<ngraph::FusedNames> : public VariantImpl<ngraph::FusedNames> {
public: public:
OPENVINO_RTTI("fused_names", "0"); OPENVINO_RTTI("fused_names", "0");
@ -105,7 +104,7 @@ public:
}; };
template <> template <>
class TRANSFORMATIONS_API AttributeAdapter<std::set<std::string>> : public DirectValueAccessor<std::set<std::string>> { class NGRAPH_API AttributeAdapter<std::set<std::string>> : public DirectValueAccessor<std::set<std::string>> {
public: public:
OPENVINO_RTTI("AttributeAdapter<set<string>>"); OPENVINO_RTTI("AttributeAdapter<set<string>>");
AttributeAdapter(std::set<std::string>& value) : DirectValueAccessor<std::set<std::string>>(value) {} AttributeAdapter(std::set<std::string>& value) : DirectValueAccessor<std::set<std::string>>(value) {}

View File

@ -17,7 +17,6 @@
#include <ngraph/node.hpp> #include <ngraph/node.hpp>
#include <ngraph/variant.hpp> #include <ngraph/variant.hpp>
#include <transformations_visibility.hpp>
namespace ov { namespace ov {
@ -26,7 +25,7 @@ namespace ov {
* @brief PrimitivesPriority class represents runtime info attribute that * @brief PrimitivesPriority class represents runtime info attribute that
* can be used for plugins specific primitive choice. * can be used for plugins specific primitive choice.
*/ */
class TRANSFORMATIONS_API PrimitivesPriority { class NGRAPH_API PrimitivesPriority {
private: private:
std::string primitives_priority; std::string primitives_priority;
@ -53,12 +52,12 @@ public:
* @brief getPrimitivesPriority return string with primitive priorities value * @brief getPrimitivesPriority return string with primitive priorities value
* @param[in] node The node will be used to get PrimitivesPriority attribute * @param[in] node The node will be used to get PrimitivesPriority attribute
*/ */
TRANSFORMATIONS_API std::string getPrimitivesPriority(const std::shared_ptr<ngraph::Node> & node); NGRAPH_API std::string getPrimitivesPriority(const std::shared_ptr<ngraph::Node> & node);
extern template class TRANSFORMATIONS_API VariantImpl<PrimitivesPriority>; extern template class NGRAPH_API VariantImpl<PrimitivesPriority>;
template<> template<>
class TRANSFORMATIONS_API VariantWrapper<PrimitivesPriority> : public VariantImpl<PrimitivesPriority> { class NGRAPH_API VariantWrapper<PrimitivesPriority> : public VariantImpl<PrimitivesPriority> {
public: public:
OPENVINO_RTTI("primitives_priority", "0"); OPENVINO_RTTI("primitives_priority", "0");

View File

@ -7,16 +7,14 @@
#include <memory> #include <memory>
#include <functional> #include <functional>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp> #include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
class TRANSFORMATIONS_API ReshapeAMatMul; class NGRAPH_API ReshapeAMatMul;
class TRANSFORMATIONS_API ReshapeBMatMul; class NGRAPH_API ReshapeBMatMul;
class TRANSFORMATIONS_API TransposeMatMul; class NGRAPH_API TransposeMatMul;
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -8,21 +8,12 @@
#include <memory> #include <memory>
#include <numeric> #include <numeric>
#include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/pattern/matcher.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp> #include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
class TRANSFORMATIONS_API MimicSetBatchSize; class NGRAPH_API MimicSetBatchSize;
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -7,15 +7,13 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp> #include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
class TRANSFORMATIONS_API Proposal1Scales; class NGRAPH_API Proposal1Scales;
class TRANSFORMATIONS_API Proposal4Scales; class NGRAPH_API Proposal4Scales;
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -7,14 +7,12 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp> #include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
class TRANSFORMATIONS_API ReshapeTo1D; class NGRAPH_API ReshapeTo1D;
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -7,15 +7,12 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp> #include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
class TRANSFORMATIONS_API SetBatchSize; class NGRAPH_API SetBatchSize;
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -7,15 +7,12 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp> #include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
class TRANSFORMATIONS_API SmartReshape; class NGRAPH_API SmartReshape;
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -7,16 +7,14 @@
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp> #include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph { namespace ngraph {
namespace pass { namespace pass {
class TRANSFORMATIONS_API StridedSliceSqueeze; class NGRAPH_API StridedSliceSqueeze;
class TRANSFORMATIONS_API SqueezeStridedSlice; class NGRAPH_API SqueezeStridedSlice;
class TRANSFORMATIONS_API SharedSqueeze; class NGRAPH_API SharedSqueeze;
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph

View File

@ -1,61 +0,0 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "itt.hpp"
#include <ngraph/pass/constant_folding.hpp>
#include <transformations/smart_reshape/mimic_set_batch_size.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::MimicSetBatchSize, "MimicSetBatchSize", 0);
bool ngraph::pass::MimicSetBatchSize::run_on_function(std::shared_ptr<ngraph::Function> f) {
RUN_ON_FUNCTION_SCOPE(MimicSetBatchSize);
// extracting ratio of out to in 0-index dimension value from the folded function
auto specialized_function = ngraph::clone_function(*f);
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.run_passes(specialized_function);
std::map<std::string, float> scale;
for (const auto & node : specialized_function->get_ops()) {
if (const auto & reshape = std::dynamic_pointer_cast<opset5::Reshape>(node)) {
const auto in_pshape = reshape->get_input_partial_shape(0), out_pshape = reshape->get_output_partial_shape(0);
if (in_pshape.rank().is_dynamic() || in_pshape.rank().get_length() <= 1 || in_pshape[0].is_dynamic() ||
out_pshape.rank().is_dynamic() || out_pshape.rank().get_length() <= 1 || out_pshape[0].is_dynamic())
continue;
const auto & pattern = std::dynamic_pointer_cast<opset5::Constant>(reshape->get_input_node_shared_ptr(1));
if (pattern && pattern->cast_vector<int64_t>()[0] > 0) {
scale[reshape->get_friendly_name()] = static_cast<float>(out_pshape[0].get_length()) / static_cast<float>(in_pshape[0].get_length());
}
}
}
// apply transformation to original function
bool transformed = false;
for (auto & reshape : f->get_ops()) {
if (!is_type<opset5::Reshape>(reshape) || !scale.count(reshape->get_friendly_name()) || reshape->get_output_partial_shape(0).rank().is_dynamic())
continue;
const auto & shape_of = std::make_shared<opset5::ShapeOf>(reshape->get_input_source_output(0), reshape->get_input_element_type(1));
const auto & new_input_batch = std::make_shared<ngraph::opset5::Gather>(
shape_of, ngraph::opset5::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{0}),
ngraph::opset5::Constant::create(ngraph::element::i64, {}, std::vector<int64_t>{0}));
const std::shared_ptr<Node> & new_output_batch = std::make_shared<opset5::Convert>(
std::make_shared<opset5::Ceiling>(
std::make_shared<opset5::Multiply>(
std::make_shared<opset5::Convert>(new_input_batch, element::f32),
opset5::Constant::create(element::f32, {1}, {scale[reshape->get_friendly_name()]}))),
reshape->get_input_element_type(1));
std::vector<int64_t> non_batch_dims(reshape->get_output_partial_shape(0).rank().get_length() - 1);
std::iota(non_batch_dims.begin(), non_batch_dims.end(), 1);
const auto & non_batch_dims_node = std::make_shared<ngraph::opset5::Gather>(
reshape->input_value(1),
ngraph::opset5::Constant::create(ngraph::element::i64, {non_batch_dims.size()}, non_batch_dims),
ngraph::opset5::Constant::create(ngraph::element::i64, {}, std::vector<int64_t>{0}));
auto new_reshape_pattern = std::make_shared<opset5::Concat>(OutputVector{new_output_batch, non_batch_dims_node}, 0);
reshape->input(1).replace_source_output(new_reshape_pattern->output(0));
transformed = true;
}
return transformed;
}

View File

@ -1,68 +0,0 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "itt.hpp"
#include <transformations/smart_reshape/proposal_scales_stridedslice.hpp>
#include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/pattern/matcher.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
bool crop_scales_for_proposal(const ngraph::pattern::PatternValueMap & pattern_to_output,
std::shared_ptr<ngraph::Node> parameter_label, std::shared_ptr<ngraph::Node> proposal_label) {
const auto & parameter = pattern_to_output.at(parameter_label);
const auto & proposal = pattern_to_output.at(proposal_label).get_node_shared_ptr();
auto cropped_scales = std::make_shared<ngraph::opset5::StridedSlice>(
proposal->input_value(2),
ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}),
ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {parameter.get_partial_shape()[1].get_length()}),
ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1}),
std::vector<int64_t>{0}, std::vector<int64_t>{0});
proposal->input(2).replace_source_output(cropped_scales->output(0));
return true;
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::Proposal1Scales, "Proposal1Scales", 0);
ngraph::pass::Proposal1Scales::Proposal1Scales() {
MATCHER_SCOPE(Proposal1Scales);
auto parameter_label = ngraph::pattern::wrap_type<opset5::Parameter>([](const Output<Node> &output) {
const auto & shape = output.get_partial_shape();
return shape.rank().is_static() && shape.rank().get_length() == 2 && shape[1].is_static() && (shape[1].get_length() == 3 || shape[1].get_length() == 4);
});
auto reshape_label = ngraph::pattern::wrap_type<opset5::Reshape>({parameter_label, ngraph::pattern::wrap_type<opset5::Constant>()},
[](const Output<Node> &output) { return output.get_partial_shape().rank().is_static() && output.get_partial_shape().rank().get_length() == 1; });
auto proposal_label = ngraph::pattern::wrap_type<opset1::Proposal>({pattern::any_input(), pattern::any_input(), reshape_label});
matcher_pass_callback callback = [parameter_label, proposal_label](pattern::Matcher &m) -> bool {
return crop_scales_for_proposal(m.get_pattern_value_map(), parameter_label, proposal_label);
};
auto m = std::make_shared<ngraph::pattern::Matcher>(proposal_label, matcher_name);
register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::Proposal4Scales, "Proposal4Scales", 0);
ngraph::pass::Proposal4Scales::Proposal4Scales() {
MATCHER_SCOPE(Proposal4Scales);
auto parameter_label = ngraph::pattern::wrap_type<opset5::Parameter>([](const Output<Node> &output) {
const auto & shape = output.get_partial_shape();
return shape.rank().is_static() && shape.rank().get_length() == 2 && shape[1].is_static() && (shape[1].get_length() == 3 || shape[1].get_length() == 4);
});
auto reshape_label = ngraph::pattern::wrap_type<opset5::Reshape>({parameter_label, ngraph::pattern::wrap_type<opset5::Constant>()},
[](const Output<Node> &output) { return output.get_partial_shape().rank().is_static() && output.get_partial_shape().rank().get_length() == 1; });
auto proposal_label = ngraph::pattern::wrap_type<opset4::Proposal>({pattern::any_input(), pattern::any_input(), reshape_label});
matcher_pass_callback callback = [parameter_label, proposal_label](pattern::Matcher &m) -> bool {
return crop_scales_for_proposal(m.get_pattern_value_map(), parameter_label, proposal_label);
};
auto m = std::make_shared<ngraph::pattern::Matcher>(proposal_label, matcher_name);
register_matcher(m, callback);
}

View File

@ -81,97 +81,6 @@ TEST_F(NGraphReshapeTests, ReshapedDynamicShapeLayout) {
ASSERT_FALSE(cnnNetwork.getInputsInfo()["A"]->getInputData()->isDynamic()); ASSERT_FALSE(cnnNetwork.getInputsInfo()["A"]->getInputData()->isDynamic());
} }
TEST_F(NGraphReshapeTests, ReshapeBatchReLU) {
std::shared_ptr<ngraph::Function> ngraph;
{
ngraph::PartialShape shape({1, 3, 22, 22});
ngraph::element::Type type(ngraph::element::Type_t::f32);
auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
auto relu = std::make_shared<ngraph::op::Relu>(param);
auto result = std::make_shared<ngraph::op::Result>(relu);
ngraph::ParameterVector params = {param};
ngraph::ResultVector results = {result};
ngraph = std::make_shared<ngraph::Function>(results, params);
}
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
{
ngraph::PartialShape shape({2, 3, 22, 22});
ngraph::element::Type type(ngraph::element::Type_t::f32);
auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
ngraph->replace_parameter(0, param);
ngraph->validate_nodes_and_infer_types();
}
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({2, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({2, 3, 22, 22}));
}
TEST_F(NGraphReshapeTests, ReshapeSpatialReLU) {
std::shared_ptr<ngraph::Function> ngraph;
{
ngraph::PartialShape shape({1, 3, 22, 22});
ngraph::element::Type type(ngraph::element::Type_t::f32);
auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
auto relu = std::make_shared<ngraph::op::Relu>(param);
auto result = std::make_shared<ngraph::op::Result>(relu);
ngraph::ParameterVector params = {param};
ngraph::ResultVector results = {result};
ngraph = std::make_shared<ngraph::Function>(results, params);
}
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
{
ngraph::PartialShape shape({1, 3, 25, 25});
ngraph::element::Type type(ngraph::element::Type_t::f32);
auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
ngraph->replace_parameter(0, param);
ngraph->validate_nodes_and_infer_types();
}
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
}
TEST_F(NGraphReshapeTests, ReshapeSpatialReLUWithoutReplaceParameter) {
std::shared_ptr<ngraph::Function> ngraph;
{
ngraph::PartialShape shape({1, 3, 22, 22});
ngraph::element::Type type(ngraph::element::Type_t::f32);
auto param = std::make_shared<ngraph::op::Parameter>(type, shape);
auto relu = std::make_shared<ngraph::op::Relu>(param);
auto result = std::make_shared<ngraph::op::Result>(relu);
ngraph::ParameterVector params = {param};
ngraph::ResultVector results = {result};
ngraph = std::make_shared<ngraph::Function>(results, params);
}
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 22, 22}));
{
ngraph->get_parameters()[0]->set_partial_shape({1, 3, 25, 25});
ngraph->validate_nodes_and_infer_types();
}
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ngraph::Shape({1, 3, 25, 25}));
}
TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLU) { TEST_F(NGraphReshapeTests, CNNReshapeSpatialReLU) {
std::shared_ptr<const ngraph::Function> ngraph; std::shared_ptr<const ngraph::Function> ngraph;
{ {

View File

@ -24,6 +24,7 @@ std::shared_ptr<ngraph::Function> getFunction1() {
auto params = ngraph::builder::makeParams(ngPrc, {inputShape}); auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
params.front()->set_friendly_name("Param_1"); params.front()->set_friendly_name("Param_1");
params.front()->get_output_tensor(0).set_names({"Tensor_1"});
auto in2add = ngraph::builder::makeConstant(ngPrc, {1, 4, 1, 1}, std::vector<float>{}, true); auto in2add = ngraph::builder::makeConstant(ngPrc, {1, 4, 1, 1}, std::vector<float>{}, true);
auto add = ngraph::builder::makeEltwise(params[0], in2add, ngraph::helpers::EltwiseTypes::ADD); auto add = ngraph::builder::makeEltwise(params[0], in2add, ngraph::helpers::EltwiseTypes::ADD);
@ -40,6 +41,7 @@ std::shared_ptr<ngraph::Function> getFunction2() {
auto params = ngraph::builder::makeParams(ngPrc, {inputShape}); auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
params.front()->set_friendly_name("Param_1"); params.front()->set_friendly_name("Param_1");
params.front()->get_output_tensor(0).set_names({"Tensor_1"});
auto split = ngraph::builder::makeSplit(params[0], ngPrc, 2, 1); auto split = ngraph::builder::makeSplit(params[0], ngPrc, 2, 1);
auto in2add = ngraph::builder::makeConstant(ngPrc, {1, 2, 1, 1}, std::vector<float>{}, true); auto in2add = ngraph::builder::makeConstant(ngPrc, {1, 2, 1, 1}, std::vector<float>{}, true);

View File

@ -6,16 +6,19 @@
#include <chrono> #include <chrono>
#include <future> #include <future>
#include <gtest/gtest.h>
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include <string> #include <string>
#include <memory> #include <memory>
#include "functional_test_utils/ov_plugin_cache.hpp"
#include "ie_extension.h" #include "ie_extension.h"
#include <condition_variable> #include <condition_variable>
#include "openvino/core/shape.hpp"
#include "shared_test_classes/base/layer_test_utils.hpp" #include "shared_test_classes/base/layer_test_utils.hpp"
#include "ngraph_functions/utils/ngraph_helpers.hpp" #include "ngraph_functions/utils/ngraph_helpers.hpp"
#include "ngraph_functions/builders.hpp" #include "ngraph_functions/builders.hpp"
#include <ngraph/opsets/opset6.hpp> #include "transformations/utils/utils.hpp"
#include <string> #include <string>
#include <ie_core.hpp> #include <ie_core.hpp>
#include <thread> #include <thread>
@ -30,7 +33,7 @@
namespace BehaviorTestsDefinitions { namespace BehaviorTestsDefinitions {
typedef std::tuple< typedef std::tuple<
std::shared_ptr<ngraph::Function>, // ngraph function std::shared_ptr<ov::Function>, // ov function
std::vector<std::pair<std::vector<size_t>, std::vector<size_t>>>, // input/expected output shapes per inference std::vector<std::pair<std::vector<size_t>, std::vector<size_t>>>, // input/expected output shapes per inference
std::string, // Device name std::string, // Device name
std::map<std::string, std::string> // Config std::map<std::string, std::string> // Config
@ -40,7 +43,7 @@ class InferRequestDynamicTests : public testing::WithParamInterface<InferRequest
public CommonTestUtils::TestsCommon { public CommonTestUtils::TestsCommon {
public: public:
static std::string getTestCaseName(testing::TestParamInfo<InferRequestDynamicParams> obj) { static std::string getTestCaseName(testing::TestParamInfo<InferRequestDynamicParams> obj) {
std::shared_ptr<ngraph::Function> func; std::shared_ptr<ov::Function> func;
std::vector<std::pair<std::vector<size_t>, std::vector<size_t>>> inOutShapes; std::vector<std::pair<std::vector<size_t>, std::vector<size_t>>> inOutShapes;
std::string targetDevice; std::string targetDevice;
std::map<std::string, std::string> configuration; std::map<std::string, std::string> configuration;
@ -69,276 +72,241 @@ protected:
void TearDown() override { void TearDown() override {
if (!configuration.empty()) { if (!configuration.empty()) {
PluginCache::get().reset(); ov::test::PluginCache::get().reset();
} }
function.reset(); function.reset();
} }
std::shared_ptr<InferenceEngine::Core> ie = PluginCache::get().ie(); std::shared_ptr<ov::runtime::Core> ie = ov::test::PluginCache::get().core();
std::shared_ptr<ngraph::Function> function; std::shared_ptr<ov::Function> function;
std::string targetDevice; std::string targetDevice;
std::map<std::string, std::string> configuration; std::map<std::string, std::string> configuration;
std::vector<std::pair<std::vector<size_t>, std::vector<size_t>>> inOutShapes; std::vector<std::pair<std::vector<size_t>, std::vector<size_t>>> inOutShapes;
}; };
TEST_P(InferRequestDynamicTests, InferDynamicNetworkWithoutSetShape) { TEST_P(InferRequestDynamicTests, InferDynamicNetworkWithoutSetShape) {
const std::string param_name = "Param_1"; const std::string tensor_name = "Tensor_1";
// Create CNNNetwork from ngrpah::Function std::map<std::string, ov::PartialShape> shapes;
InferenceEngine::CNNNetwork cnnNet(function); shapes[tensor_name] = {ov::Dimension::dynamic(), 4, 20, 20};
std::map<std::string, ngraph::PartialShape> shapes; ASSERT_NO_THROW(function->reshape(shapes));
shapes[param_name] = {ngraph::Dimension::dynamic(), 4, 20, 20}; // Load ov::Function to target plugins
cnnNet.reshape(shapes); auto execNet = ie->compile_model(function, targetDevice, configuration);
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest // Create InferRequest
InferenceEngine::InferRequest req; ov::runtime::InferRequest req;
InferenceEngine::Blob::Ptr blob; ov::runtime::Tensor tensor;
ASSERT_NO_THROW(req = execNet.CreateInferRequest()); ASSERT_NO_THROW(req = execNet.create_infer_request());
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getInputsInfo().begin()->first)); ASSERT_NO_THROW(tensor = req.get_tensor(function->get_parameters().back()->get_friendly_name()));
} }
TEST_P(InferRequestDynamicTests, InferDynamicNetworkBoundWithoutSetShape) { TEST_P(InferRequestDynamicTests, InferDynamicNetworkBoundWithoutSetShape) {
const std::string param_name = "Param_1"; const std::string tensor_name = "Tensor_1";
// Create CNNNetwork from ngrpah::Function std::map<std::string, ov::PartialShape> shapes;
InferenceEngine::CNNNetwork cnnNet(function); shapes[tensor_name] = {ov::Dimension(0, 5), 4, 20, 20};
std::map<std::string, ngraph::PartialShape> shapes; ASSERT_NO_THROW(function->reshape(shapes));
shapes[param_name] = {ngraph::Dimension(0, 5), 4, 20, 20}; // Load ov::Function to target plugins
cnnNet.reshape(shapes); auto execNet = ie->compile_model(function, targetDevice, configuration);
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest // Create InferRequest
InferenceEngine::InferRequest req; ov::runtime::InferRequest req;
InferenceEngine::Blob::Ptr blob; ov::runtime::Tensor tensor;
ASSERT_NO_THROW(req = execNet.CreateInferRequest()); ASSERT_NO_THROW(req = execNet.create_infer_request());
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getInputsInfo().begin()->first)); ASSERT_NO_THROW(tensor = req.get_tensor(function->get_parameters().back()->get_friendly_name()));
} }
TEST_P(InferRequestDynamicTests, InferDynamicNetworkWithGetBlob) { TEST_P(InferRequestDynamicTests, InferDynamicNetworkWithGetTensor) {
const std::string param_name = "Param_1"; const std::string tensor_name = "Tensor_1";
const InferenceEngine::SizeVector refShape = inOutShapes[0].first; const ov::Shape refShape = inOutShapes[0].first;
const InferenceEngine::SizeVector refOutShape = inOutShapes[0].second; const ov::Shape refOutShape = inOutShapes[0].second;
// Create CNNNetwork from ngrpah::Function std::map<std::string, ov::PartialShape> shapes;
InferenceEngine::CNNNetwork cnnNet(function); shapes[tensor_name] = {ov::Dimension::dynamic(), 4, 20, 20};
std::map<std::string, ngraph::PartialShape> shapes; ASSERT_NO_THROW(function->reshape(shapes));
shapes[param_name] = {ngraph::Dimension::dynamic(), 4, 20, 20}; // Load ov::Function to target plugins
cnnNet.reshape(shapes); auto execNet = ie->compile_model(function, targetDevice, configuration);
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest // Create InferRequest
InferenceEngine::InferRequest req; ov::runtime::InferRequest req;
InferenceEngine::Blob::Ptr blob; ov::runtime::Tensor tensor;
ASSERT_NO_THROW(req = execNet.CreateInferRequest()); ASSERT_NO_THROW(req = execNet.create_infer_request());
//ASSERT_NO_THROW(req.SetShape(param_name, {1, 4, 20, 20})); //ASSERT_NO_THROW(req.SetShape(tensor_name, {1, 4, 20, 20}));
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getInputsInfo().begin()->first)); ASSERT_NO_THROW(tensor = req.get_tensor(function->get_parameters().back()->get_friendly_name()));
ASSERT_NO_THROW(blob->setShape({1, 4, 20, 20})); ASSERT_NO_THROW(tensor.set_shape({1, 4, 20, 20}));
ASSERT_EQ(blob->getTensorDesc().getDims(), refShape); ASSERT_EQ(tensor.get_shape(), refShape);
req.Infer(); ASSERT_NO_THROW(req.infer());
req.StartAsync(); ASSERT_NO_THROW(req.start_async());
InferenceEngine::StatusCode sts; req.wait();
sts = req.Wait(InferenceEngine::InferRequest::WaitMode::RESULT_READY); ASSERT_NO_THROW(tensor = req.get_tensor(ngraph::op::util::create_ie_output_name(function->get_results().front()->input_value(0))));
ASSERT_EQ(InferenceEngine::StatusCode::OK, sts); ASSERT_EQ(tensor.get_shape(), refOutShape);
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getOutputsInfo().begin()->first));
ASSERT_EQ(blob->getTensorDesc().getDims(), refOutShape);
} }
TEST_P(InferRequestDynamicTests, InferUpperBoundNetworkWithGetBlob) { TEST_P(InferRequestDynamicTests, InferUpperBoundNetworkWithGetTensor) {
const std::string param_name = "Param_1"; const std::string tensor_name = "Tensor_1";
const InferenceEngine::SizeVector refShape = inOutShapes[0].first; const ov::Shape refShape = inOutShapes[0].first;
const InferenceEngine::SizeVector refOutShape = inOutShapes[0].second; const ov::Shape refOutShape = inOutShapes[0].second;
// Create CNNNetwork from ngrpah::Function std::map<std::string, ov::PartialShape> shapes;
InferenceEngine::CNNNetwork cnnNet(function); shapes[tensor_name] = {ov::Dimension(0, 19), 4, 20, 20};
std::map<std::string, ngraph::PartialShape> shapes; ASSERT_NO_THROW(function->reshape(shapes));
shapes[param_name] = {ngraph::Dimension(0, 19), 4, 20, 20}; // Load ov::Function to target plugins
cnnNet.reshape(shapes); auto execNet = ie->compile_model(function, targetDevice, configuration);
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest // Create InferRequest
InferenceEngine::InferRequest req; ov::runtime::InferRequest req;
InferenceEngine::Blob::Ptr blob; ov::runtime::Tensor tensor;
ASSERT_NO_THROW(req = execNet.CreateInferRequest()); ASSERT_NO_THROW(req = execNet.create_infer_request());
//ASSERT_NO_THROW(req.SetShape(param_name, {1, 4, 20, 20})); //ASSERT_NO_THROW(req.SetShape(tensor_name, {1, 4, 20, 20}));
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getInputsInfo().begin()->first)); ASSERT_NO_THROW(tensor = req.get_tensor(function->get_parameters().back()->get_friendly_name()));
ASSERT_NO_THROW(blob->setShape({1, 4, 20, 20})); ASSERT_NO_THROW(tensor.set_shape({1, 4, 20, 20}));
ASSERT_EQ(blob->getTensorDesc().getDims(), refShape); ASSERT_EQ(tensor.get_shape(), refShape);
req.Infer(); ASSERT_NO_THROW(req.infer());
req.StartAsync(); ASSERT_NO_THROW(req.start_async());
InferenceEngine::StatusCode sts; req.wait();
sts = req.Wait(InferenceEngine::InferRequest::WaitMode::RESULT_READY); ASSERT_NO_THROW(tensor = req.get_tensor(ngraph::op::util::create_ie_output_name(function->get_results().front()->input_value(0))));
ASSERT_EQ(InferenceEngine::StatusCode::OK, sts); ASSERT_EQ(tensor.get_shape(), refOutShape);
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getOutputsInfo().begin()->first));
ASSERT_EQ(blob->getTensorDesc().getDims(), refOutShape);
} }
TEST_P(InferRequestDynamicTests, InferOutOfRangeShapeNetworkWithGetBlobLower) { TEST_P(InferRequestDynamicTests, InferOutOfRangeShapeNetworkWithGetTensorLower) {
const std::string param_name = "Param_1"; const std::string tensor_name = "Tensor_1";
const InferenceEngine::SizeVector refShape = inOutShapes[0].first; const ov::Shape refShape = inOutShapes[0].first;
const InferenceEngine::SizeVector refOutShape = inOutShapes[0].second; const ov::Shape refOutShape = inOutShapes[0].second;
// Create CNNNetwork from ngrpah::Function std::map<std::string, ov::PartialShape> shapes;
InferenceEngine::CNNNetwork cnnNet(function); shapes[tensor_name] = {ov::Dimension(2, 3), 4, 20, 20};
std::map<std::string, ngraph::PartialShape> shapes; ASSERT_NO_THROW(function->reshape(shapes));
shapes[param_name] = {ngraph::Dimension(2, 3), 4, 20, 20}; // Load ov::Function to target plugins
cnnNet.reshape(shapes); auto execNet = ie->compile_model(function, targetDevice, configuration);
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest // Create InferRequest
InferenceEngine::InferRequest req; ov::runtime::InferRequest req;
InferenceEngine::Blob::Ptr blob; ov::runtime::Tensor tensor;
ASSERT_NO_THROW(req = execNet.CreateInferRequest()); ASSERT_NO_THROW(req = execNet.create_infer_request());
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getInputsInfo().begin()->first)); ASSERT_NO_THROW(tensor = req.get_tensor(function->get_parameters().back()->get_friendly_name()));
ASSERT_NO_THROW(blob->setShape({1, 4, 20, 20})); ASSERT_NO_THROW(tensor.set_shape({1, 4, 20, 20}));
// Plugin may or may not throw in case if input tensor has dimensions that are out of bounds // Plugin may or may not throw in case if input tensor has dimensions that are out of bounds
//ASSERT_THROW(req.Infer(), InferenceEngine::Exception); //ASSERT_THROW(req.infer(), ov::Exception);
} }
TEST_P(InferRequestDynamicTests, InferOutOfRangeShapeNetworkWithGetBlobUpper) { TEST_P(InferRequestDynamicTests, InferOutOfRangeShapeNetworkWithGetTensorUpper) {
const std::string param_name = "Param_1"; const std::string tensor_name = "Tensor_1";
const InferenceEngine::SizeVector refShape = inOutShapes[0].first; const ov::Shape refShape = inOutShapes[0].first;
const InferenceEngine::SizeVector refOutShape = inOutShapes[0].second; const ov::Shape refOutShape = inOutShapes[0].second;
// Create CNNNetwork from ngrpah::Function std::map<std::string, ov::PartialShape> shapes;
InferenceEngine::CNNNetwork cnnNet(function); shapes[tensor_name] = {ov::Dimension(1, 2), 4, 20, 20};
std::map<std::string, ngraph::PartialShape> shapes; ASSERT_NO_THROW(function->reshape(shapes));
shapes[param_name] = {ngraph::Dimension(1, 2), 4, 20, 20}; // Load ov::Function to target plugins
cnnNet.reshape(shapes); auto execNet = ie->compile_model(function, targetDevice, configuration);
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest // Create InferRequest
InferenceEngine::InferRequest req; ov::runtime::InferRequest req;
InferenceEngine::Blob::Ptr blob; ov::runtime::Tensor tensor;
ASSERT_NO_THROW(req = execNet.CreateInferRequest()); ASSERT_NO_THROW(req = execNet.create_infer_request());
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getInputsInfo().begin()->first)); ASSERT_NO_THROW(tensor = req.get_tensor(function->get_parameters().back()->get_friendly_name()));
ASSERT_NO_THROW(blob->setShape({3, 4, 20, 20})); ASSERT_NO_THROW(tensor.set_shape({3, 4, 20, 20}));
// Plugin may or may not throw in case if input tensor has dimensions that are out of bounds // Plugin may or may not throw in case if input tensor has dimensions that are out of bounds
// ASSERT_THROW(req.Infer(), InferenceEngine::Exception); // ASSERT_THROW(req.infer(), ov::Exception);
} }
TEST_P(InferRequestDynamicTests, InferDynamicNetworkWithGetBlob2times) { TEST_P(InferRequestDynamicTests, InferDynamicNetworkWithGetTensor2times) {
const std::string param_name = "Param_1"; const std::string tensor_name = "Tensor_1";
const InferenceEngine::SizeVector refShape = inOutShapes[0].first; const ov::Shape refShape = inOutShapes[0].first;
const InferenceEngine::SizeVector refShape2 = inOutShapes[1].first; const ov::Shape refShape2 = inOutShapes[1].first;
const InferenceEngine::SizeVector refOutShape = inOutShapes[0].second; const ov::Shape refOutShape = inOutShapes[0].second;
const InferenceEngine::SizeVector refOutShape2 = inOutShapes[1].second; const ov::Shape refOutShape2 = inOutShapes[1].second;
// Create CNNNetwork from ngrpah::Function std::map<std::string, ov::PartialShape> shapes;
InferenceEngine::CNNNetwork cnnNet(function); shapes[tensor_name] = {ov::Dimension::dynamic(), 4, 20, 20};
std::map<std::string, ngraph::PartialShape> shapes; ASSERT_NO_THROW(function->reshape(shapes));
shapes[param_name] = {ngraph::Dimension::dynamic(), 4, 20, 20}; // Load ov::Function to target plugins
cnnNet.reshape(shapes); auto execNet = ie->compile_model(function, targetDevice, configuration);
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest // Create InferRequest
InferenceEngine::InferRequest req; ov::runtime::InferRequest req;
InferenceEngine::Blob::Ptr blob; ov::runtime::Tensor tensor;
ASSERT_NO_THROW(req = execNet.CreateInferRequest()); ASSERT_NO_THROW(req = execNet.create_infer_request());
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getInputsInfo().begin()->first)); ASSERT_NO_THROW(tensor = req.get_tensor(function->get_parameters().back()->get_friendly_name()));
ASSERT_NO_THROW(blob->setShape(refShape)); ASSERT_NO_THROW(tensor.set_shape(refShape));
ASSERT_EQ(blob->getTensorDesc().getDims(), refShape); ASSERT_EQ(tensor.get_shape(), refShape);
req.Infer(); ASSERT_NO_THROW(req.infer());
req.StartAsync(); ASSERT_NO_THROW(req.start_async());
InferenceEngine::StatusCode sts; req.wait();
sts = req.Wait(InferenceEngine::InferRequest::WaitMode::RESULT_READY); ASSERT_NO_THROW(tensor = req.get_tensor(ngraph::op::util::create_ie_output_name(function->get_results().front()->input_value(0))));
ASSERT_EQ(InferenceEngine::StatusCode::OK, sts); ASSERT_EQ(tensor.get_shape(), refOutShape);
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getOutputsInfo().begin()->first));
ASSERT_EQ(blob->getTensorDesc().getDims(), refOutShape);
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getInputsInfo().begin()->first)); ASSERT_NO_THROW(tensor = req.get_tensor(function->get_parameters().back()->get_friendly_name()));
ASSERT_NO_THROW(blob->setShape(refShape2)); ASSERT_NO_THROW(tensor.set_shape(refShape2));
ASSERT_EQ(blob->getTensorDesc().getDims(), refShape2); ASSERT_EQ(tensor.get_shape(), refShape2);
req.Infer(); ASSERT_NO_THROW(req.infer());
req.StartAsync(); ASSERT_NO_THROW(req.start_async());
sts = req.Wait(InferenceEngine::InferRequest::WaitMode::RESULT_READY); req.wait();
ASSERT_EQ(InferenceEngine::StatusCode::OK, sts); ASSERT_NO_THROW(tensor = req.get_tensor(ngraph::op::util::create_ie_output_name(function->get_results().front()->input_value(0))));
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getOutputsInfo().begin()->first)); ASSERT_EQ(tensor.get_shape(), refOutShape2);
ASSERT_EQ(blob->getTensorDesc().getDims(), refOutShape2);
} }
TEST_P(InferRequestDynamicTests, GetSameBlob2times) { TEST_P(InferRequestDynamicTests, GetSameTensor2times) {
const std::string param_name = "Param_1"; const std::string tensor_name = "Tensor_1";
const InferenceEngine::SizeVector refShape = inOutShapes[0].first; const ov::Shape refShape = inOutShapes[0].first;
// Create CNNNetwork from ngrpah::Function std::map<std::string, ov::PartialShape> shapes;
InferenceEngine::CNNNetwork cnnNet(function); shapes[tensor_name] = {ov::Dimension::dynamic(), 4, 20, 20};
std::map<std::string, ngraph::PartialShape> shapes; ASSERT_NO_THROW(function->reshape(shapes));
shapes[param_name] = {ngraph::Dimension::dynamic(), 4, 20, 20}; // Load ov::Function to target plugins
cnnNet.reshape(shapes); auto execNet = ie->compile_model(function, targetDevice, configuration);
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest // Create InferRequest
InferenceEngine::InferRequest req; ov::runtime::InferRequest req;
InferenceEngine::Blob::Ptr blob; ov::runtime::Tensor tensor;
ASSERT_NO_THROW(req = execNet.CreateInferRequest()); ASSERT_NO_THROW(req = execNet.create_infer_request());
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getInputsInfo().begin()->first)); ASSERT_NO_THROW(tensor = req.get_tensor(function->get_parameters().back()->get_friendly_name()));
ASSERT_NO_THROW(blob->setShape(refShape)); ASSERT_NO_THROW(tensor.set_shape(refShape));
ASSERT_EQ(blob->getTensorDesc().getDims(), refShape); ASSERT_EQ(tensor.get_shape(), refShape);
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getInputsInfo().begin()->first)); ASSERT_NO_THROW(tensor = req.get_tensor(function->get_parameters().back()->get_friendly_name()));
ASSERT_EQ(blob->getTensorDesc().getDims(), refShape); ASSERT_EQ(tensor.get_shape(), refShape);
} }
TEST_P(InferRequestDynamicTests, InferDynamicNetworkWithSetBlob) { TEST_P(InferRequestDynamicTests, InferDynamicNetworkWithSetTensor) {
const std::string param_name = "Param_1"; const std::string tensor_name = "Tensor_1";
const InferenceEngine::SizeVector refShape = inOutShapes[0].first; const ov::Shape refShape = inOutShapes[0].first;
const InferenceEngine::SizeVector refOutShape = inOutShapes[0].second; const ov::Shape refOutShape = inOutShapes[0].second;
// Create CNNNetwork from ngrpah::Function std::map<std::string, ov::PartialShape> shapes;
InferenceEngine::CNNNetwork cnnNet(function); shapes[tensor_name] = {ov::Dimension::dynamic(), 4, 20, 20};
std::map<std::string, ngraph::PartialShape> shapes; ASSERT_NO_THROW(function->reshape(shapes));
shapes[param_name] = {ngraph::Dimension::dynamic(), 4, 20, 20}; // Load ov::Function to target plugins
cnnNet.reshape(shapes); auto execNet = ie->compile_model(function, targetDevice, configuration);
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest // Create InferRequest
InferenceEngine::InferRequest req; ov::runtime::InferRequest req;
InferenceEngine::Blob::Ptr blob = make_blob_with_precision({InferenceEngine::Precision::FP32, refShape, InferenceEngine::Layout::NCHW}); ov::runtime::Tensor tensor(ov::element::f32, refShape);
blob->allocate(); ASSERT_NO_THROW(req = execNet.create_infer_request());
ASSERT_NO_THROW(req = execNet.CreateInferRequest()); ASSERT_NO_THROW(req.set_tensor(function->get_parameters().back()->get_friendly_name(), tensor));
ASSERT_NO_THROW(req.SetBlob(cnnNet.getInputsInfo().begin()->first, blob)); ASSERT_EQ(tensor.get_shape(), refShape);
ASSERT_EQ(blob->getTensorDesc().getDims(), refShape); ASSERT_NO_THROW(req.infer());
req.Infer(); ASSERT_NO_THROW(req.start_async());
req.StartAsync(); ASSERT_NO_THROW(req.wait());
InferenceEngine::StatusCode sts; ASSERT_NO_THROW(tensor = req.get_tensor(ngraph::op::util::create_ie_output_name(function->get_results().front()->input_value(0))));
sts = req.Wait(InferenceEngine::InferRequest::WaitMode::RESULT_READY); ASSERT_EQ(tensor.get_shape(), refOutShape);
ASSERT_EQ(InferenceEngine::StatusCode::OK, sts);
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getOutputsInfo().begin()->first));
ASSERT_EQ(blob->getTensorDesc().getDims(), refOutShape);
} }
TEST_P(InferRequestDynamicTests, InferDynamicNetworkWithSetBlob2times) { TEST_P(InferRequestDynamicTests, InferDynamicNetworkWithSetTensor2times) {
const std::string param_name = "Param_1"; const std::string tensor_name = "Tensor_1";
const InferenceEngine::SizeVector refShape = inOutShapes[0].first; const ov::Shape refShape = inOutShapes[0].first;
const InferenceEngine::SizeVector refShape2 = inOutShapes[1].first; const ov::Shape refShape2 = inOutShapes[1].first;
const InferenceEngine::SizeVector refOutShape = inOutShapes[0].second; const ov::Shape refOutShape = inOutShapes[0].second;
const InferenceEngine::SizeVector refOutShape2 = inOutShapes[1].second; const ov::Shape refOutShape2 = inOutShapes[1].second;
// Create CNNNetwork from ngrpah::Function std::map<std::string, ov::PartialShape> shapes;
InferenceEngine::CNNNetwork cnnNet(function); shapes[tensor_name] = {ov::Dimension::dynamic(), 4, 20, 20};
std::map<std::string, ngraph::PartialShape> shapes; ASSERT_NO_THROW(function->reshape(shapes));
shapes[param_name] = {ngraph::Dimension::dynamic(), 4, 20, 20}; // Load ov::Function to target plugins
cnnNet.reshape(shapes); auto execNet = ie->compile_model(function, targetDevice, configuration);
// Load CNNNetwork to target plugins
auto execNet = ie->LoadNetwork(cnnNet, targetDevice, configuration);
// Create InferRequest // Create InferRequest
InferenceEngine::InferRequest req; ov::runtime::InferRequest req;
InferenceEngine::Blob::Ptr blob = make_blob_with_precision({InferenceEngine::Precision::FP32, refShape, InferenceEngine::Layout::NCHW}); ov::runtime::Tensor tensor(ov::element::f32, refShape);
blob->allocate();
ASSERT_NO_THROW(req = execNet.CreateInferRequest()); ASSERT_NO_THROW(req = execNet.create_infer_request());
ASSERT_NO_THROW(req.SetBlob(cnnNet.getInputsInfo().begin()->first, blob)); ASSERT_NO_THROW(req.set_tensor(function->get_parameters().back()->get_friendly_name(), tensor));
ASSERT_EQ(blob->getTensorDesc().getDims(), refShape); ASSERT_EQ(tensor.get_shape(), refShape);
req.Infer(); ASSERT_NO_THROW(req.infer());
req.StartAsync(); ASSERT_NO_THROW(req.start_async());
InferenceEngine::StatusCode sts; ASSERT_NO_THROW(req.wait());
sts = req.Wait(InferenceEngine::InferRequest::WaitMode::RESULT_READY); ASSERT_NO_THROW(tensor = req.get_tensor(ngraph::op::util::create_ie_output_name(function->get_results().front()->input_value(0))));
ASSERT_EQ(InferenceEngine::StatusCode::OK, sts); ASSERT_EQ(tensor.get_shape(), refOutShape);
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getOutputsInfo().begin()->first));
ASSERT_EQ(blob->getTensorDesc().getDims(), refOutShape);
blob = make_blob_with_precision({InferenceEngine::Precision::FP32, refShape2, InferenceEngine::Layout::NCHW}); tensor = ov::runtime::Tensor(ov::element::f32, refShape2);
blob->allocate(); ASSERT_NO_THROW(req.set_tensor(function->get_parameters().back()->get_friendly_name(), tensor));
ASSERT_NO_THROW(req.SetBlob(cnnNet.getInputsInfo().begin()->first, blob)); ASSERT_EQ(tensor.get_shape(), refShape2);
ASSERT_EQ(blob->getTensorDesc().getDims(), refShape2); ASSERT_NO_THROW(req.infer());
req.Infer(); ASSERT_NO_THROW(req.start_async());
req.StartAsync(); ASSERT_NO_THROW(req.wait());
sts = req.Wait(InferenceEngine::InferRequest::WaitMode::RESULT_READY); ASSERT_NO_THROW(tensor = req.get_tensor(ngraph::op::util::create_ie_output_name(function->get_results().front()->input_value(0))));
ASSERT_EQ(InferenceEngine::StatusCode::OK, sts); ASSERT_EQ(tensor.get_shape(), refOutShape2);
ASSERT_NO_THROW(blob = req.GetBlob(cnnNet.getOutputsInfo().begin()->first));
ASSERT_EQ(blob->getTensorDesc().getDims(), refOutShape2);
} }
} // namespace BehaviorTestsDefinitions } // namespace BehaviorTestsDefinitions

View File

@ -43,6 +43,7 @@ inline std::shared_ptr<ngraph::Function> makeSplitConvConcat(std::vector<size_t>
ngraph::element::Type_t ngPrc = ngraph::element::Type_t::f32) { ngraph::element::Type_t ngPrc = ngraph::element::Type_t::f32) {
auto params = ngraph::builder::makeParams(ngPrc, {inputShape}); auto params = ngraph::builder::makeParams(ngPrc, {inputShape});
params.front()->set_friendly_name("Param_1"); params.front()->set_friendly_name("Param_1");
params.front()->get_output_tensor(0).set_names({"Tensor_1"});
auto split = ngraph::builder::makeSplit(params[0], ngPrc, 2, 1); auto split = ngraph::builder::makeSplit(params[0], ngPrc, 2, 1);
auto conv1 = ngraph::builder::makeConvolution(split->output(0), ngPrc, {3, 3}, {1, 1}, {0, 0}, {0, 0}, {1, 1}, auto conv1 = ngraph::builder::makeConvolution(split->output(0), ngPrc, {3, 3}, {1, 1}, {0, 0}, {0, 0}, {1, 1},

View File

@ -86,9 +86,15 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
set_target_properties(ngraph PROPERTIES LINK_FLAGS "/IGNORE:4217,4286") set_target_properties(ngraph PROPERTIES LINK_FLAGS "/IGNORE:4217,4286")
endif() endif()
# some sources are located in ngraph, while headers are in inference_engine_transformations
file(GLOB_RECURSE smart_reshape_srcs ${CMAKE_CURRENT_SOURCE_DIR}/src/pass/smart_reshape/*.cpp)
file(GLOB_RECURSE rt_info_srcs ${CMAKE_CURRENT_SOURCE_DIR}/src/pass/rt_info/*.cpp)
set_source_files_properties("${CMAKE_CURRENT_SOURCE_DIR}/src/pass/convert_precision.cpp" set_source_files_properties("${CMAKE_CURRENT_SOURCE_DIR}/src/pass/convert_precision.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/pass/convert_fp32_to_fp16.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/pass/convert_fp32_to_fp16.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/pass/init_node_info.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/op/type_relaxed.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/src/op/type_relaxed.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/src/function.cpp" # for SmartReshape
${smart_reshape_srcs} ${rt_info_srcs}
PROPERTIES INCLUDE_DIRECTORIES $<TARGET_PROPERTY:inference_engine_transformations,INTERFACE_INCLUDE_DIRECTORIES>) PROPERTIES INCLUDE_DIRECTORIES $<TARGET_PROPERTY:inference_engine_transformations,INTERFACE_INCLUDE_DIRECTORIES>)
# Defines macro in C++ to load backend plugin # Defines macro in C++ to load backend plugin

View File

@ -23,7 +23,7 @@
namespace ov { namespace ov {
/// A user-defined function. /// A user-defined function.
class OPENVINO_API Function { class OPENVINO_API Function : public std::enable_shared_from_this<Function> {
public: public:
static constexpr ngraph::DiscreteTypeInfo type_info{"Function", 0}; static constexpr ngraph::DiscreteTypeInfo type_info{"Function", 0};
const ngraph::DiscreteTypeInfo& get_type_info() const { const ngraph::DiscreteTypeInfo& get_type_info() const {
@ -111,6 +111,8 @@ public:
ov::Output<const ov::Node> input(size_t i) const; ov::Output<const ov::Node> input(size_t i) const;
ov::Output<const ov::Node> input(const std::string& tensor_name) const; ov::Output<const ov::Node> input(const std::string& tensor_name) const;
void reshape(const std::map<std::string, ngraph::PartialShape>& partial_shapes);
/// Return the element type of output i /// Return the element type of output i
const ngraph::element::Type& get_output_element_type(size_t i) const; const ngraph::element::Type& get_output_element_type(size_t i) const;

View File

@ -7,6 +7,8 @@
#include <algorithm> #include <algorithm>
#include <list> #include <list>
#include <memory> #include <memory>
#include <string>
#include <unordered_map>
#include "itt.hpp" #include "itt.hpp"
#include "ngraph/graph_util.hpp" #include "ngraph/graph_util.hpp"
@ -14,10 +16,15 @@
#include "ngraph/ops.hpp" #include "ngraph/ops.hpp"
#include "ngraph/opsets/opset7.hpp" #include "ngraph/opsets/opset7.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
#include "openvino/core/attribute_visitor.hpp"
#include "openvino/core/except.hpp" #include "openvino/core/except.hpp"
#include "openvino/core/partial_shape.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/util/op_types.hpp" #include "openvino/op/util/op_types.hpp"
#include "openvino/op/util/variable_context.hpp" #include "openvino/op/util/variable_context.hpp"
#include "openvino/op/util/variable_extension.hpp" #include "openvino/op/util/variable_extension.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/smart_reshape/smart_reshape.hpp"
using namespace std; using namespace std;
@ -640,3 +647,66 @@ ov::Output<ov::Node> ov::Function::input(const std::string& tensor_name) {
} }
throw ov::Exception("Input for tensor name " + tensor_name + " was not found."); throw ov::Exception("Input for tensor name " + tensor_name + " was not found.");
} }
void ov::Function::reshape(const std::map<std::string, ov::PartialShape>& partial_shapes) {
if (partial_shapes.empty())
return;
const auto& params = get_parameters();
std::unordered_map<std::string, std::shared_ptr<ov::op::v0::Parameter>> tensor_param_map;
// Check that we need to do reshape only if input shapes will be changed
bool need_reshape = false;
for (const auto& partial_shape : partial_shapes) {
bool shape_is_used = false;
for (const auto& param : params) {
const auto& tensor_names = param->get_output_tensor(0).get_names();
if (tensor_names.count(partial_shape.first)) {
shape_is_used = true;
tensor_param_map[partial_shape.first] = param;
if (param->get_output_partial_shape(0).is_dynamic() ||
param->get_output_partial_shape(0) != partial_shape.second) {
need_reshape = true;
}
break;
}
}
OPENVINO_ASSERT(shape_is_used,
"PartialShape for tensor with name '",
partial_shape.first,
"' is not used in ov::Function::reshape");
}
if (!need_reshape)
return;
// save original parameters shape
std::map<std::string, ov::PartialShape> original_input_shapes;
for (const auto& param : params) {
std::string any_tensor_name = *param->get_output_tensor(0).get_names().begin();
original_input_shapes[any_tensor_name] = param->get_output_partial_shape(0);
}
auto reshape_only = [&](const std::map<std::string, ov::PartialShape>& pshapes) {
for (const auto& pshape : pshapes) {
tensor_param_map[pshape.first]->set_partial_shape(pshape.second);
}
validate_nodes_and_infer_types();
};
try {
ov::pass::Manager ssr_manager;
ssr_manager.register_pass<ngraph::pass::SmartReshape>();
ssr_manager.run_passes(shared_from_this());
reshape_only(partial_shapes);
} catch (std::exception& ex) {
// restore shapes to original ones
reshape_only(original_input_shapes);
throw ex;
}
}

View File

@ -2,56 +2,54 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "itt.hpp"
#include "transformations/init_node_info.hpp" #include "transformations/init_node_info.hpp"
#include "transformations/rt_info/fused_names_attribute.hpp"
#include "transformations/rt_info/primitives_priority_attribute.hpp"
#include <memory> #include <memory>
#include <vector>
#include <ngraph/opsets/opset1.hpp> #include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
#include <ngraph/variant.hpp> #include <ngraph/variant.hpp>
#include <vector>
#include "itt.hpp"
#include "transformations/rt_info/fused_names_attribute.hpp"
#include "transformations/rt_info/primitives_priority_attribute.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::InitNodeInfo, "InitNodeInfo", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::InitNodeInfo, "InitNodeInfo", 0);
bool ngraph::pass::InitNodeInfo::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ngraph::pass::InitNodeInfo::run_on_function(std::shared_ptr<ngraph::Function> f) {
RUN_ON_FUNCTION_SCOPE(InitNodeInfo); // TODO: enable conditional compile
std::vector<std::shared_ptr<Variant> > attributes { // RUN_ON_FUNCTION_SCOPE(InitNodeInfo);
std::make_shared<VariantWrapper<ngraph::FusedNames> >(ngraph::FusedNames()) std::vector<std::shared_ptr<Variant>> attributes{
}; std::make_shared<VariantWrapper<ngraph::FusedNames>>(ngraph::FusedNames())};
using VariantCreator = std::function<std::shared_ptr<Variant>(const std::string&)>; using VariantCreator = std::function<std::shared_ptr<Variant>(const std::string&)>;
std::map<std::string, VariantCreator> update_attributes { std::map<std::string, VariantCreator> update_attributes{
{"PrimitivesPriority", {"PrimitivesPriority", [](const std::string& value) -> std::shared_ptr<Variant> {
[](const std::string & value) -> std::shared_ptr<Variant> { return std::make_shared<VariantWrapper<ov::PrimitivesPriority>>(ov::PrimitivesPriority(value));
return std::make_shared<VariantWrapper<ov::PrimitivesPriority> >(ov::PrimitivesPriority(value)); }}};
}
}
};
for (auto & node : f->get_ops()) { for (auto& node : f->get_ops()) {
// Recursively apply transformation for sub-graph based operations // Recursively apply transformation for sub-graph based operations
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) { if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
if (auto sub_graph = sub_graph_node->get_function()) { if (auto sub_graph = sub_graph_node->get_function()) {
run_on_function(sub_graph); run_on_function(sub_graph);
} }
} }
auto & rtInfo = node->get_rt_info(); auto& rtInfo = node->get_rt_info();
// Default attributes initialization // Default attributes initialization
for (auto & attr : attributes) { for (auto& attr : attributes) {
// Skip initialization if attribute has been already set // Skip initialization if attribute has been already set
if (rtInfo.count(attr->get_type_info())) continue; if (rtInfo.count(attr->get_type_info()))
continue;
if (auto init_attr = attr->init(node)) { if (auto init_attr = attr->init(node)) {
rtInfo[attr->get_type_info()] = init_attr; rtInfo[attr->get_type_info()] = init_attr;
} }
} }
// Convert manually set attributes to appropriate VariantWrapper class instances // Convert manually set attributes to appropriate VariantWrapper class instances
// all manually set attributes must belong to VariantWrapper<std::string> class // all manually set attributes must belong to VariantWrapper<std::string> class
for (auto & attr : update_attributes) { for (auto& attr : update_attributes) {
if (rtInfo.count(attr.first)) { if (rtInfo.count(attr.first)) {
if (auto variant_string = std::dynamic_pointer_cast<VariantWrapper<std::string> >(rtInfo[attr.first])) { if (auto variant_string = std::dynamic_pointer_cast<VariantWrapper<std::string>>(rtInfo[attr.first])) {
rtInfo.erase(attr.first); rtInfo.erase(attr.first);
auto res = attr.second(variant_string->get()); auto res = attr.second(variant_string->get());
rtInfo[res->get_type_info()] = res; rtInfo[res->get_type_info()] = res;

View File

@ -2,23 +2,23 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include <assert.h> #include "transformations/rt_info/fused_names_attribute.hpp"
#include <functional>
#include <memory>
#include <iterator>
#include <ostream>
#include <assert.h>
#include <functional>
#include <iterator>
#include <memory>
#include <ngraph/node.hpp> #include <ngraph/node.hpp>
#include <ngraph/variant.hpp> #include <ngraph/variant.hpp>
#include <ostream>
#include "transformations/rt_info/fused_names_attribute.hpp"
using namespace ngraph; using namespace ngraph;
using namespace ov; using namespace ov;
std::string FusedNames::getNames() const { std::string FusedNames::getNames() const {
std::string res; std::string res;
for (auto &name : fused_names) { for (auto& name : fused_names) {
res += (res.empty() ? name : "," + name); res += (res.empty() ? name : "," + name);
} }
return res; return res;
@ -28,58 +28,62 @@ std::vector<std::string> FusedNames::getVectorNames() const {
return std::vector<std::string>(fused_names.begin(), fused_names.end()); return std::vector<std::string>(fused_names.begin(), fused_names.end());
} }
void FusedNames::fuseWith(const FusedNames &names) { void FusedNames::fuseWith(const FusedNames& names) {
for (const auto & name : names.fused_names) { for (const auto& name : names.fused_names) {
fused_names.insert(name); fused_names.insert(name);
} }
} }
std::string ngraph::getFusedNames(const std::shared_ptr<ngraph::Node> &node) { std::string ngraph::getFusedNames(const std::shared_ptr<ngraph::Node>& node) {
const auto &rtInfo = node->get_rt_info(); const auto& rtInfo = node->get_rt_info();
using FusedNamesWrapper = VariantWrapper<FusedNames>; using FusedNamesWrapper = VariantWrapper<FusedNames>;
if (!rtInfo.count(FusedNamesWrapper::get_type_info_static())) return {}; if (!rtInfo.count(FusedNamesWrapper::get_type_info_static()))
return {};
const auto &attr = rtInfo.at(FusedNamesWrapper::get_type_info_static()); const auto& attr = rtInfo.at(FusedNamesWrapper::get_type_info_static());
FusedNames fusedNames = ov::as_type_ptr<FusedNamesWrapper>(attr)->get(); FusedNames fusedNames = ov::as_type_ptr<FusedNamesWrapper>(attr)->get();
return fusedNames.getNames(); return fusedNames.getNames();
} }
std::vector<std::string> ngraph::getFusedNamesVector(const std::shared_ptr<ngraph::Node> &node) { std::vector<std::string> ngraph::getFusedNamesVector(const std::shared_ptr<ngraph::Node>& node) {
if (!node) return {}; if (!node)
return {};
const auto &rtInfo = node->get_rt_info(); const auto& rtInfo = node->get_rt_info();
using FusedNamesWrapper = VariantWrapper<FusedNames>; using FusedNamesWrapper = VariantWrapper<FusedNames>;
if (!rtInfo.count(FusedNamesWrapper::get_type_info_static())) return {}; if (!rtInfo.count(FusedNamesWrapper::get_type_info_static()))
return {};
const auto &attr = rtInfo.at(FusedNamesWrapper::get_type_info_static()); const auto& attr = rtInfo.at(FusedNamesWrapper::get_type_info_static());
FusedNames fusedNames = ov::as_type_ptr<FusedNamesWrapper>(attr)->get(); FusedNames fusedNames = ov::as_type_ptr<FusedNamesWrapper>(attr)->get();
return fusedNames.getVectorNames(); return fusedNames.getVectorNames();
} }
template class ov::VariantImpl<FusedNames>; template class ov::VariantImpl<FusedNames>;
std::shared_ptr<ngraph::Variant> VariantWrapper<FusedNames>::merge(const ngraph::NodeVector & nodes) { std::shared_ptr<ngraph::Variant> VariantWrapper<FusedNames>::merge(const ngraph::NodeVector& nodes) {
FusedNames mergedNames; FusedNames mergedNames;
for (auto &node : nodes) { for (auto& node : nodes) {
const auto &rtInfo = node->get_rt_info(); const auto& rtInfo = node->get_rt_info();
if (!rtInfo.count(VariantWrapper<FusedNames>::get_type_info_static())) continue; if (!rtInfo.count(VariantWrapper<FusedNames>::get_type_info_static()))
continue;
const auto attr = rtInfo.at(VariantWrapper<FusedNames>::get_type_info_static()); const auto attr = rtInfo.at(VariantWrapper<FusedNames>::get_type_info_static());
if (auto fusedNames = std::dynamic_pointer_cast<VariantWrapper<FusedNames> >(attr)) { if (auto fusedNames = std::dynamic_pointer_cast<VariantWrapper<FusedNames>>(attr)) {
mergedNames.fuseWith(fusedNames->get()); mergedNames.fuseWith(fusedNames->get());
} }
} }
return std::make_shared<VariantWrapper<FusedNames> >(mergedNames); return std::make_shared<VariantWrapper<FusedNames>>(mergedNames);
} }
std::shared_ptr<ngraph::Variant> VariantWrapper<FusedNames>::init(const std::shared_ptr<ngraph::Node> & node) { std::shared_ptr<ngraph::Variant> VariantWrapper<FusedNames>::init(const std::shared_ptr<ngraph::Node>& node) {
return std::make_shared<VariantWrapper<FusedNames> > (FusedNames(node->get_friendly_name())); return std::make_shared<VariantWrapper<FusedNames>>(FusedNames(node->get_friendly_name()));
} }
bool VariantWrapper<FusedNames>::visit_attributes(AttributeVisitor &visitor) { bool VariantWrapper<FusedNames>::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("value", m_value.fused_names); visitor.on_attribute("value", m_value.fused_names);
return true; return true;
} }

View File

@ -2,19 +2,17 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include <assert.h>
#include <functional>
#include <memory>
#include <iterator>
#include <ostream>
#include <ngraph/node.hpp>
#include <ngraph/variant.hpp>
#include <ngraph/opsets/opset1.hpp>
#include "transformations/rt_info/primitives_priority_attribute.hpp" #include "transformations/rt_info/primitives_priority_attribute.hpp"
#include "ngraph_ops/convolution_ie.hpp"
#include "ngraph_ops/deconvolution_ie.hpp" #include <assert.h>
#include <functional>
#include <iterator>
#include <memory>
#include <ngraph/node.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/variant.hpp>
#include <ostream>
using namespace ov; using namespace ov;
using namespace ngraph; using namespace ngraph;
@ -23,27 +21,26 @@ std::string PrimitivesPriority::getPrimitivesPriority() const {
return primitives_priority; return primitives_priority;
} }
std::string ov::getPrimitivesPriority(const std::shared_ptr<ngraph::Node> &node) { std::string ov::getPrimitivesPriority(const std::shared_ptr<ngraph::Node>& node) {
const auto &rtInfo = node->get_rt_info(); const auto& rtInfo = node->get_rt_info();
using PrimitivesPriorityWrapper = VariantWrapper<PrimitivesPriority>; using PrimitivesPriorityWrapper = VariantWrapper<PrimitivesPriority>;
if (!rtInfo.count(PrimitivesPriorityWrapper::get_type_info_static())) return ""; if (!rtInfo.count(PrimitivesPriorityWrapper::get_type_info_static()))
return "";
const auto &attr = rtInfo.at(PrimitivesPriorityWrapper::get_type_info_static()); const auto& attr = rtInfo.at(PrimitivesPriorityWrapper::get_type_info_static());
PrimitivesPriority pp = ov::as_type_ptr<PrimitivesPriorityWrapper>(attr)->get(); PrimitivesPriority pp = ov::as_type_ptr<PrimitivesPriorityWrapper>(attr)->get();
return pp.getPrimitivesPriority(); return pp.getPrimitivesPriority();
} }
template class ov::VariantImpl<PrimitivesPriority>; template class ov::VariantImpl<PrimitivesPriority>;
std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::merge(const ngraph::NodeVector & nodes) { std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::merge(const ngraph::NodeVector& nodes) {
auto isConvolutionBased = [](const std::shared_ptr<Node> & node) -> bool { auto isConvolutionBased = [](const std::shared_ptr<Node>& node) -> bool {
if (std::dynamic_pointer_cast<ngraph::opset1::Convolution>(node) || if (std::dynamic_pointer_cast<ngraph::opset1::Convolution>(node) ||
std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution>(node) || std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution>(node) ||
std::dynamic_pointer_cast<ngraph::opset1::GroupConvolutionBackpropData>(node) || std::dynamic_pointer_cast<ngraph::opset1::GroupConvolutionBackpropData>(node) ||
std::dynamic_pointer_cast<ngraph::opset1::ConvolutionBackpropData>(node) || std::dynamic_pointer_cast<ngraph::opset1::ConvolutionBackpropData>(node)) {
std::dynamic_pointer_cast<ngraph::op::ConvolutionIE>(node) ||
std::dynamic_pointer_cast<ngraph::op::DeconvolutionIE>(node)) {
return true; return true;
} }
return false; return false;
@ -51,10 +48,11 @@ std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::merge(const
std::set<std::string> unique_pp; std::set<std::string> unique_pp;
for (auto &node : nodes) { for (auto& node : nodes) {
if (isConvolutionBased(node)) { if (isConvolutionBased(node)) {
std::string pp = getPrimitivesPriority(node); std::string pp = getPrimitivesPriority(node);
if (!pp.empty()) unique_pp.insert(pp); if (!pp.empty())
unique_pp.insert(pp);
} }
} }
@ -66,14 +64,14 @@ std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::merge(const
if (unique_pp.size() == 1) { if (unique_pp.size() == 1) {
final_primitives_priority = *unique_pp.begin(); final_primitives_priority = *unique_pp.begin();
} }
return std::make_shared<VariantWrapper<PrimitivesPriority> >(PrimitivesPriority(final_primitives_priority)); return std::make_shared<VariantWrapper<PrimitivesPriority>>(PrimitivesPriority(final_primitives_priority));
} }
std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::init(const std::shared_ptr<ngraph::Node> & node) { std::shared_ptr<ngraph::Variant> VariantWrapper<PrimitivesPriority>::init(const std::shared_ptr<ngraph::Node>& node) {
throw ngraph_error(std::string(get_type_info()) + " has no default initialization."); throw ngraph_error(std::string(get_type_info()) + " has no default initialization.");
} }
bool VariantWrapper<PrimitivesPriority>::visit_attributes(AttributeVisitor &visitor) { bool VariantWrapper<PrimitivesPriority>::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("value", m_value.primitives_priority); visitor.on_attribute("value", m_value.primitives_priority);
return true; return true;
} }

View File

@ -2,43 +2,45 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "itt.hpp"
#include "transformations/smart_reshape/matmul_sr.hpp" #include "transformations/smart_reshape/matmul_sr.hpp"
#include <memory>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/pattern/matcher.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/validation_util.hpp>
#include <numeric>
#include "itt.hpp"
#include "transformations/smart_reshape/utils.hpp" #include "transformations/smart_reshape/utils.hpp"
#include <numeric> bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap& pattern_to_output,
#include <memory> const std::shared_ptr<ngraph::Node>& matmul_label,
const std::shared_ptr<ngraph::Node>& reshape_label,
#include <ngraph/ngraph.hpp> const std::shared_ptr<ngraph::Node>& other_input_label,
#include <ngraph/pattern/matcher.hpp> const std::shared_ptr<ngraph::Node>& reshape_pattern_label,
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/opsets/opset4.hpp>
bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap & pattern_to_output,
const std::shared_ptr<ngraph::Node> & matmul_label,
const std::shared_ptr<ngraph::Node> & reshape_label,
const std::shared_ptr<ngraph::Node> & other_input_label,
const std::shared_ptr<ngraph::Node> & reshape_pattern_label,
bool reshape_is_A_input) { bool reshape_is_A_input) {
const auto & reshape_rank = pattern_to_output.at(reshape_label).get_partial_shape().rank(); const auto& reshape_rank = pattern_to_output.at(reshape_label).get_partial_shape().rank();
const auto & matmul = std::dynamic_pointer_cast<ngraph::opset4::MatMul>(pattern_to_output.at(matmul_label).get_node_shared_ptr()); const auto& matmul =
std::dynamic_pointer_cast<ngraph::opset4::MatMul>(pattern_to_output.at(matmul_label).get_node_shared_ptr());
if (!matmul || reshape_rank.is_dynamic() || reshape_rank.get_length() != 2) if (!matmul || reshape_rank.is_dynamic() || reshape_rank.get_length() != 2)
return false; return false;
const auto &shape_source = pattern_to_output.at(other_input_label); const auto& shape_source = pattern_to_output.at(other_input_label);
if (ngraph::is_type<ngraph::opset4::Transpose>(shape_source.get_node_shared_ptr()) || if (ngraph::is_type<ngraph::opset4::Transpose>(shape_source.get_node_shared_ptr()) ||
ngraph::is_type<ngraph::opset4::Reshape>(shape_source.get_node_shared_ptr())) ngraph::is_type<ngraph::opset4::Reshape>(shape_source.get_node_shared_ptr()))
// avoiding loop creation // avoiding loop creation
return false; return false;
const auto & raw_idx = reshape_is_A_input ? (matmul->get_transpose_b() ? -1 : -2) : (matmul->get_transpose_a() ? -2 : -1); const auto& raw_idx =
const auto & idx = ngraph::normalize_axes(matmul->description(), {raw_idx}, reshape_rank); reshape_is_A_input ? (matmul->get_transpose_b() ? -1 : -2) : (matmul->get_transpose_a() ? -2 : -1);
const auto & C = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_source(shape_source, idx); const auto& idx = ngraph::normalize_axes(matmul->description(), {raw_idx}, reshape_rank);
const auto & N = ngraph::opset4::Constant::create(ngraph::element::i64, {1}, {-1}); const auto& C = ngraph::op::util::node_to_get_shape_value_of_indices_from_shape_source(shape_source, idx);
const auto & pattern_vector = reshape_is_A_input ? const auto& N = ngraph::opset4::Constant::create(ngraph::element::i64, {1}, {-1});
(matmul->get_transpose_a() ? ngraph::OutputVector({C, N}) : ngraph::OutputVector({N, C})) : const auto& pattern_vector =
(matmul->get_transpose_b() ? ngraph::OutputVector({N, C}) : ngraph::OutputVector({C, N})); reshape_is_A_input ? (matmul->get_transpose_a() ? ngraph::OutputVector({C, N}) : ngraph::OutputVector({N, C}))
const auto & new_reshape_pattern = std::make_shared<ngraph::opset4::Concat>(pattern_vector, 0); : (matmul->get_transpose_b() ? ngraph::OutputVector({N, C}) : ngraph::OutputVector({C, N}));
const auto& new_reshape_pattern = std::make_shared<ngraph::opset4::Concat>(pattern_vector, 0);
auto reshape_pattern = pattern_to_output.at(reshape_pattern_label).get_node_shared_ptr(); auto reshape_pattern = pattern_to_output.at(reshape_pattern_label).get_node_shared_ptr();
new_reshape_pattern->set_friendly_name(reshape_pattern->get_friendly_name()); new_reshape_pattern->set_friendly_name(reshape_pattern->get_friendly_name());
@ -50,60 +52,76 @@ bool relax_hc_reshape_followed_by_matmul(const ngraph::pattern::PatternValueMap
NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeAMatMul, "ReshapeAMatMul", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeAMatMul, "ReshapeAMatMul", 0);
ngraph::pass::ReshapeAMatMul::ReshapeAMatMul() { ngraph::pass::ReshapeAMatMul::ReshapeAMatMul() {
MATCHER_SCOPE(ReshapeAMatMul); // TODO: enable conditional compile
// MATCHER_SCOPE(ReshapeAMatMul);
auto other_input_label = pattern::any_input(); auto other_input_label = pattern::any_input();
auto reshape_input_label = pattern::any_input(); auto reshape_input_label = pattern::any_input();
auto reshape_pattern_label = pattern::any_input(); auto reshape_pattern_label = pattern::any_input();
auto reshape_label = ngraph::pattern::wrap_type<opset4::Reshape>({reshape_input_label, reshape_pattern_label}); auto reshape_label = ngraph::pattern::wrap_type<opset4::Reshape>({reshape_input_label, reshape_pattern_label});
auto matmul_label = ngraph::pattern::wrap_type<opset4::MatMul>({reshape_label, other_input_label}); auto matmul_label = ngraph::pattern::wrap_type<opset4::MatMul>({reshape_label, other_input_label});
matcher_pass_callback callback = [=](pattern::Matcher &m) -> bool { matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
const auto & pattern_to_output = m.get_pattern_value_map(); const auto& pattern_to_output = m.get_pattern_value_map();
return relax_hc_reshape_followed_by_matmul(pattern_to_output, matmul_label, reshape_label, other_input_label, reshape_pattern_label, true); return relax_hc_reshape_followed_by_matmul(pattern_to_output,
matmul_label,
reshape_label,
other_input_label,
reshape_pattern_label,
true);
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul_label, matcher_name); auto m = std::make_shared<ngraph::pattern::Matcher>(matmul_label /*, matcher_name */);
register_matcher(m, callback); register_matcher(m, callback);
} }
NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeBMatMul, "ReshapeBMatMul", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeBMatMul, "ReshapeBMatMul", 0);
ngraph::pass::ReshapeBMatMul::ReshapeBMatMul() { ngraph::pass::ReshapeBMatMul::ReshapeBMatMul() {
MATCHER_SCOPE(ReshapeBMatMul); // TODO: enable conditional compile
// MATCHER_SCOPE(ReshapeBMatMul);
auto other_input_label = pattern::any_input(); auto other_input_label = pattern::any_input();
auto reshape_input_label = pattern::any_input(); auto reshape_input_label = pattern::any_input();
auto reshape_pattern_label = pattern::any_input(); auto reshape_pattern_label = pattern::any_input();
auto reshape_label = ngraph::pattern::wrap_type<opset4::Reshape>({reshape_input_label, reshape_pattern_label}); auto reshape_label = ngraph::pattern::wrap_type<opset4::Reshape>({reshape_input_label, reshape_pattern_label});
auto matmul_label = ngraph::pattern::wrap_type<opset4::MatMul>({other_input_label, reshape_label}); auto matmul_label = ngraph::pattern::wrap_type<opset4::MatMul>({other_input_label, reshape_label});
matcher_pass_callback callback = [=](pattern::Matcher &m) -> bool { matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
const auto & pattern_to_output = m.get_pattern_value_map(); const auto& pattern_to_output = m.get_pattern_value_map();
return relax_hc_reshape_followed_by_matmul(pattern_to_output, matmul_label, reshape_label, other_input_label, reshape_pattern_label, false); return relax_hc_reshape_followed_by_matmul(pattern_to_output,
matmul_label,
reshape_label,
other_input_label,
reshape_pattern_label,
false);
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul_label, matcher_name); auto m = std::make_shared<ngraph::pattern::Matcher>(matmul_label /*, matcher_name */);
register_matcher(m, callback); register_matcher(m, callback);
} }
NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeMatMul, "TransposeMatMul", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeMatMul, "TransposeMatMul", 0);
ngraph::pass::TransposeMatMul::TransposeMatMul() { ngraph::pass::TransposeMatMul::TransposeMatMul() {
MATCHER_SCOPE(TransposeMatMul); // TODO: enable conditional compile
// MATCHER_SCOPE(TransposeMatMul);
auto matmul_label = ngraph::pattern::wrap_type<opset4::MatMul>(); auto matmul_label = ngraph::pattern::wrap_type<opset4::MatMul>();
matcher_pass_callback callback = [=](pattern::Matcher &m) -> bool { matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
const auto & pattern_to_output = m.get_pattern_value_map(); const auto& pattern_to_output = m.get_pattern_value_map();
auto matmul = std::dynamic_pointer_cast<ngraph::opset4::MatMul>(pattern_to_output.at(matmul_label).get_node_shared_ptr()); auto matmul =
std::dynamic_pointer_cast<ngraph::opset4::MatMul>(pattern_to_output.at(matmul_label).get_node_shared_ptr());
if (!matmul) if (!matmul)
return false; return false;
auto transpose_is_fusable = [](const std::shared_ptr<ngraph::Node>& input) { auto transpose_is_fusable = [](const std::shared_ptr<ngraph::Node>& input) {
const auto & input_rank = input->get_output_partial_shape(0).rank(); const auto& input_rank = input->get_output_partial_shape(0).rank();
if (input_rank.is_static() && input_rank.get_length() >= 2) { if (input_rank.is_static() && input_rank.get_length() >= 2) {
if (auto transpose = std::dynamic_pointer_cast<ngraph::opset4::Transpose>(input)) { if (auto transpose = std::dynamic_pointer_cast<ngraph::opset4::Transpose>(input)) {
if (auto order = std::dynamic_pointer_cast<opset4::Constant>(transpose->get_input_node_shared_ptr(1))) { if (auto order =
const auto & order_vector = order->cast_vector<int64_t>(); std::dynamic_pointer_cast<opset4::Constant>(transpose->get_input_node_shared_ptr(1))) {
const auto& order_vector = order->cast_vector<int64_t>();
std::vector<int64_t> fusable_order(input_rank.get_length()); std::vector<int64_t> fusable_order(input_rank.get_length());
std::iota(fusable_order.begin(), fusable_order.end(), 0); std::iota(fusable_order.begin(), fusable_order.end(), 0);
std::swap(fusable_order[input_rank.get_length() - 1], fusable_order[input_rank.get_length() - 2]); std::swap(fusable_order[input_rank.get_length() - 1],
fusable_order[input_rank.get_length() - 2]);
return order_vector == fusable_order; return order_vector == fusable_order;
} }
} }
@ -138,6 +156,6 @@ ngraph::pass::TransposeMatMul::TransposeMatMul() {
} }
return false; return false;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul_label, matcher_name); auto m = std::make_shared<ngraph::pattern::Matcher>(matmul_label /*, matcher_name */);
register_matcher(m, callback); register_matcher(m, callback);
} }

View File

@ -0,0 +1,70 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/smart_reshape/mimic_set_batch_size.hpp>
#include "itt.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::MimicSetBatchSize, "MimicSetBatchSize", 0);
bool ngraph::pass::MimicSetBatchSize::run_on_function(std::shared_ptr<ngraph::Function> f) {
// TODO: enable conditional compile
// RUN_ON_FUNCTION_SCOPE(MimicSetBatchSize);
// extracting ratio of out to in 0-index dimension value from the folded function
auto specialized_function = ngraph::clone_function(*f);
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.run_passes(specialized_function);
std::map<std::string, float> scale;
for (const auto& node : specialized_function->get_ops()) {
if (const auto& reshape = std::dynamic_pointer_cast<opset5::Reshape>(node)) {
const auto in_pshape = reshape->get_input_partial_shape(0),
out_pshape = reshape->get_output_partial_shape(0);
if (in_pshape.rank().is_dynamic() || in_pshape.rank().get_length() <= 1 || in_pshape[0].is_dynamic() ||
out_pshape.rank().is_dynamic() || out_pshape.rank().get_length() <= 1 || out_pshape[0].is_dynamic())
continue;
const auto& pattern = std::dynamic_pointer_cast<opset5::Constant>(reshape->get_input_node_shared_ptr(1));
if (pattern && pattern->cast_vector<int64_t>()[0] > 0) {
scale[reshape->get_friendly_name()] =
static_cast<float>(out_pshape[0].get_length()) / static_cast<float>(in_pshape[0].get_length());
}
}
}
// apply transformation to original function
bool transformed = false;
for (auto& reshape : f->get_ops()) {
if (!is_type<opset5::Reshape>(reshape) || !scale.count(reshape->get_friendly_name()) ||
reshape->get_output_partial_shape(0).rank().is_dynamic())
continue;
const auto& shape_of =
std::make_shared<opset5::ShapeOf>(reshape->get_input_source_output(0), reshape->get_input_element_type(1));
const auto& new_input_batch = std::make_shared<ngraph::opset5::Gather>(
shape_of,
ngraph::opset5::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{0}),
ngraph::opset5::Constant::create(ngraph::element::i64, {}, std::vector<int64_t>{0}));
const std::shared_ptr<Node>& new_output_batch = std::make_shared<opset5::Convert>(
std::make_shared<opset5::Ceiling>(std::make_shared<opset5::Multiply>(
std::make_shared<opset5::Convert>(new_input_batch, element::f32),
opset5::Constant::create(element::f32, {1}, {scale[reshape->get_friendly_name()]}))),
reshape->get_input_element_type(1));
std::vector<int64_t> non_batch_dims(reshape->get_output_partial_shape(0).rank().get_length() - 1);
std::iota(non_batch_dims.begin(), non_batch_dims.end(), 1);
const auto& non_batch_dims_node = std::make_shared<ngraph::opset5::Gather>(
reshape->input_value(1),
ngraph::opset5::Constant::create(ngraph::element::i64, {non_batch_dims.size()}, non_batch_dims),
ngraph::opset5::Constant::create(ngraph::element::i64, {}, std::vector<int64_t>{0}));
auto new_reshape_pattern =
std::make_shared<opset5::Concat>(OutputVector{new_output_batch, non_batch_dims_node}, 0);
reshape->input(1).replace_source_output(new_reshape_pattern->output(0));
transformed = true;
}
return transformed;
}

View File

@ -0,0 +1,83 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/pattern/matcher.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <transformations/smart_reshape/proposal_scales_stridedslice.hpp>
#include "itt.hpp"
bool crop_scales_for_proposal(const ngraph::pattern::PatternValueMap& pattern_to_output,
std::shared_ptr<ngraph::Node> parameter_label,
std::shared_ptr<ngraph::Node> proposal_label) {
const auto& parameter = pattern_to_output.at(parameter_label);
const auto& proposal = pattern_to_output.at(proposal_label).get_node_shared_ptr();
auto cropped_scales = std::make_shared<ngraph::opset5::StridedSlice>(
proposal->input_value(2),
ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}),
ngraph::opset5::Constant::create(ngraph::element::i64,
ngraph::Shape{1},
{parameter.get_partial_shape()[1].get_length()}),
ngraph::opset5::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1}),
std::vector<int64_t>{0},
std::vector<int64_t>{0});
proposal->input(2).replace_source_output(cropped_scales->output(0));
return true;
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::Proposal1Scales, "Proposal1Scales", 0);
ngraph::pass::Proposal1Scales::Proposal1Scales() {
// TODO: enable conditional compile
// MATCHER_SCOPE(Proposal1Scales);
auto parameter_label = ngraph::pattern::wrap_type<opset5::Parameter>([](const Output<Node>& output) {
const auto& shape = output.get_partial_shape();
return shape.rank().is_static() && shape.rank().get_length() == 2 && shape[1].is_static() &&
(shape[1].get_length() == 3 || shape[1].get_length() == 4);
});
auto reshape_label = ngraph::pattern::wrap_type<opset5::Reshape>(
{parameter_label, ngraph::pattern::wrap_type<opset5::Constant>()},
[](const Output<Node>& output) {
return output.get_partial_shape().rank().is_static() && output.get_partial_shape().rank().get_length() == 1;
});
auto proposal_label =
ngraph::pattern::wrap_type<opset1::Proposal>({pattern::any_input(), pattern::any_input(), reshape_label});
matcher_pass_callback callback = [parameter_label, proposal_label](pattern::Matcher& m) -> bool {
return crop_scales_for_proposal(m.get_pattern_value_map(), parameter_label, proposal_label);
};
auto m = std::make_shared<ngraph::pattern::Matcher>(proposal_label /*, matcher_name */);
register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::Proposal4Scales, "Proposal4Scales", 0);
ngraph::pass::Proposal4Scales::Proposal4Scales() {
// TODO: enable conditional compile
// MATCHER_SCOPE(Proposal4Scales);
auto parameter_label = ngraph::pattern::wrap_type<opset5::Parameter>([](const Output<Node>& output) {
const auto& shape = output.get_partial_shape();
return shape.rank().is_static() && shape.rank().get_length() == 2 && shape[1].is_static() &&
(shape[1].get_length() == 3 || shape[1].get_length() == 4);
});
auto reshape_label = ngraph::pattern::wrap_type<opset5::Reshape>(
{parameter_label, ngraph::pattern::wrap_type<opset5::Constant>()},
[](const Output<Node>& output) {
return output.get_partial_shape().rank().is_static() && output.get_partial_shape().rank().get_length() == 1;
});
auto proposal_label =
ngraph::pattern::wrap_type<opset4::Proposal>({pattern::any_input(), pattern::any_input(), reshape_label});
matcher_pass_callback callback = [parameter_label, proposal_label](pattern::Matcher& m) -> bool {
return crop_scales_for_proposal(m.get_pattern_value_map(), parameter_label, proposal_label);
};
auto m = std::make_shared<ngraph::pattern::Matcher>(proposal_label /*, matcher_name */);
register_matcher(m, callback);
}

View File

@ -2,26 +2,30 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "itt.hpp"
#include <transformations/smart_reshape/reshape_to_1D.hpp>
#include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset5.hpp> #include <ngraph/opsets/opset5.hpp>
#include <ngraph/pattern/matcher.hpp> #include <ngraph/pattern/matcher.hpp>
#include <ngraph/pattern/op/wrap_type.hpp> #include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
#include <transformations/smart_reshape/reshape_to_1D.hpp>
#include "itt.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeTo1D, "ReshapeTo1D", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeTo1D, "ReshapeTo1D", 0);
ngraph::pass::ReshapeTo1D::ReshapeTo1D() { ngraph::pass::ReshapeTo1D::ReshapeTo1D() {
MATCHER_SCOPE(ReshapeTo1D); // TODO: enable conditional compile
auto reshape_label = ngraph::pattern::wrap_type<opset5::Reshape>({pattern::any_input(), ngraph::pattern::wrap_type<opset5::Constant>()}, // MATCHER_SCOPE(ReshapeTo1D);
[](const Output<Node> & output) { return output.get_partial_shape().rank().is_static() && output.get_partial_shape().rank().get_length() == 1; }); auto reshape_label = ngraph::pattern::wrap_type<opset5::Reshape>(
{pattern::any_input(), ngraph::pattern::wrap_type<opset5::Constant>()},
[](const Output<Node>& output) {
return output.get_partial_shape().rank().is_static() && output.get_partial_shape().rank().get_length() == 1;
});
matcher_pass_callback callback = [](pattern::Matcher &m) -> bool { matcher_pass_callback callback = [](pattern::Matcher& m) -> bool {
m.get_match_root()->input(1).replace_source_output(ngraph::opset5::Constant::create(ngraph::element::i64, {1}, {-1})); m.get_match_root()->input(1).replace_source_output(
ngraph::opset5::Constant::create(ngraph::element::i64, {1}, {-1}));
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_label, matcher_name); auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_label /*, matcher_name*/);
register_matcher(m, callback); register_matcher(m, callback);
} }

View File

@ -2,13 +2,11 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include <memory>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <transformations/init_node_info.hpp>
#include <itt.hpp> #include <itt.hpp>
#include <memory>
#include <ngraph/pass/constant_folding.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/smart_reshape/mimic_set_batch_size.hpp> #include <transformations/smart_reshape/mimic_set_batch_size.hpp>
#include <transformations/smart_reshape/reshape_to_1D.hpp> #include <transformations/smart_reshape/reshape_to_1D.hpp>
#include <transformations/smart_reshape/set_batch_size.hpp> #include <transformations/smart_reshape/set_batch_size.hpp>
@ -17,8 +15,9 @@
NGRAPH_RTTI_DEFINITION(ngraph::pass::SetBatchSize, "SetBatchSize", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::SetBatchSize, "SetBatchSize", 0);
bool ngraph::pass::SetBatchSize::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ngraph::pass::SetBatchSize::run_on_function(std::shared_ptr<ngraph::Function> f) {
RUN_ON_FUNCTION_SCOPE(SetBatchSize); // TODO: enable conditional compile
OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::SetBatchSize"); // RUN_ON_FUNCTION_SCOPE(SetBatchSize);
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "ngraph::pass::SetBatchSize");
ngraph::pass::Manager manager; ngraph::pass::Manager manager;
// This pass must be called first in pipeline // This pass must be called first in pipeline
@ -30,4 +29,3 @@ bool ngraph::pass::SetBatchSize::run_on_function(std::shared_ptr<ngraph::Functio
manager.run_passes(f); manager.run_passes(f);
return true; return true;
} }

View File

@ -2,23 +2,23 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// //
#include "itt.hpp"
#include <memory> #include <memory>
#include <ngraph/pass/manager.hpp> #include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp> #include <transformations/init_node_info.hpp>
#include <transformations/smart_reshape/matmul_sr.hpp>
#include <transformations/smart_reshape/mimic_set_batch_size.hpp>
#include <transformations/smart_reshape/proposal_scales_stridedslice.hpp> #include <transformations/smart_reshape/proposal_scales_stridedslice.hpp>
#include <transformations/smart_reshape/reshape_to_1D.hpp> #include <transformations/smart_reshape/reshape_to_1D.hpp>
#include <transformations/smart_reshape/matmul_sr.hpp>
#include <transformations/smart_reshape/smart_reshape.hpp> #include <transformations/smart_reshape/smart_reshape.hpp>
#include <transformations/smart_reshape/strided_slice_squeeze.hpp> #include <transformations/smart_reshape/strided_slice_squeeze.hpp>
#include <transformations/smart_reshape/mimic_set_batch_size.hpp>
#include "itt.hpp"
NGRAPH_RTTI_DEFINITION(ngraph::pass::SmartReshape, "SmartReshape", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::SmartReshape, "SmartReshape", 0);
bool ngraph::pass::SmartReshape::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ngraph::pass::SmartReshape::run_on_function(std::shared_ptr<ngraph::Function> f) {
RUN_ON_FUNCTION_SCOPE(SmartReshape); // TODO: enable conditional compile
// RUN_ON_FUNCTION_SCOPE(SmartReshape);
ngraph::pass::Manager static_manager; ngraph::pass::Manager static_manager;
// This pass must be called first in pipeline // This pass must be called first in pipeline
static_manager.register_pass<ngraph::pass::InitNodeInfo>(); static_manager.register_pass<ngraph::pass::InitNodeInfo>();

View File

@ -3,24 +3,26 @@
// //
#include <itt.hpp> #include <itt.hpp>
#include <transformations/smart_reshape/strided_slice_squeeze.hpp>
#include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset5.hpp> #include <ngraph/opsets/opset5.hpp>
#include <ngraph/pattern/matcher.hpp> #include <ngraph/pattern/matcher.hpp>
#include <ngraph/pattern/op/wrap_type.hpp> #include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
#include <ngraph/validation_util.hpp>
#include <transformations/smart_reshape/strided_slice_squeeze.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::StridedSliceSqueeze, "ngraph::pass::StridedSliceSqueeze", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::StridedSliceSqueeze, "ngraph::pass::StridedSliceSqueeze", 0);
ngraph::pass::StridedSliceSqueeze::StridedSliceSqueeze() { ngraph::pass::StridedSliceSqueeze::StridedSliceSqueeze() {
MATCHER_SCOPE(StridedSliceSqueeze); // TODO: enable conditional compile
// MATCHER_SCOPE(StridedSliceSqueeze);
auto ss_label = ngraph::pattern::wrap_type<opset5::StridedSlice>(pattern::consumers_count(1)); auto ss_label = ngraph::pattern::wrap_type<opset5::StridedSlice>(pattern::consumers_count(1));
auto squeeze_label = ngraph::pattern::wrap_type<opset5::Squeeze>({ss_label, ngraph::pattern::wrap_type<opset5::Constant>()}); auto squeeze_label =
ngraph::pattern::wrap_type<opset5::Squeeze>({ss_label, ngraph::pattern::wrap_type<opset5::Constant>()});
matcher_pass_callback callback = [](pattern::Matcher &m) -> bool { matcher_pass_callback callback = [](pattern::Matcher& m) -> bool {
const auto & squeeze = m.get_match_root(); const auto& squeeze = m.get_match_root();
const auto & const_axes = std::dynamic_pointer_cast<ngraph::opset5::Constant>(squeeze->get_input_node_shared_ptr(1)); const auto& const_axes =
std::dynamic_pointer_cast<ngraph::opset5::Constant>(squeeze->get_input_node_shared_ptr(1));
auto slice = std::dynamic_pointer_cast<ngraph::opset5::StridedSlice>(squeeze->get_input_node_shared_ptr(0)); auto slice = std::dynamic_pointer_cast<ngraph::opset5::StridedSlice>(squeeze->get_input_node_shared_ptr(0));
if (!const_axes || !slice) if (!const_axes || !slice)
return false; return false;
@ -36,25 +38,36 @@ ngraph::pass::StridedSliceSqueeze::StridedSliceSqueeze() {
auto strides_vec = strides->cast_vector<int64_t>(); auto strides_vec = strides->cast_vector<int64_t>();
auto begin_mask = slice->get_begin_mask(); auto begin_mask = slice->get_begin_mask();
auto end_mask = slice->get_end_mask(); auto end_mask = slice->get_end_mask();
auto new_axis_mask = slice->get_new_axis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0) : slice->get_new_axis_mask(); auto new_axis_mask = slice->get_new_axis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0)
auto shrink_axis_mask = slice->get_shrink_axis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0) : slice->get_shrink_axis_mask(); : slice->get_new_axis_mask();
auto ellipsis_mask = slice->get_ellipsis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0) : slice->get_ellipsis_mask(); auto shrink_axis_mask = slice->get_shrink_axis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0)
: slice->get_shrink_axis_mask();
auto ellipsis_mask = slice->get_ellipsis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0)
: slice->get_ellipsis_mask();
auto is_zero_vec = [](const std::vector<int64_t> & mask){ return std::all_of(mask.begin(), mask.end(), [](const int64_t& i){ return i == 0; }); }; auto is_zero_vec = [](const std::vector<int64_t>& mask) {
return std::all_of(mask.begin(), mask.end(), [](const int64_t& i) {
return i == 0;
});
};
if (!is_zero_vec(new_axis_mask) || !is_zero_vec(shrink_axis_mask) || !is_zero_vec(ellipsis_mask)) if (!is_zero_vec(new_axis_mask) || !is_zero_vec(shrink_axis_mask) || !is_zero_vec(ellipsis_mask))
return false; return false;
if (!std::all_of(strides_vec.begin(), strides_vec.end(), [](const int64_t& i){ return i == 1; })) if (!std::all_of(strides_vec.begin(), strides_vec.end(), [](const int64_t& i) {
return i == 1;
}))
return false; return false;
const auto & axes = normalize_axes(squeeze->description(), const_axes->cast_vector<int64_t>(), squeeze->get_input_partial_shape(0).rank()); const auto& axes = normalize_axes(squeeze->description(),
for (const auto & axis : axes) { const_axes->cast_vector<int64_t>(),
if (begin_mask[axis]) { // corresponding dimension of the begin input is ignored. starting from 0 squeeze->get_input_partial_shape(0).rank());
for (const auto& axis : axes) {
if (begin_mask[axis]) { // corresponding dimension of the begin input is ignored. starting from 0
begin_vec[axis] = 0; begin_vec[axis] = 0;
end_vec[axis] = 1; end_vec[axis] = 1;
begin_mask[axis] = 0; begin_mask[axis] = 0;
end_mask[axis] = 0; end_mask[axis] = 0;
} else { // corresponding dimension of the begin input is used for slicing start } else { // corresponding dimension of the begin input is used for slicing start
if (begin_vec[axis] == -1) { // slicing the latest slice if (begin_vec[axis] == -1) { // slicing the latest slice
end_mask[axis] = 1; end_mask[axis] = 1;
} else { } else {
end_vec[axis] = begin_vec[axis] + 1; end_vec[axis] = begin_vec[axis] + 1;
@ -65,32 +78,40 @@ ngraph::pass::StridedSliceSqueeze::StridedSliceSqueeze() {
} }
auto new_slice = std::make_shared<opset5::StridedSlice>( auto new_slice = std::make_shared<opset5::StridedSlice>(
slice->input_value(0), slice->input_value(0),
opset5::Constant::create(element::i64, {begin_vec.size()}, begin_vec), opset5::Constant::create(element::i64, {begin_vec.size()}, begin_vec),
opset5::Constant::create(element::i64, {end_vec.size()}, end_vec), opset5::Constant::create(element::i64, {end_vec.size()}, end_vec),
opset5::Constant::create(element::i64, {strides_vec.size()}, strides_vec), opset5::Constant::create(element::i64, {strides_vec.size()}, strides_vec),
begin_mask, end_mask, new_axis_mask, shrink_axis_mask, ellipsis_mask); begin_mask,
end_mask,
new_axis_mask,
shrink_axis_mask,
ellipsis_mask);
replace_node(squeeze, new_slice); replace_node(squeeze, new_slice);
new_slice->set_friendly_name(slice->get_friendly_name()); new_slice->set_friendly_name(slice->get_friendly_name());
copy_runtime_info(slice, new_slice); copy_runtime_info(slice, new_slice);
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(squeeze_label, matcher_name); auto m = std::make_shared<ngraph::pattern::Matcher>(squeeze_label /*, matcher_name */);
register_matcher(m, callback); register_matcher(m, callback);
} }
NGRAPH_RTTI_DEFINITION(ngraph::pass::SqueezeStridedSlice, "ngraph::pass::SqueezeStridedSlice", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::SqueezeStridedSlice, "ngraph::pass::SqueezeStridedSlice", 0);
ngraph::pass::SqueezeStridedSlice::SqueezeStridedSlice() { ngraph::pass::SqueezeStridedSlice::SqueezeStridedSlice() {
MATCHER_SCOPE(SqueezeStridedSlice); // TODO: enable conditional compile
// MATCHER_SCOPE(SqueezeStridedSlice);
auto squeeze_label = ngraph::pattern::wrap_type<opset5::Squeeze>( auto squeeze_label = ngraph::pattern::wrap_type<opset5::Squeeze>(
{pattern::any_input(), ngraph::pattern::wrap_type<opset5::Constant>()}, pattern::consumers_count(1)); {pattern::any_input(), ngraph::pattern::wrap_type<opset5::Constant>()},
auto ss_label = ngraph::pattern::wrap_type<opset5::StridedSlice>({squeeze_label, pattern::any_input(), pattern::any_input(), pattern::any_input()}); pattern::consumers_count(1));
auto ss_label = ngraph::pattern::wrap_type<opset5::StridedSlice>(
{squeeze_label, pattern::any_input(), pattern::any_input(), pattern::any_input()});
matcher_pass_callback callback = [](pattern::Matcher &m) -> bool { matcher_pass_callback callback = [](pattern::Matcher& m) -> bool {
auto slice = std::dynamic_pointer_cast<ngraph::opset5::StridedSlice>(m.get_match_root()); auto slice = std::dynamic_pointer_cast<ngraph::opset5::StridedSlice>(m.get_match_root());
auto squeeze = slice->get_input_node_shared_ptr(0); auto squeeze = slice->get_input_node_shared_ptr(0);
const auto & const_axes = std::dynamic_pointer_cast<ngraph::opset5::Constant>(squeeze->get_input_node_shared_ptr(1)); const auto& const_axes =
std::dynamic_pointer_cast<ngraph::opset5::Constant>(squeeze->get_input_node_shared_ptr(1));
if (!const_axes || !slice) if (!const_axes || !slice)
return false; return false;
@ -105,19 +126,30 @@ ngraph::pass::SqueezeStridedSlice::SqueezeStridedSlice() {
auto strides_vec = strides->cast_vector<int64_t>(); auto strides_vec = strides->cast_vector<int64_t>();
auto begin_mask = slice->get_begin_mask(); auto begin_mask = slice->get_begin_mask();
auto end_mask = slice->get_end_mask(); auto end_mask = slice->get_end_mask();
auto new_axis_mask = slice->get_new_axis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0) : slice->get_new_axis_mask(); auto new_axis_mask = slice->get_new_axis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0)
auto shrink_axis_mask = slice->get_shrink_axis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0) : slice->get_shrink_axis_mask(); : slice->get_new_axis_mask();
auto ellipsis_mask = slice->get_ellipsis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0) : slice->get_ellipsis_mask(); auto shrink_axis_mask = slice->get_shrink_axis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0)
: slice->get_shrink_axis_mask();
auto ellipsis_mask = slice->get_ellipsis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0)
: slice->get_ellipsis_mask();
auto is_zero_vec = [](const std::vector<int64_t> & mask){ return std::all_of(mask.begin(), mask.end(), [](const int64_t& i){ return i == 0; }); }; auto is_zero_vec = [](const std::vector<int64_t>& mask) {
return std::all_of(mask.begin(), mask.end(), [](const int64_t& i) {
return i == 0;
});
};
if (!is_zero_vec(new_axis_mask) || !is_zero_vec(shrink_axis_mask) || !is_zero_vec(ellipsis_mask)) if (!is_zero_vec(new_axis_mask) || !is_zero_vec(shrink_axis_mask) || !is_zero_vec(ellipsis_mask))
return false; return false;
if (!std::all_of(strides_vec.begin(), strides_vec.end(), [](const int64_t& i){ return i == 1; })) if (!std::all_of(strides_vec.begin(), strides_vec.end(), [](const int64_t& i) {
return i == 1;
}))
return false; return false;
auto axes = normalize_axes(squeeze->description(), const_axes->cast_vector<int64_t>(), squeeze->get_input_partial_shape(0).rank()); auto axes = normalize_axes(squeeze->description(),
const_axes->cast_vector<int64_t>(),
squeeze->get_input_partial_shape(0).rank());
std::sort(axes.begin(), axes.end()); std::sort(axes.begin(), axes.end());
for (const auto & axis : axes) { for (const auto& axis : axes) {
begin_vec.insert(begin_vec.begin() + axis, 0); begin_vec.insert(begin_vec.begin() + axis, 0);
end_vec.insert(end_vec.begin() + axis, 1); end_vec.insert(end_vec.begin() + axis, 1);
strides_vec.insert(strides_vec.begin() + axis, 1); strides_vec.insert(strides_vec.begin() + axis, 1);
@ -129,24 +161,29 @@ ngraph::pass::SqueezeStridedSlice::SqueezeStridedSlice() {
} }
auto new_slice = std::make_shared<opset5::StridedSlice>( auto new_slice = std::make_shared<opset5::StridedSlice>(
slice->get_input_node_shared_ptr(0)->input_value(0), slice->get_input_node_shared_ptr(0)->input_value(0),
opset5::Constant::create(element::i64, {begin_vec.size()}, begin_vec), opset5::Constant::create(element::i64, {begin_vec.size()}, begin_vec),
opset5::Constant::create(element::i64, {end_vec.size()}, end_vec), opset5::Constant::create(element::i64, {end_vec.size()}, end_vec),
opset5::Constant::create(element::i64, {strides_vec.size()}, strides_vec), opset5::Constant::create(element::i64, {strides_vec.size()}, strides_vec),
begin_mask, end_mask, new_axis_mask, shrink_axis_mask, ellipsis_mask); begin_mask,
end_mask,
new_axis_mask,
shrink_axis_mask,
ellipsis_mask);
replace_node(slice, new_slice); replace_node(slice, new_slice);
new_slice->set_friendly_name(slice->get_friendly_name()); new_slice->set_friendly_name(slice->get_friendly_name());
copy_runtime_info(slice, new_slice); copy_runtime_info(slice, new_slice);
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(ss_label, matcher_name); auto m = std::make_shared<ngraph::pattern::Matcher>(ss_label /*, matcher_name */);
register_matcher(m, callback); register_matcher(m, callback);
} }
NGRAPH_RTTI_DEFINITION(ngraph::pass::SharedSqueeze, "ngraph::pass::SharedSqueeze", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::SharedSqueeze, "ngraph::pass::SharedSqueeze", 0);
bool squeezes_perform_the_same(std::shared_ptr<ngraph::opset5::Squeeze> lhs, std::shared_ptr<ngraph::opset5::Squeeze> rhs) { bool squeezes_perform_the_same(std::shared_ptr<ngraph::opset5::Squeeze> lhs,
std::shared_ptr<ngraph::opset5::Squeeze> rhs) {
size_t l_input_size = lhs->inputs().size(), r_input_size = rhs->inputs().size(); size_t l_input_size = lhs->inputs().size(), r_input_size = rhs->inputs().size();
if (l_input_size != r_input_size) if (l_input_size != r_input_size)
return false; return false;
@ -164,13 +201,14 @@ bool squeezes_perform_the_same(std::shared_ptr<ngraph::opset5::Squeeze> lhs, std
} }
bool ngraph::pass::SharedSqueeze::run_on_function(std::shared_ptr<ngraph::Function> f) { bool ngraph::pass::SharedSqueeze::run_on_function(std::shared_ptr<ngraph::Function> f) {
RUN_ON_FUNCTION_SCOPE(SharedSqueeze); // TODO: enable conditional compile
OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::SharedSqueeze"); // RUN_ON_FUNCTION_SCOPE(SharedSqueeze);
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, "ngraph::pass::SharedSqueeze");
bool graph_rewritten = false; bool graph_rewritten = false;
std::map<ngraph::Output<Node>, std::vector<std::shared_ptr<ngraph::opset5::Squeeze>>> source_to_squeeze; std::map<ngraph::Output<Node>, std::vector<std::shared_ptr<ngraph::opset5::Squeeze>>> source_to_squeeze;
for (const auto & node : f->get_ordered_ops()) { for (const auto& node : f->get_ordered_ops()) {
// Recursively apply transformation for sub-graph based operations // Recursively apply transformation for sub-graph based operations
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) { if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
if (auto sub_graph = sub_graph_node->get_function()) { if (auto sub_graph = sub_graph_node->get_function()) {
@ -187,7 +225,8 @@ bool ngraph::pass::SharedSqueeze::run_on_function(std::shared_ptr<ngraph::Functi
continue; continue;
auto root_squeeze = item.second[0]; auto root_squeeze = item.second[0];
for (auto& child_squeeze : item.second) { for (auto& child_squeeze : item.second) {
if (root_squeeze->get_instance_id() != child_squeeze->get_instance_id() && squeezes_perform_the_same(root_squeeze, child_squeeze)) { if (root_squeeze->get_instance_id() != child_squeeze->get_instance_id() &&
squeezes_perform_the_same(root_squeeze, child_squeeze)) {
graph_rewritten |= replace_output_update_name(child_squeeze->output(0), root_squeeze->output(0)); graph_rewritten |= replace_output_update_name(child_squeeze->output(0), root_squeeze->output(0));
} }
} }

View File

@ -6,6 +6,7 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "openvino/core/partial_shape.hpp"
#include "openvino/opsets/opset8.hpp" #include "openvino/opsets/opset8.hpp"
TEST(function, get_input_by_tensor_name) { TEST(function, get_input_by_tensor_name) {
@ -511,3 +512,243 @@ TEST(function, DISABLED_create_function_with_incorrect_tensor_names_from_const_f
auto f = std::make_shared<const ov::Function>(relu, ov::ParameterVector{arg0}); auto f = std::make_shared<const ov::Function>(relu, ov::ParameterVector{arg0});
ASSERT_THROW(f->validate_nodes_and_infer_types(), ov::Exception); ASSERT_THROW(f->validate_nodes_and_infer_types(), ov::Exception);
} }
TEST(function_reshape, ReshapedDynamicShapeLayout) {
std::shared_ptr<ov::Function> ngraph;
{
ov::PartialShape shape({-1, 3, 22, 22});
ov::element::Type type(ov::element::Type_t::f32);
auto param = std::make_shared<ov::op::v0::Parameter>(type, shape);
param->get_output_tensor(0).set_names({"tensor"});
auto relu = std::make_shared<ov::op::v0::Relu>(param);
ov::ParameterVector params = {param};
ngraph = std::make_shared<ov::Function>(relu, params);
}
EXPECT_TRUE(ngraph->input().get_partial_shape().is_dynamic());
std::map<std::string, ov::PartialShape> new_shape;
new_shape["tensor"] = ov::Shape{1, 3, 22, 22};
ASSERT_NO_THROW(ngraph->reshape(new_shape));
EXPECT_FALSE(ngraph->input().get_partial_shape().is_dynamic());
EXPECT_FALSE(ngraph->get_parameters().front()->get_partial_shape().is_dynamic());
}
TEST(function_reshape, ReshapeBatchReLU) {
std::shared_ptr<ov::Function> ngraph;
{
ov::PartialShape shape({1, 3, 22, 22});
ov::element::Type type(ov::element::Type_t::f32);
auto param = std::make_shared<ov::op::v0::Parameter>(type, shape);
param->get_output_tensor(0).set_names({"tensor"});
auto relu = std::make_shared<ov::op::v0::Relu>(param);
auto result = std::make_shared<ov::op::v0::Result>(relu);
ov::ParameterVector params = {param};
ov::ResultVector results = {result};
ngraph = std::make_shared<ov::Function>(results, params);
}
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
{
std::map<std::string, ov::PartialShape> new_shape;
new_shape["tensor"] = ov::PartialShape{2, 3, 22, 22};
ASSERT_NO_THROW(ngraph->reshape(new_shape));
}
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({2, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({2, 3, 22, 22}));
}
TEST(function_reshape, ReshapeSpatialReLU) {
std::shared_ptr<ov::Function> ngraph;
{
ov::PartialShape shape({1, 3, 22, 22});
ov::element::Type type(ov::element::Type_t::f32);
auto param = std::make_shared<ov::op::v0::Parameter>(type, shape);
param->get_output_tensor(0).set_names({"tensor"});
auto relu = std::make_shared<ov::op::v0::Relu>(param);
auto result = std::make_shared<ov::op::v0::Result>(relu);
ov::ParameterVector params = {param};
ov::ResultVector results = {result};
ngraph = std::make_shared<ov::Function>(results, params);
}
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
{
std::map<std::string, ov::PartialShape> new_shape;
new_shape["tensor"] = ov::PartialShape{1, 3, 25, 25};
ASSERT_NO_THROW(ngraph->reshape(new_shape));
}
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 25, 25}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 25, 25}));
}
TEST(function_reshape, ReshapeSpatialReLUWithoutReplaceParameter) {
std::shared_ptr<ov::Function> ngraph;
{
ov::PartialShape shape({1, 3, 22, 22});
ov::element::Type type(ov::element::Type_t::f32);
auto param = std::make_shared<ov::op::v0::Parameter>(type, shape);
auto relu = std::make_shared<ov::op::v0::Relu>(param);
auto result = std::make_shared<ov::op::v0::Result>(relu);
ov::ParameterVector params = {param};
ov::ResultVector results = {result};
ngraph = std::make_shared<ov::Function>(results, params);
}
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
{
ngraph->get_parameters()[0]->set_partial_shape({1, 3, 25, 25});
ngraph->validate_nodes_and_infer_types();
}
ASSERT_EQ(ngraph->input().get_partial_shape(), ov::Shape({1, 3, 25, 25}));
ASSERT_EQ(ngraph->output().get_partial_shape(), ov::Shape({1, 3, 25, 25}));
}
TEST(function_reshape, ReshapeSpatialReLUStaticToDynamic) {
const ov::PartialShape refShape{1, 3, ov::Dimension::dynamic(), 25};
std::shared_ptr<ov::Function> ngraph;
{
ov::PartialShape shape({1, 3, 22, 22});
ov::element::Type type(ov::element::Type_t::f32);
auto param = std::make_shared<ov::op::v0::Parameter>(type, shape);
param->get_output_tensor(0).set_names({"tensor"});
auto relu = std::make_shared<ov::op::v0::Relu>(param);
auto result = std::make_shared<ov::op::v0::Result>(relu);
ov::ParameterVector params = {param};
ov::ResultVector results = {result};
ngraph = std::make_shared<ov::Function>(results, params);
}
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
{
std::map<std::string, ov::PartialShape> new_shape;
new_shape["tensor"] = refShape;
ASSERT_NO_THROW(ngraph->reshape(new_shape));
}
ASSERT_TRUE(ngraph->input(0).get_partial_shape().is_dynamic());
ASSERT_TRUE(ngraph->output(0).get_partial_shape().is_dynamic());
ASSERT_EQ(ngraph->input(0).get_partial_shape(), refShape);
ASSERT_EQ(ngraph->output(0).get_partial_shape(), refShape);
}
TEST(function_reshape, ReshapeSpatialReLUStaticToFullyDynamic) {
const ov::PartialShape refShape = ov::PartialShape::dynamic();
std::shared_ptr<ov::Function> ngraph;
{
ov::PartialShape shape({1, 3, 22, 22});
ov::element::Type type(ov::element::Type_t::f32);
auto param = std::make_shared<ov::op::v0::Parameter>(type, shape);
param->get_output_tensor(0).set_names({"tensor"});
auto relu = std::make_shared<ov::op::v0::Relu>(param);
auto result = std::make_shared<ov::op::v0::Result>(relu);
ov::ParameterVector params = {param};
ov::ResultVector results = {result};
ngraph = std::make_shared<ov::Function>(results, params);
}
ASSERT_EQ(ngraph->get_parameters()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
ASSERT_EQ(ngraph->get_results()[0]->get_shape(), ov::Shape({1, 3, 22, 22}));
{
std::map<std::string, ov::PartialShape> new_shape;
new_shape["tensor"] = refShape;
ASSERT_NO_THROW(ngraph->reshape(new_shape));
}
ASSERT_TRUE(ngraph->input().get_partial_shape().is_dynamic());
ASSERT_TRUE(ngraph->output().get_partial_shape().is_dynamic());
ASSERT_EQ(ngraph->input().get_partial_shape(), refShape);
ASSERT_EQ(ngraph->output().get_partial_shape(), refShape);
}
TEST(function_reshape, ReshapeSpatialReLUDynamicToDynamic) {
const ov::PartialShape refShape{1, 3, ov::Dimension::dynamic(), 25};
std::shared_ptr<ov::Function> ngraph;
{
ov::PartialShape shape({1, 3, 22, ov::Dimension::dynamic()});
ov::element::Type type(ov::element::Type_t::f32);
auto param = std::make_shared<ov::op::v0::Parameter>(type, shape);
param->get_output_tensor(0).set_names({"tensor"});
auto relu = std::make_shared<ov::op::v0::Relu>(param);
auto result = std::make_shared<ov::op::v0::Result>(relu);
ov::ParameterVector params = {param};
ov::ResultVector results = {result};
ngraph = std::make_shared<ov::Function>(results, params);
}
ASSERT_EQ(ngraph->input().get_partial_shape(), ov::PartialShape({1, 3, 22, ov::Dimension::dynamic()}));
ASSERT_EQ(ngraph->output().get_partial_shape(), ov::PartialShape({1, 3, 22, ov::Dimension::dynamic()}));
{
std::map<std::string, ov::PartialShape> new_shape;
new_shape["tensor"] = refShape;
ASSERT_NO_THROW(ngraph->reshape(new_shape));
}
ASSERT_TRUE(ngraph->input().get_partial_shape().is_dynamic());
ASSERT_TRUE(ngraph->output().get_partial_shape().is_dynamic());
ASSERT_EQ(ngraph->input().get_partial_shape(), refShape);
ASSERT_EQ(ngraph->output().get_partial_shape(), refShape);
}
TEST(function_reshape, TestInvalidReshape) {
std::shared_ptr<ov::Function> f;
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 1000, 4});
input->get_output_tensor(0).set_names({"tensor"});
auto shape = ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 4000});
auto reshape = std::make_shared<ov::op::v1::Reshape>(input, shape, true);
f = std::make_shared<ov::Function>(ov::OutputVector{reshape}, ov::ParameterVector{input});
}
ASSERT_ANY_THROW(f->reshape({{"tensor", ov::Shape({4})}}));
auto param = f->get_parameters().front();
ASSERT_EQ(param->get_output_shape(0), ov::Shape({1, 1000, 4}));
ASSERT_NO_THROW(f->reshape({{"tensor", ov::Shape({1, 1000, 4})}}));
}
TEST(function_reshape, TestReshapeWithInvalidTensorName) {
std::shared_ptr<ov::Function> f;
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 1000, 4});
input->set_friendly_name("param");
input->get_output_tensor(0).set_names({"tensor"});
auto shape = ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 4000});
auto reshape = std::make_shared<ov::op::v1::Reshape>(input, shape, true);
f = std::make_shared<ov::Function>(ov::OutputVector{reshape}, ov::ParameterVector{input});
}
// both operation names and tensor names are specified
ASSERT_ANY_THROW(f->reshape({{"param", ov::Shape({4, 4, 4})}, {"tensor", ov::Shape({4, 4, 4})}}));
// operation name does not work
ASSERT_ANY_THROW(f->reshape({{"param", ov::Shape({4, 4, 4})}}));
}