Xuejun/remove api model (#15924)
* [Remove APIs] remove api m_transformation_callback
Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
* [Remove APIs] remove api run_on_function(), replaced by run_on_model()
Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
* [Remove APIs] remove set_callback(), use get_pass_config() to configure transformation pipeline
Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
* [Remove APIs] remove api add_matcher()
Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
* Fix format issue
Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
* [Remove APIs] Fix review comments
Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
* Fix formast issue
Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
* Fix merge master error
Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
* Fix CI compiler error
Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
* Update ONNX Runtime from rel-1.8.1 to rel-1.14.0
Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
* Revert "Update ONNX Runtime from rel-1.8.1 to rel-1.14.0"
This reverts commit e31a9e04b7
.
---------
Signed-off-by: Zhai, Xuejun <xuejun.zhai@intel.com>
This commit is contained in:
parent
bb59672639
commit
a9bd5f741d
@ -17,9 +17,9 @@ namespace ngraph {
|
||||
namespace pass {
|
||||
namespace device {
|
||||
|
||||
class ConvertOpSet1ToDeviceSpecific: public ngraph::pass::FunctionPass {
|
||||
class ConvertOpSet1ToDeviceSpecific: public ov::pass::ModelPass {
|
||||
public:
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override {
|
||||
bool run_on_model(const std::shared_ptr<ngraph::Function>& f) override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
@ -24,6 +24,11 @@ OV_CC_DOMAINS(ov_pass);
|
||||
# define ADD_MATCHER(obj, region, ...) obj->add_matcher<region>(__VA_ARGS__);
|
||||
# define REGISTER_PASS(obj, region, ...) obj.register_pass<region>(__VA_ARGS__);
|
||||
# define REGISTER_DISABLED_PASS(obj, region, ...) obj.register_pass<region, false>(__VA_ARGS__);
|
||||
|
||||
# define OV_PASS_CALLBACK(matcher) \
|
||||
openvino::itt::handle_t m_callback_handle; \
|
||||
m_callback_handle = openvino::itt::handle(matcher->get_name()); \
|
||||
OV_ITT_SCOPED_TASK(SIMPLE_ov_pass, m_callback_handle)
|
||||
#elif defined(SELECTIVE_BUILD)
|
||||
|
||||
# define MATCHER_SCOPE_(scope, region) \
|
||||
@ -70,6 +75,7 @@ OV_CC_DOMAINS(ov_pass);
|
||||
# define REGISTER_DISABLED_PASS(obj, region, ...) \
|
||||
OV_PP_CAT(REGISTER_PASS_WITH_FALSE_, OV_CC_SCOPE_IS_ENABLED(OV_PP_CAT3(ov_pass, _, region))) \
|
||||
(obj, region, __VA_ARGS__)
|
||||
# define OV_PASS_CALLBACK(matcher)
|
||||
#else
|
||||
|
||||
# define MATCHER_SCOPE(region) const std::string matcher_name(OV_PP_TOSTRING(region))
|
||||
@ -79,6 +85,7 @@ OV_CC_DOMAINS(ov_pass);
|
||||
# define ADD_MATCHER(obj, region, ...) obj->add_matcher<region>(__VA_ARGS__);
|
||||
# define REGISTER_PASS(obj, region, ...) obj.register_pass<region>(__VA_ARGS__);
|
||||
# define REGISTER_DISABLED_PASS(obj, region, ...) obj.register_pass<region, false>(__VA_ARGS__);
|
||||
# define OV_PASS_CALLBACK(matcher)
|
||||
#endif
|
||||
|
||||
#define ADD_MATCHER_FOR_THIS(region, ...) ADD_MATCHER(this, region, __VA_ARGS__)
|
||||
|
@ -5,7 +5,6 @@
|
||||
#include <low_precision/layer_transformation.hpp>
|
||||
#include <low_precision/network_helper.hpp>
|
||||
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
@ -448,9 +447,24 @@ void LayerTransformation::addPattern(ngraph::pass::GraphRewrite& pass, Transform
|
||||
};
|
||||
// TODO: better name for matcher? required?
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(patternRoot, matcher_name);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
pass.add_matcher(m, internal_callback, ngraph::pass::PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
auto match_pass = std::make_shared<ov::pass::MatcherPass>(
|
||||
m->get_name(),
|
||||
m,
|
||||
[m, internal_callback](const std::shared_ptr<Node>& node) -> bool {
|
||||
NGRAPH_DEBUG << "Running matcher " << m->get_name() << " on " << node;
|
||||
OV_PASS_CALLBACK(m);
|
||||
if (std::dynamic_pointer_cast<ov::pass::pattern::Matcher>(m)->match(node->output(0))) {
|
||||
NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
|
||||
bool status = internal_callback(*m.get());
|
||||
// explicitly clear Matcher state because it holds pointers to matched nodes
|
||||
m->clear_state();
|
||||
return status;
|
||||
}
|
||||
m->clear_state();
|
||||
return false;
|
||||
},
|
||||
ov::pass::PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
pass.add_matcher(match_pass);
|
||||
}
|
||||
|
||||
} // namespace low_precision
|
||||
|
@ -5,7 +5,6 @@
|
||||
#include "low_precision/low_precision.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
@ -134,9 +133,24 @@ void make_matcher_type_relaxed(ngraph::pass::GraphRewrite* transformation) {
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(p_node, matcher_name);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
transformation->add_matcher(m, callback, ngraph::pass::PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
auto match_pass = std::make_shared<ov::pass::MatcherPass>(
|
||||
m->get_name(),
|
||||
m,
|
||||
[m, callback](const std::shared_ptr<Node>& node) -> bool {
|
||||
NGRAPH_DEBUG << "Running matcher " << m->get_name() << " on " << node;
|
||||
if (std::dynamic_pointer_cast<ov::pass::pattern::Matcher>(m)->match(node->output(0))) {
|
||||
NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
|
||||
OV_PASS_CALLBACK(m);
|
||||
bool status = callback(*m.get());
|
||||
// explicitly clear Matcher state because it holds pointers to matched nodes
|
||||
m->clear_state();
|
||||
return status;
|
||||
}
|
||||
m->clear_state();
|
||||
return false;
|
||||
},
|
||||
ov::pass::PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
transformation->add_matcher(match_pass);
|
||||
}
|
||||
|
||||
ngraph::pass::low_precision::TypeRelaxedReplacer::TypeRelaxedReplacer() {
|
||||
|
@ -80,7 +80,7 @@ public:
|
||||
0);
|
||||
|
||||
ngraph::pass::low_precision::TypeRelaxedReplacer pass;
|
||||
pass.run_on_function(actualFunction);
|
||||
pass.run_on_model(actualFunction);
|
||||
|
||||
auto supportedPrecisionsOnActivation = std::vector<ngraph::pass::low_precision::PrecisionsRestriction>(
|
||||
{ngraph::pass::low_precision::PrecisionsRestriction::create<ngraph::opset1::Convolution>(
|
||||
@ -129,7 +129,7 @@ public:
|
||||
};
|
||||
|
||||
TEST_P(MarkupAvgPoolPrecisionsTransformation, CompareFunctions) {
|
||||
ov::pass::InitNodeInfo().run_on_function(actualFunction);
|
||||
ov::pass::InitNodeInfo().run_on_model(actualFunction);
|
||||
actualFunction->validate_nodes_and_infer_types();
|
||||
|
||||
const auto avgPoolOperations = LayerTransformation::get<opset1::AvgPool>(actualFunction);
|
||||
|
@ -267,7 +267,7 @@ TEST_F(TransformationTestsF, PropagateMasksBasic) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksBasic.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -352,7 +352,7 @@ TEST_F(TransformationTestsF, PropagateMasksDynamicConvolution) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksDynamicConvolution.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
|
||||
{
|
||||
pass::Manager m;
|
||||
@ -403,7 +403,7 @@ TEST(TransformationTests, PropagateMasksDynamicReshape) {
|
||||
auto function = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksDynamicReshape.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::Pruning>();
|
||||
@ -448,7 +448,7 @@ TEST(TransformationTests, PropagateMasksDynamicGroupConvolution) {
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksDynamicGroupConvolution.svg")
|
||||
.run_on_function(f);
|
||||
.run_on_model(f);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -486,7 +486,7 @@ TEST(TransformationTests, PropagateMasksEmpty) {
|
||||
auto f = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksEmpty.svg").run_on_function(f);
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksEmpty.svg").run_on_model(f);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -583,7 +583,7 @@ TEST_F(TransformationTestsF, PropagateMaskPassThrough) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMaskPassThrough.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -735,7 +735,7 @@ TEST_F(TransformationTestsF, PropagateMasksHardDependencies) {
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksHardDependencies.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -886,7 +886,7 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolution) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksQuantizedGroupConvolution.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -1053,7 +1053,7 @@ TEST_F(TransformationTestsF, PropagateMasksQuantizedGroupConvolutionWithShapeOf)
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) +
|
||||
"PropagateMasksQuantizedGroupConvolutionWithShapeOf.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -1185,7 +1185,7 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerTensor) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksFakeQuantizePerTensor.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
|
||||
{
|
||||
pass::Manager m;
|
||||
@ -1269,7 +1269,7 @@ TEST(TransformationTests, PropagateMasksFakeQuantizePerTensor1DScale) {
|
||||
auto function = std::make_shared<Function>(NodeVector{conv2}, ParameterVector{input});
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksFakeQuantizePerTensor1DScale.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
|
||||
{
|
||||
pass::Manager m;
|
||||
@ -1387,7 +1387,7 @@ TEST_F(TransformationTestsF, PropagateMasksFakeQuantizePerChannel) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksFakeQuantizePerChannel.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
// Masks for fq input parammeters didn't saved after
|
||||
@ -1530,7 +1530,7 @@ TEST_F(TransformationTestsF, TestConcatMaskPropagation) {
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "TestConcatMaskPropagation.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -1673,7 +1673,7 @@ TEST_F(TransformationTestsF, TestConcatMaskPropagationUp) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "TestConcatMaskPropagationUp.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -1744,7 +1744,7 @@ TEST(TransformationTests, TestConcatMaskPropagationUpEmpty) {
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "TestConcatMaskPropagationUpEmpty.svg")
|
||||
.run_on_function(f);
|
||||
.run_on_model(f);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -1806,7 +1806,7 @@ TEST_F(TransformationTestsF, PruneConvIsClosingAndInGroup) {
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneConvIsClosingAndInGroup.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
auto input = std::make_shared<opset10::Parameter>(element::f32, inputShapes);
|
||||
auto weights = create_constant_with_zeros(
|
||||
@ -1922,7 +1922,7 @@ TEST(TransformationTests, PruneBranchingStopOp) {
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneBranchingStopOp.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::Pruning>();
|
||||
@ -1977,7 +1977,7 @@ TEST(TransformationTests, PruneStopOpUp) {
|
||||
auto function = std::make_shared<ngraph::Function>(OutputVector{end_conv}, ParameterVector{input}, "StopOpUp");
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneStopOpUp.svg").run_on_function(function);
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneStopOpUp.svg").run_on_model(function);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::Pruning>();
|
||||
@ -2044,8 +2044,7 @@ TEST_F(TransformationTestsF, PruneReducelayerUp) {
|
||||
function_ref = std::make_shared<ngraph::Function>(OutputVector{conv_1}, ParameterVector{input});
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneReducelayerUp.svg")
|
||||
.run_on_function(function);
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneReducelayerUp.svg").run_on_model(function);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -2142,7 +2141,7 @@ TEST_F(TransformationTestsF, PruneReduceLayerDown) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneReduceLayerDown.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -2194,7 +2193,7 @@ TEST(TransformationTests, PruneStopReducelayerUp) {
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneStopReducelayerUp.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::Pruning>();
|
||||
@ -2252,7 +2251,7 @@ TEST(TransformationTests, PruneStopReduceLayerDown) {
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneStopReduceLayerDown.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::Pruning>();
|
||||
@ -2327,7 +2326,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUp) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationReshapeUp.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -2438,7 +2437,7 @@ TEST_P(TransformationTestsBoolParamF, MaskPropagationReshapeUpWithShapeOf) {
|
||||
const auto postfix = use_shape_of ? "True" : "False";
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationReshapeUpWithShapeOf" + postfix +
|
||||
".svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
}
|
||||
{
|
||||
pass::Manager m;
|
||||
@ -2550,7 +2549,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUpShapeSubGraph) {
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationReshapeUpShapeSubGraph.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -2642,7 +2641,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeExtend) {
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationReshapeExtend.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -2749,7 +2748,7 @@ TEST_F(DISABLED_TransformationTestsF, MaskPropagationReshapeDownMul) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationReshapeDownMul.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -2853,7 +2852,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeDownAdd) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationReshapeDownAdd.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -2902,7 +2901,7 @@ TEST(TransformationTests, MaskPropagationStopReshapeUp) {
|
||||
auto function = std::make_shared<ngraph::Function>(OutputVector{conv_1}, ParameterVector{input});
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationStopReshapeUp.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -2959,7 +2958,7 @@ TEST(TransformationTests, MaskPropagationStopReshapeDown) {
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationStopReshapeDown.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -3017,7 +3016,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUnsqueezeUp) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationReshapeUnsqueezeUp.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -3079,7 +3078,7 @@ TEST_F(TransformationTestsF, MaskPropagationReshapeUnsqueezeDown) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationReshapeUnsqueezeDown.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -3140,7 +3139,7 @@ TEST(TransformationTests, MaskPropagationWrongDimsElementwise) {
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationWrongDimsElementwise.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::Pruning>();
|
||||
@ -3251,7 +3250,7 @@ TEST_F(TransformationTestsF, PruneSEBlock) {
|
||||
function_ref = std::make_shared<ngraph::Function>(OutputVector{end_conv}, ParameterVector{input}, "SEBlock");
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneSEBlock.svg").run_on_function(function);
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneSEBlock.svg").run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -3343,7 +3342,7 @@ TEST_F(TransformationTestsF, PropagateMasksLinear) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksLinear.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -3395,7 +3394,7 @@ TEST(TransformationTests, MaskPropagationMatMulStopEmptyABranch) {
|
||||
auto function = std::make_shared<ngraph::Function>(OutputVector{mul_left}, ParameterVector{input});
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationMatMulStopEmptyABranch.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -3462,7 +3461,7 @@ TEST(TransformationTests, PruneLinearUp) {
|
||||
auto function = std::make_shared<ngraph::Function>(OutputVector{last_linear}, ParameterVector{input});
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneLinearUp.svg").run_on_function(function);
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneLinearUp.svg").run_on_model(function);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::Pruning>();
|
||||
@ -3519,8 +3518,7 @@ TEST(TransformationTests, PruneConvUpShort) {
|
||||
auto function = std::make_shared<ngraph::Function>(OutputVector{last_conv}, ParameterVector{input});
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneConvUpShort.svg")
|
||||
.run_on_function(function);
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneConvUpShort.svg").run_on_model(function);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::Pruning>();
|
||||
@ -3595,7 +3593,7 @@ TEST_F(TransformationTestsF, MaskPropagationLinearOuterDims) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationLinearOuterDims.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -3674,7 +3672,7 @@ TEST(TransformationTests, MaskPropagationStopLinearOuterDims) {
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationStopLinearOuterDims.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -3765,7 +3763,7 @@ TEST_F(TransformationTestsF, PruneMasksMatMulColsStopRowsUp) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneMasksMatMulColsStopRowsUp.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -3854,7 +3852,7 @@ TEST_F(TransformationTestsF, PruneMasksMatMulRowsStopColsUp) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PruneMasksMatMulRowsStopColsUp.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -3949,8 +3947,7 @@ TEST_F(TransformationTestsF, PropagateFlattenUp) {
|
||||
function_ref = std::make_shared<Function>(NodeVector{linear}, ParameterVector{input});
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateFlattenUp.svg")
|
||||
.run_on_function(function);
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateFlattenUp.svg").run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -4025,7 +4022,7 @@ TEST_F(TransformationTestsF, PropagateFlattenDown) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateFlattenDown.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -4084,7 +4081,7 @@ TEST_F(TransformationTestsF, PropagateMasksTranspose) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksTranspose.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -4157,7 +4154,7 @@ TEST_F(TransformationTestsF, PropagateMasksTransposeComplex) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksTransposeComplex.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -4197,7 +4194,7 @@ TEST(TransformationTests, PropagateMasksTransposeStop) {
|
||||
auto function = std::make_shared<Function>(NodeVector{last_mul}, ParameterVector{input});
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksTransposeStop.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -4323,7 +4320,7 @@ TEST_F(DISABLED_TransformationTestsF, PropagateMasksBroadcastedEltwiseWithInputs
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE) {
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksBroadcastedEltwiseWithInputs.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
}
|
||||
{
|
||||
pass::Manager m;
|
||||
@ -4500,7 +4497,7 @@ TEST_F(TransformationTestsF, PropagateMasksBroadcastedEltwise) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE) {
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "PropagateMasksBroadcastedEltwise.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
}
|
||||
{
|
||||
pass::Manager m;
|
||||
@ -4663,7 +4660,7 @@ TEST_F(TransformationTestsF, MaskPropagationComplexReshape) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE) {
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationComplexReshape.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
}
|
||||
{
|
||||
pass::Manager m;
|
||||
@ -4856,7 +4853,7 @@ TEST_P(TransformationTestsBoolParamF, MaskPropagationReshapedPassThroughP) {
|
||||
auto postfix = (add_shape_of) ? "True" : "False";
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationReshapedPassThroughP" + postfix +
|
||||
".svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
}
|
||||
{
|
||||
pass::Manager m;
|
||||
@ -4981,7 +4978,7 @@ TEST_P(TransformationTestsBoolParamF, MaskPropagationBroadcastedSameRankEltwiseS
|
||||
auto postfix = (reverse_mul) ? "True" : "False";
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) +
|
||||
"MaskPropagationBroadcastedSameRankEltwiseSwappedLayoutP" + postfix + ".svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
}
|
||||
{
|
||||
pass::Manager m;
|
||||
@ -5028,7 +5025,7 @@ TEST(TransformationTests, MaskPropagationBroadcastedEltwiseInputAndWeightsBroadc
|
||||
if (VISUALIZE_TESTS_TREE) {
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) +
|
||||
"MaskPropagationBroadcastedEltwiseInputAndWeightsBroadcasted.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
}
|
||||
{
|
||||
pass::Manager m;
|
||||
@ -5078,7 +5075,7 @@ TEST(TransformationTests, MaskPropagationBroadcastedEltwiseWrongBroadcastingMode
|
||||
if (VISUALIZE_TESTS_TREE) {
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) +
|
||||
"MaskPropagationBroadcastedEltwiseWrongBroadcastingMode.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
}
|
||||
{
|
||||
pass::Manager m;
|
||||
@ -5143,7 +5140,7 @@ TEST_F(TransformationTestsF, MaskPropagationMatMulWithSeveralOutputs) {
|
||||
}
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "MaskPropagationMatMulWithSeveralOutputs.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
{
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitMasks>();
|
||||
@ -5174,7 +5171,7 @@ TEST(TransformationTests, CheckReshapeWithNoConstInShape) {
|
||||
|
||||
if (VISUALIZE_TESTS_TREE)
|
||||
ngraph::pass::VisualizeTree(std::string(VISUALIZE_TREE_ROOT) + "CheckReshapeWithNoConstInShape.svg")
|
||||
.run_on_function(function);
|
||||
.run_on_model(function);
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::ShrinkWeights>();
|
||||
|
@ -257,14 +257,6 @@ public:
|
||||
return pass;
|
||||
}
|
||||
|
||||
OPENVINO_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property);
|
||||
|
||||
OPENVINO_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m, const ov::graph_rewrite_callback& callback);
|
||||
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||
|
||||
void set_pass_config(const std::shared_ptr<PassConfig>& pass_config) override;
|
||||
|
@ -76,36 +76,6 @@ public:
|
||||
/// \param new_state Value "true" enables Validate pass run; "false", otherwise
|
||||
void set_per_pass_validation(bool new_state);
|
||||
|
||||
/// \brief Callback is a lambda function that can be used by registered transformations.
|
||||
/// The main purpose of this callback is to provide a way for plugins to disable/enable
|
||||
/// transformations based on some conditions. In some cases plugins may want not to
|
||||
/// execute some
|
||||
/// transformations.
|
||||
/// For example plugin can disable unpleasant decompositions because of performance
|
||||
/// reasons for
|
||||
/// some cases.
|
||||
/// Callback example:
|
||||
/// auto callback = [](const std::shared_ptr<const ov::Node> & node) -> bool {
|
||||
/// return std::dynamic_pointer_cast<const ov::opset3::DepthToSpace>(node) !=
|
||||
/// nullptr;
|
||||
/// };
|
||||
/// This callback returns true in case of DepthToSpace operation. So when execution
|
||||
/// DepthToSpace
|
||||
/// decomposition pass will check is this decomposition needed or plugin can execute
|
||||
/// this
|
||||
/// operation directly. And of course on transformation side we need to have a response
|
||||
/// for this
|
||||
/// callback.
|
||||
/// if (transformation_callback(batch_to_space)) {
|
||||
/// return false;
|
||||
/// }
|
||||
/// \param callback lamda function that returns true in case if node is supported by
|
||||
/// plugin and
|
||||
/// transformation is not needed
|
||||
OPENVINO_DEPRECATED("Please use get_pass_config() to configure transformation pipeline")
|
||||
void set_callback(const param_callback& callback) {
|
||||
m_pass_config->set_callback(callback);
|
||||
}
|
||||
/// \return PassConfig shared object. This object is used for transformations pipeline
|
||||
/// configuration.
|
||||
/// This object allows to disable/enable transformations execution, set callback to
|
||||
|
@ -61,14 +61,6 @@ public:
|
||||
std::shared_ptr<PassConfig> get_pass_config() {
|
||||
return m_pass_config;
|
||||
}
|
||||
/// \brief Applies callback for given node. By default callback returns false.
|
||||
/// This method remains here only for backward compatibility and will be removed
|
||||
/// after all transformations are moved to transformation_callback() method.
|
||||
/// \return result of callback execution for given node
|
||||
OPENVINO_DEPRECATED("Please use transformation_callback method instead")
|
||||
bool m_transformation_callback(const std::shared_ptr<const Node>& node) {
|
||||
return m_pass_config->get_callback(get_type_info())(node);
|
||||
}
|
||||
|
||||
/// \brief Applies callback for given node. By default callback returns false.
|
||||
/// \param node which will be used inside callback
|
||||
@ -99,13 +91,7 @@ class OPENVINO_API ModelPass : public PassBase {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::pass::ModelPass");
|
||||
~ModelPass() override;
|
||||
OPENVINO_DEPRECATED("run_on_function() method is deprecated. Please use run_on_model() instead.")
|
||||
virtual bool run_on_function(std::shared_ptr<ov::Model> m);
|
||||
virtual bool run_on_model(const std::shared_ptr<ov::Model>& m);
|
||||
|
||||
private:
|
||||
bool call_on_function{false};
|
||||
bool call_on_model{false};
|
||||
virtual bool run_on_model(const std::shared_ptr<ov::Model>& m) = 0;
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
|
@ -33,10 +33,6 @@ OV_CC_DOMAINS(ov_opset);
|
||||
*/
|
||||
#if defined(SELECTIVE_BUILD_ANALYZER)
|
||||
# define OV_OP_SCOPE(region) OV_SCOPE(ov_op, region)
|
||||
# define OV_PASS_CALLBACK(matcher) \
|
||||
openvino::itt::handle_t m_callback_handle; \
|
||||
m_callback_handle = openvino::itt::handle(matcher->get_name()); \
|
||||
OV_ITT_SCOPED_TASK(SIMPLE_ov_pass, m_callback_handle)
|
||||
# define REGISTER_OP(opset_name, op_name) \
|
||||
OV_ITT_SCOPED_TASK(SIMPLE_ov_opset, openvino::itt::handle(opset_name + "_" + op_name))
|
||||
# define INSERT_OP(opset_name, op_name, op_namespace) opset.insert<op_namespace::op_name>()
|
||||
@ -44,14 +40,12 @@ OV_CC_DOMAINS(ov_opset);
|
||||
# define OV_OP_SCOPE(region) \
|
||||
if (OV_CC_SCOPE_IS_ENABLED(OV_PP_CAT3(ov_op, _, region)) == 0) \
|
||||
throw ngraph::ngraph_error(std::string(OV_PP_TOSTRING(OV_PP_CAT3(ov_op, _, region))) + " is disabled!")
|
||||
# define OV_PASS_CALLBACK(matcher)
|
||||
# define REGISTER_OP(opset_name, op_name)
|
||||
# define INSERT_OP(opset_name, op_name, op_namespace) \
|
||||
if (OV_CC_SCOPE_IS_ENABLED(OV_PP_CAT4(ov_opset_, opset_name, _, op_name)) == 1) \
|
||||
opset.insert<op_namespace::op_name>()
|
||||
#else
|
||||
# define OV_OP_SCOPE(region) OV_ITT_SCOPED_TASK(ov::itt::domains::ov_op, OV_PP_TOSTRING(region))
|
||||
# define OV_PASS_CALLBACK(matcher)
|
||||
# define REGISTER_OP(opset_name, op_name)
|
||||
# define INSERT_OP(opset_name, op_name, op_namespace) opset.insert<op_namespace::op_name>()
|
||||
#endif
|
||||
|
@ -239,37 +239,6 @@ bool ov::pass::GraphRewrite::apply_matcher_passes(std::shared_ptr<Model> f,
|
||||
return rewritten;
|
||||
}
|
||||
|
||||
void ov::pass::GraphRewrite::add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property) {
|
||||
m_matchers.push_back(std::make_shared<MatcherPass>(
|
||||
m->get_name(),
|
||||
m,
|
||||
[m, callback](const std::shared_ptr<Node>& node) -> bool {
|
||||
NGRAPH_DEBUG << "Running matcher " << m->get_name() << " on " << node;
|
||||
if (m->match(node->output(0))) {
|
||||
NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
|
||||
OV_PASS_CALLBACK(m);
|
||||
bool status = callback(*m.get());
|
||||
// explicitly clear Matcher state because it holds pointers to matched nodes
|
||||
m->clear_state();
|
||||
return status;
|
||||
}
|
||||
m->clear_state();
|
||||
return false;
|
||||
},
|
||||
property));
|
||||
}
|
||||
|
||||
void ov::pass::GraphRewrite::add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const graph_rewrite_callback& callback) {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
// TODO: before deprecate this function, by default expect the
|
||||
// callback require static shape.
|
||||
add_matcher(m, callback, {PassProperty::REQUIRE_STATIC_SHAPE});
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
}
|
||||
|
||||
void ov::pass::GraphRewrite::set_pass_config(const std::shared_ptr<PassConfig>& rhs) {
|
||||
auto pass_config = get_pass_config();
|
||||
// We have to preserve disabled passes because in case when we register matchers inside
|
||||
|
@ -74,23 +74,6 @@ ov::pass::ModelPass::~ModelPass() = default;
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
|
||||
bool ov::pass::ModelPass::run_on_model(const std::shared_ptr<ov::Model>& m) {
|
||||
RUN_ON_MODEL_SCOPE(ModelPass);
|
||||
RunLocker locked(call_on_model);
|
||||
OPENVINO_ASSERT(!call_on_function,
|
||||
"Cycle detected. run_on_model() or run_on_function() method should be overridden.");
|
||||
bool sts = run_on_function(m);
|
||||
return sts;
|
||||
}
|
||||
|
||||
bool ov::pass::ModelPass::run_on_function(std::shared_ptr<ov::Model> m) {
|
||||
RUN_ON_FUNCTION_SCOPE(ModelPass);
|
||||
RunLocker locked(call_on_function);
|
||||
OPENVINO_ASSERT(!call_on_model, "Cycle detected. run_on_model() or run_on_function() method should be overridden.");
|
||||
bool sts = run_on_model(m);
|
||||
return sts;
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::NodePass, "ngraph::pass::NodePass", 0);
|
||||
|
||||
ngraph::pass::NodePass::~NodePass() = default;
|
||||
|
@ -107,7 +107,7 @@ TEST(GraphRewriteTest, MatcherPassCallback) {
|
||||
|
||||
Anchor anchor;
|
||||
anchor.add_matcher<TestPass>()->set_callback(get_callback());
|
||||
anchor.run_on_function(f);
|
||||
anchor.run_on_model(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
@ -118,7 +118,7 @@ TEST(GraphRewriteTest, GraphRewriteCallback) {
|
||||
Anchor anchor;
|
||||
anchor.add_matcher<TestPass>();
|
||||
anchor.set_callback(get_callback());
|
||||
anchor.run_on_function(f);
|
||||
anchor.run_on_model(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
@ -129,7 +129,7 @@ TEST(GraphRewriteTest, ManagerCallbackDeprecated) {
|
||||
pass::Manager manager;
|
||||
auto anchor = manager.register_pass<Anchor>();
|
||||
anchor->add_matcher<TestPass>();
|
||||
manager.set_callback(get_callback());
|
||||
manager.get_pass_config()->set_callback(get_callback());
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
@ -153,7 +153,7 @@ TEST(GraphRewriteTest, ManagerCallback2) {
|
||||
|
||||
pass::Manager manager;
|
||||
auto anchor = manager.register_pass<TestPass>();
|
||||
manager.set_callback(get_callback());
|
||||
manager.get_pass_config()->set_callback(get_callback());
|
||||
manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
@ -179,7 +179,7 @@ TEST(GraphRewriteTest, MatcherPassCallbackDerived) {
|
||||
|
||||
Anchor anchor;
|
||||
anchor.add_matcher<TestPass>()->set_callback(get_callback());
|
||||
anchor.run_on_function(f);
|
||||
anchor.run_on_model(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
@ -228,7 +228,7 @@ TEST(GraphRewriteTest, TypeBasedMatcherPassCallback) {
|
||||
|
||||
Anchor anchor;
|
||||
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
|
||||
anchor.run_on_function(f);
|
||||
anchor.run_on_model(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
@ -238,7 +238,7 @@ TEST(GraphRewriteTest, TypeBasedMatcherPassCallbackDerived) {
|
||||
|
||||
Anchor anchor;
|
||||
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
|
||||
anchor.run_on_function(f);
|
||||
anchor.run_on_model(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
@ -249,7 +249,7 @@ TEST(GraphRewriteTest, TypeBasedMatcherPassOrder1) {
|
||||
Anchor anchor;
|
||||
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
|
||||
anchor.add_matcher<TypeBasedTestPassDerived>()->set_callback(get_callback());
|
||||
anchor.run_on_function(f);
|
||||
anchor.run_on_model(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Relu>(f), 1);
|
||||
}
|
||||
@ -260,7 +260,7 @@ TEST(GraphRewriteTest, TypeBasedMatcherPassOrder2) {
|
||||
Anchor anchor;
|
||||
anchor.add_matcher<TypeBasedTestPassDerived>()->set_callback(get_callback());
|
||||
anchor.add_matcher<TypeBasedTestPass>()->set_callback(get_callback());
|
||||
anchor.run_on_function(f);
|
||||
anchor.run_on_model(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<opset3::Tanh>(f), 1);
|
||||
}
|
||||
|
@ -98,7 +98,7 @@ TEST(pattern, matcher_pass) {
|
||||
|
||||
pass::GraphRewrite pass;
|
||||
pass.add_matcher<TestMatcherPass>();
|
||||
pass.run_on_function(f);
|
||||
pass.run_on_model(f);
|
||||
|
||||
// Parameter->Relu->Result
|
||||
ASSERT_TRUE(f->get_ops().size() == 3);
|
||||
|
@ -56,7 +56,7 @@ class TestFunctionPass : public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
bool run_on_function(std::shared_ptr<Function> f) override {
|
||||
bool run_on_model(const std::shared_ptr<Function>& f) override {
|
||||
pass::Manager manager(get_pass_config());
|
||||
|
||||
manager.register_pass<RenameReLU, false /*disabled by default*/>();
|
||||
|
@ -37,7 +37,7 @@ namespace {
|
||||
class DummyPass : public pass::FunctionPass {
|
||||
public:
|
||||
DummyPass() {}
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> /* f */) override {
|
||||
bool run_on_model(const std::shared_ptr<ngraph::Function>& /* f */) override {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
@ -111,9 +111,23 @@ public:
|
||||
};
|
||||
|
||||
auto m = make_shared<TestMatcher>(make_shared<op::v1::Multiply>(pattern, iconst1));
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
auto match_pass = std::make_shared<ov::pass::MatcherPass>(
|
||||
m->get_name(),
|
||||
m,
|
||||
[m, callback](const std::shared_ptr<Node>& node) -> bool {
|
||||
NGRAPH_DEBUG << "Running matcher " << m->get_name() << " on " << node;
|
||||
if (std::dynamic_pointer_cast<ov::pass::pattern::Matcher>(m)->match(node->output(0))) {
|
||||
NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
|
||||
bool status = callback(*m.get());
|
||||
// explicitly clear Matcher state because it holds pointers to matched nodes
|
||||
m->clear_state();
|
||||
return status;
|
||||
}
|
||||
m->clear_state();
|
||||
return false;
|
||||
},
|
||||
ov::pass::PassProperty::REQUIRE_STATIC_SHAPE);
|
||||
this->add_matcher(match_pass);
|
||||
}
|
||||
|
||||
void construct_add_zero() {
|
||||
@ -156,9 +170,23 @@ public:
|
||||
|
||||
auto add = make_shared<op::v1::Add>(pattern, iconst0);
|
||||
auto m = make_shared<TestMatcher>(add);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
auto match_pass = std::make_shared<ov::pass::MatcherPass>(
|
||||
m->get_name(),
|
||||
m,
|
||||
[m, callback](const std::shared_ptr<Node>& node) -> bool {
|
||||
NGRAPH_DEBUG << "Running matcher " << m->get_name() << " on " << node;
|
||||
if (std::dynamic_pointer_cast<ov::pass::pattern::Matcher>(m)->match(node->output(0))) {
|
||||
NGRAPH_DEBUG << "Matcher " << m->get_name() << " matched " << node;
|
||||
bool status = callback(*m.get());
|
||||
// explicitly clear Matcher state because it holds pointers to matched nodes
|
||||
m->clear_state();
|
||||
return status;
|
||||
}
|
||||
m->clear_state();
|
||||
return false;
|
||||
},
|
||||
ov::pass::PassProperty::REQUIRE_STATIC_SHAPE);
|
||||
this->add_matcher(match_pass);
|
||||
}
|
||||
|
||||
TestGraphRewrite() : GraphRewrite() {
|
||||
|
@ -39,9 +39,9 @@ public:
|
||||
std::vector<std::pair<ngraph::element::Type, std::vector<std::uint8_t>>> CalculateRefs() override {
|
||||
// Convert the second input constant precision to i64 to run the reference function
|
||||
if (ngraph::element::Type_t::i8 == secondConstantType) {
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::i8, ngraph::element::Type_t::i64>().run_on_function(functionRefs);
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::i8, ngraph::element::Type_t::i64>().run_on_model(functionRefs);
|
||||
} else if (ngraph::element::Type_t::bf16 == secondConstantType) {
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::bf16, ngraph::element::Type_t::i64>().run_on_function(functionRefs);
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::bf16, ngraph::element::Type_t::i64>().run_on_model(functionRefs);
|
||||
}
|
||||
return LayerTestsUtils::LayerTestsCommon::CalculateRefs();
|
||||
}
|
||||
|
@ -203,8 +203,8 @@ public:
|
||||
using ngraph::pass::ConvertPrecision;
|
||||
ConcatConvSumInPlaceTest::SetUp();
|
||||
functionRefs = function->clone();
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::i8, ngraph::element::Type_t::f32>().run_on_function(functionRefs);
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::u8, ngraph::element::Type_t::f32>().run_on_function(functionRefs);
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::i8, ngraph::element::Type_t::f32>().run_on_model(functionRefs);
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::u8, ngraph::element::Type_t::f32>().run_on_model(functionRefs);
|
||||
functionRefs->validate_nodes_and_infer_types();
|
||||
}
|
||||
};
|
||||
|
@ -681,7 +681,7 @@ TEST_P(OVExecutableNetworkBaseTest, precisionsAsInOriginalIR) {
|
||||
auto filePrefix = CommonTestUtils::generateTestFilePrefix();
|
||||
const std::string m_out_xml_path_1 = filePrefix + "precisionsAsInOriginalIR.xml";
|
||||
const std::string m_out_bin_path_1 = filePrefix + "precisionsAsInOriginalIR.bin";
|
||||
ov::pass::Serialize(m_out_xml_path_1, m_out_bin_path_1).run_on_function(function);
|
||||
ov::pass::Serialize(m_out_xml_path_1, m_out_bin_path_1).run_on_model(function);
|
||||
|
||||
ov::CompiledModel execNet;
|
||||
EXPECT_NO_THROW(execNet = core->compile_model(m_out_xml_path_1, target_device, configuration));
|
||||
|
@ -62,7 +62,7 @@ void AddTransformation::SetUp() {
|
||||
precision, inputShape, param.broadcast,
|
||||
param.fakeQuantize1, param.fakeQuantize2);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
TEST_P(AddTransformation, CompareWithRefImpl) {
|
||||
|
@ -67,7 +67,7 @@ void ElementwiseBranchSelectionTransformation::SetUp() {
|
||||
param.branch2.fakeQuantizeAfter,
|
||||
param.fakeQuantizeAfter);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
void ElementwiseBranchSelectionTransformation::Run() {
|
||||
|
@ -40,7 +40,7 @@ void FakeQuantizeAndAvgPoolTransformation::SetUp() {
|
||||
inputShape,
|
||||
fakeQuantize);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
TEST_P(FakeQuantizeAndAvgPoolTransformation, CompareWithRefImpl) {
|
||||
|
@ -39,7 +39,7 @@ void FakeQuantizeAndMaxPoolTransformation::SetUp() {
|
||||
inputShape,
|
||||
fakeQuantize);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
TEST_P(FakeQuantizeAndMaxPoolTransformation, CompareWithRefImpl) {
|
||||
|
@ -44,7 +44,7 @@ void FakeQuantizePrecisionSelectionTransformation::SetUp() {
|
||||
testValues.actual.fakeQuantizeOnWeights
|
||||
});
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
TEST_P(FakeQuantizePrecisionSelectionTransformation, CompareWithRefImpl) {
|
||||
|
@ -49,7 +49,7 @@ void FakeQuantizeTransformation::SetUp() {
|
||||
testParams.fakequantize,
|
||||
true);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
void FakeQuantizeTransformation::Run() {
|
||||
|
@ -39,7 +39,7 @@ void FuseFakeQuantizeAndScaleShiftTransformation::SetUp() {
|
||||
inputShape,
|
||||
fakeQuantizeOnData);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
TEST_P(FuseFakeQuantizeAndScaleShiftTransformation, CompareWithRefImpl) {
|
||||
|
@ -46,7 +46,7 @@ void FuseFakeQuantizeTransformation::SetUp() {
|
||||
testValues.actual.precisionAfterDequantization,
|
||||
testValues.actual.fakeQuantizeOnData);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
TEST_P(FuseFakeQuantizeTransformation, CompareWithRefImpl) {
|
||||
|
@ -35,7 +35,7 @@ void FuseMultiplyToFakeQuantizeTransformation::SetUp() {
|
||||
testValues.actual.fakeQuantizeOnData,
|
||||
testValues.actual.dequantization);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
TEST_P(FuseMultiplyToFakeQuantizeTransformation, CompareWithRefImpl) {
|
||||
|
@ -35,7 +35,7 @@ void FuseSubtractToFakeQuantizeTransformation::SetUp() {
|
||||
testValues.actual.fakeQuantizeOnData,
|
||||
testValues.actual.dequantization);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
TEST_P(FuseSubtractToFakeQuantizeTransformation, CompareWithRefImpl) {
|
||||
|
@ -71,7 +71,7 @@ void MatMulTransformation::SetUp() {
|
||||
testValues.inputShape2,
|
||||
testValues.fqOnData2);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
void MatMulTransformation::Run() {
|
||||
|
@ -70,7 +70,7 @@ void MatMulWithConstantTransformation::SetUp() {
|
||||
testValues.fqOnWeights,
|
||||
testValues.deqOnWeights);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
void MatMulWithConstantTransformation::Run() {
|
||||
|
@ -66,7 +66,7 @@ void MultiplyTransformation::SetUp() {
|
||||
param.fakeQuantizeAfter,
|
||||
param.secondInputIsConstant);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
void MultiplyTransformation::Run() {
|
||||
|
@ -54,7 +54,7 @@ void PReluTransformation::SetUp() {
|
||||
|
||||
function = ngraph::builder::subgraph::PReluFunction::getOriginal(inputShape, precision, testValues.fakeQuantize);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
TEST_P(PReluTransformation, CompareWithRefImpl) {
|
||||
|
@ -54,7 +54,7 @@ void ReluTransformation::SetUp() {
|
||||
|
||||
function = ngraph::builder::subgraph::ReluFunction::getOriginal(inputShape, precision, testValues.fakeQuantize);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
TEST_P(ReluTransformation, CompareWithRefImpl) {
|
||||
|
@ -75,7 +75,7 @@ void SqueezeTransformation::SetUp() {
|
||||
squeezeParam.fakeQuantize,
|
||||
squeezeParam.squeezeAxes);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
TEST_P(SqueezeTransformation, CompareWithRefImpl) {
|
||||
|
@ -75,7 +75,7 @@ void UnsqueezeTransformation::SetUp() {
|
||||
unsqueezeParam.fakeQuantize,
|
||||
unsqueezeParam.unsqueezeAxes);
|
||||
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
}
|
||||
|
||||
TEST_P(UnsqueezeTransformation, CompareWithRefImpl) {
|
||||
|
@ -64,8 +64,8 @@ namespace snippets {
|
||||
"CodegenGelu");
|
||||
|
||||
if (useSubgraph) {
|
||||
ov::pass::InitNodeInfo().run_on_function(function);
|
||||
ngraph::pass::ConstantFolding().run_on_function(function);
|
||||
ov::pass::InitNodeInfo().run_on_model(function);
|
||||
ngraph::pass::ConstantFolding().run_on_model(function);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -417,8 +417,8 @@ void LayerTestsCommon::Infer() {
|
||||
}
|
||||
|
||||
void LayerTestsCommon::ConvertRefsParams() {
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::f16, ngraph::element::Type_t::f32>().run_on_function(functionRefs);
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::bf16, ngraph::element::Type_t::f32>().run_on_function(functionRefs);
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::f16, ngraph::element::Type_t::f32>().run_on_model(functionRefs);
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::bf16, ngraph::element::Type_t::f32>().run_on_model(functionRefs);
|
||||
}
|
||||
|
||||
std::vector<std::pair<ngraph::element::Type, std::vector<std::uint8_t>>> LayerTestsCommon::CalculateRefs() {
|
||||
|
@ -82,7 +82,7 @@ void RandomUniformLayerTest::SetUp() {
|
||||
|
||||
void RandomUniformLayerTest::ConvertRefsParams() {
|
||||
// we shouldn't use default conversion from f16 to f32
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::bf16, ngraph::element::Type_t::f32>().run_on_function(
|
||||
ngraph::pass::ConvertPrecision<ngraph::element::Type_t::bf16, ngraph::element::Type_t::f32>().run_on_model(
|
||||
functionRefs);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user