Remove deprecated methods usage from transformation library (#3881)

* Remove deprecated methods usage from transformation library

* graph_rewrite_callback -> matcher_pass_callback

* Clean-up legacy library from deprecated methods usage

* Update func tests
This commit is contained in:
Gleb Kazantaev 2021-01-18 06:25:58 +03:00 committed by GitHub
parent aa73eb2424
commit 99b83e9238
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 199 additions and 202 deletions

View File

@ -19,15 +19,10 @@ class INFERENCE_ENGINE_API_CLASS(ConvertMulAddToScaleShiftOrPower);
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvertMulAddToScaleShiftOrPower: public ngraph::pass::GraphRewrite {
class ngraph::pass::ConvertMulAddToScaleShiftOrPower: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertMulAddToScaleShiftOrPower() : GraphRewrite() {
convert_mul_add_to_scaleshift_or_power();
}
private:
void convert_mul_add_to_scaleshift_or_power();
ConvertMulAddToScaleShiftOrPower();
};
enum class CONVERSION_RESULT {

View File

@ -64,8 +64,8 @@ bool convert_to_eltwise(std::shared_ptr<T> & node,
}
template <typename T>
ngraph::graph_rewrite_callback get_callback() {
ngraph::graph_rewrite_callback callback = [](ngraph::pattern::Matcher& m) {
ngraph::matcher_pass_callback get_callback() {
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
static_assert(std::is_same<T, ngraph::opset1::Add>() || std::is_same<T, ngraph::opset1::Subtract>() || std::is_same<T, ngraph::opset1::Multiply>(),
"Unsupported template parameter. Only Add or Multiply allowed!");

View File

@ -15,20 +15,29 @@ namespace ngraph {
namespace pass {
class INFERENCE_ENGINE_API_CLASS(ConvertPriorBox);
class INFERENCE_ENGINE_API_CLASS(ConvertPriorBoxToLegacy);
class INFERENCE_ENGINE_API_CLASS(ConvertPriorBoxClusteredToLegacy);
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvertPriorBoxToLegacy : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertPriorBoxToLegacy();
};
class ngraph::pass::ConvertPriorBoxClusteredToLegacy : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertPriorBoxClusteredToLegacy();
};
class ngraph::pass::ConvertPriorBox: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
ConvertPriorBox() : GraphRewrite() {
convert_prior_box();
convert_prior_box_clustered();
ConvertPriorBox() {
add_matcher<ngraph::pass::ConvertPriorBoxToLegacy>();
add_matcher<ngraph::pass::ConvertPriorBoxClusteredToLegacy>();
}
private:
void convert_prior_box();
void convert_prior_box_clustered();
};
};

View File

@ -25,7 +25,7 @@ class INFERENCE_ENGINE_API_CLASS(Reshape1DMaxPool);
class ngraph::pass::Reshape1DOps: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
Reshape1DOps() : GraphRewrite() {
Reshape1DOps() {
add_matcher<ngraph::pass::Reshape1DConvolution>();
add_matcher<ngraph::pass::Reshape1DAvgPool>();
add_matcher<ngraph::pass::Reshape1DMaxPool>();

View File

@ -30,20 +30,8 @@ class INFERENCE_ENGINE_API_CLASS(ReshapeFullyConnectedFusion);
} // namespace pass
} // namespace ngraph
class ngraph::pass::ReshapeFullyConnectedFusion : public ngraph::pass::GraphRewrite {
class ngraph::pass::ReshapeFullyConnectedFusion : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ReshapeFullyConnectedFusion() : GraphRewrite() {
construct_reshape_fc();
}
bool run_on_function(std::shared_ptr<ngraph::Function> f) override {
if (!ngraph::op::util::has_op_with_type<ngraph::op::FakeQuantize>(f)) {
return GraphRewrite::run_on_function(f);
}
return false;
}
private:
void construct_reshape_fc();
ReshapeFullyConnectedFusion();
};

View File

@ -57,7 +57,7 @@ CONVERSION_RESULT check_constant(const std::shared_ptr<ngraph::opset1::Constant>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMulAddToScaleShiftOrPower, "ConvertMulAddToScaleShiftOrPower", 0);
void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshift_or_power() {
ngraph::pass::ConvertMulAddToScaleShiftOrPower::ConvertMulAddToScaleShiftOrPower() {
auto data_batch = std::make_shared<pattern::op::Label>(element::f32, Shape {1});
auto weights = std::make_shared<ngraph::opset1::Constant>(element::f32, Shape {1}, std::vector<float> {0});
@ -66,7 +66,7 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi
auto mul = std::make_shared<ngraph::opset1::Multiply>(data_batch, weights);
auto add = std::make_shared<ngraph::opset1::Add>(mul, bias);
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto add_node = ngraph::as_type_ptr<ngraph::opset1::Add>(m.get_match_root());
if (!add_node) {
@ -209,8 +209,6 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(add, "CPUFusion.MulAddToScaleShiftOrPower");
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
NGRAPH_SUPPRESS_DEPRECATED_END
auto m = std::make_shared<ngraph::pattern::Matcher>(add, "MulAddToScaleShiftOrPower");
register_matcher(m, callback);
}

View File

@ -6,7 +6,7 @@
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMulOrAddFinally, "ConvertMulOrAddFinally", 0);
ngraph::pass::ConvertMulOrAddFinally::ConvertMulOrAddFinally() : GraphRewrite() {
ngraph::pass::ConvertMulOrAddFinally::ConvertMulOrAddFinally() {
convert_mul_or_add_finally<ngraph::opset1::Add>();
convert_mul_or_add_finally<ngraph::opset1::Subtract>();
convert_mul_or_add_finally<ngraph::opset1::Multiply>();

View File

@ -22,7 +22,7 @@ ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE::ConvertNormalizeL2WithMulT
auto normalize = std::make_shared<ngraph::op::NormalizeL2>(input_0, axis, 0.0f, ngraph::op::EpsMode::ADD);
auto mul = std::make_shared<ngraph::opset1::Multiply> (normalize, input_1);
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto mul = std::dynamic_pointer_cast<ngraph::opset1::Multiply> (m.get_match_root());
if (!mul) return false;

View File

@ -140,7 +140,10 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
// List of final conversion transformations that must to be executed
// after previous group of transformations
manager.register_pass<ngraph::pass::ReshapeFullyConnectedFusion>();
if (!ngraph::op::util::has_op_with_type<ngraph::op::FakeQuantize>(f)) {
manager.register_pass<ngraph::pass::ReshapeFullyConnectedFusion>();
}
manager.register_pass<ngraph::pass::ConvertNormalizeL2ToLegacyMatcher>();
manager.register_pass<ngraph::pass::ConvertMulAddToScaleShiftOrPower>();
manager.register_pass<ngraph::pass::ConvertMulOrAddFinally>();

View File

@ -15,8 +15,10 @@
#include <ngraph/rt_info.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertPriorBox, "ConvertPriorBox", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertPriorBoxToLegacy, "ConvertPriorBoxToLegacy", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertPriorBoxClusteredToLegacy, "ConvertPriorBoxClusteredToLegacy", 0);
void ngraph::pass::ConvertPriorBox::convert_prior_box() {
ngraph::pass::ConvertPriorBoxToLegacy::ConvertPriorBoxToLegacy() {
auto data = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
auto axes = ngraph::opset1::Constant::create(element::i64, Shape{1}, {0});
auto image = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
@ -35,7 +37,7 @@ void ngraph::pass::ConvertPriorBox::convert_prior_box() {
auto prior_box = std::make_shared<ngraph::opset1::PriorBox>(data, image, attr);
auto unsqueeze = std::make_shared<ngraph::opset1::Unsqueeze> (prior_box, axes);
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto unsqueeze = std::dynamic_pointer_cast<ngraph::opset1::Unsqueeze> (m.get_match_root());
if (!unsqueeze) {
return false;
@ -153,13 +155,11 @@ void ngraph::pass::ConvertPriorBox::convert_prior_box() {
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(unsqueeze, "CPUFusion.ConvertPriorBoxToPriorBoxIE");
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
NGRAPH_SUPPRESS_DEPRECATED_END
auto m = std::make_shared<ngraph::pattern::Matcher>(unsqueeze, "ConvertPriorBoxToLegacy");
register_matcher(m, callback);
}
void ngraph::pass::ConvertPriorBox::convert_prior_box_clustered() {
ngraph::pass::ConvertPriorBoxClusteredToLegacy::ConvertPriorBoxClusteredToLegacy() {
auto data = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
auto axes = ngraph::opset1::Constant::create(element::i64, Shape{1}, {0});
auto image = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
@ -176,7 +176,7 @@ void ngraph::pass::ConvertPriorBox::convert_prior_box_clustered() {
auto prior_box = std::make_shared<ngraph::opset1::PriorBoxClustered>(data, image, attr);
auto unsqueeze = std::make_shared<ngraph::opset1::Unsqueeze> (prior_box, axes);
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto unsqueeze = std::dynamic_pointer_cast<ngraph::opset1::Unsqueeze> (m.get_match_root());
if (!unsqueeze) {
return false;
@ -293,8 +293,6 @@ void ngraph::pass::ConvertPriorBox::convert_prior_box_clustered() {
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(unsqueeze, "CPUFusion.ConvertPriorBoxClusteredToPriorBoxClusteredIE");
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
NGRAPH_SUPPRESS_DEPRECATED_END
auto m = std::make_shared<ngraph::pattern::Matcher>(unsqueeze, "ConvertPriorBoxClusteredToLegacy");
register_matcher(m, callback);
}

View File

@ -15,13 +15,13 @@
NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeFullyConnectedFusion, "ReshapeFullyConnectedFusion", 0);
void ngraph::pass::ReshapeFullyConnectedFusion::construct_reshape_fc() {
ngraph::pass::ReshapeFullyConnectedFusion::ReshapeFullyConnectedFusion() {
auto m_reshape = pattern::wrap_type<opset1::Reshape>(pattern::has_static_shape());
auto m_fc = pattern::wrap_type<op::FullyConnected>({m_reshape,
pattern::any_input(),
pattern::any_input()});
ngraph::graph_rewrite_callback callback = [=](pattern::Matcher &m) {
ngraph::matcher_pass_callback callback = [=](pattern::Matcher &m) {
auto & pattern_to_output = m.get_pattern_value_map();
auto fc = pattern_to_output[m_fc].get_node_shared_ptr();
auto reshape = pattern_to_output[m_reshape].get_node_shared_ptr();
@ -52,7 +52,5 @@ void ngraph::pass::ReshapeFullyConnectedFusion::construct_reshape_fc() {
};
auto m = std::make_shared<ngraph::pattern::Matcher>(m_fc, "ReshapeFullyConnectedFusion");
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
NGRAPH_SUPPRESS_DEPRECATED_END
register_matcher(m, callback);
}

View File

@ -22,7 +22,7 @@ ngraph::pass::ReshapeFullyConnected::ReshapeFullyConnected() {
pattern::any_input()},
pattern::has_static_shape());
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto fc = std::dynamic_pointer_cast<ngraph::op::FullyConnected> (m.get_match_root());
if (!fc || transformation_callback(fc)) {
return false;

View File

@ -34,16 +34,6 @@ class TRANSFORMATIONS_API DeconvAddFusion;
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvFusion: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
ConvFusion() : GraphRewrite() {
add_matcher<ngraph::pass::ConvAddFusion>();
add_matcher<ngraph::pass::ConvMultiplyFusion>();
add_matcher<ngraph::pass::DeconvAddFusion>();
}
};
class ngraph::pass::ConvAddFusion: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
@ -60,4 +50,14 @@ class ngraph::pass::DeconvAddFusion: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
DeconvAddFusion();
};
class ngraph::pass::ConvFusion: public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
ConvFusion() {
add_matcher<ngraph::pass::ConvAddFusion>();
add_matcher<ngraph::pass::ConvMultiplyFusion>();
add_matcher<ngraph::pass::DeconvAddFusion>();
}
};

View File

@ -41,13 +41,8 @@ namespace pass {
*
*/
class ngraph::pass::DepthToSpaceFusion: public ngraph::pass::GraphRewrite {
class ngraph::pass::DepthToSpaceFusion: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
DepthToSpaceFusion() : GraphRewrite() {
depth_to_space_fusion();
}
private:
void depth_to_space_fusion();
DepthToSpaceFusion();
};

View File

@ -19,13 +19,8 @@ class TRANSFORMATIONS_API RemoveFilteringBoxesBySize;
} // namespace pass
} // namespace ngraph
class ngraph::pass::RemoveFilteringBoxesBySize: public ngraph::pass::GraphRewrite {
class ngraph::pass::RemoveFilteringBoxesBySize: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
RemoveFilteringBoxesBySize() : GraphRewrite() {
remove_filtering_boxes_by_size();
}
private:
void remove_filtering_boxes_by_size();
RemoveFilteringBoxesBySize();
};

View File

@ -23,13 +23,8 @@ class TRANSFORMATIONS_API ConvertScatterElementsToScatter;
* @ingroup ie_transformation_common_api
* @brief ConvertScatterElementsToScatter convert opset3::ScatterElementsUpdate to opset3::ScatterUpdate.
*/
class ngraph::pass::ConvertScatterElementsToScatter: public ngraph::pass::GraphRewrite {
class ngraph::pass::ConvertScatterElementsToScatter: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertScatterElementsToScatter() : GraphRewrite() {
convert_scatter_elements_to_scatter();
}
private:
void convert_scatter_elements_to_scatter();
ConvertScatterElementsToScatter();
};

View File

@ -19,13 +19,8 @@ class TRANSFORMATIONS_API ConvertShapeOf3;
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvertShapeOf3: public ngraph::pass::GraphRewrite {
class ngraph::pass::ConvertShapeOf3: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertShapeOf3() : GraphRewrite() {
convert_shapeof3();
}
private:
void convert_shapeof3();
ConvertShapeOf3();
};

View File

@ -19,13 +19,8 @@ class TRANSFORMATIONS_API ConvertShuffleChannels3;
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvertShuffleChannels3: public ngraph::pass::GraphRewrite {
class ngraph::pass::ConvertShuffleChannels3: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertShuffleChannels3() : GraphRewrite() {
convert_shuffle_channels3();
}
private:
void convert_shuffle_channels3();
ConvertShuffleChannels3();
};

View File

@ -19,13 +19,8 @@ class TRANSFORMATIONS_API ConvertTopK3;
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvertTopK3: public ngraph::pass::GraphRewrite {
class ngraph::pass::ConvertTopK3: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertTopK3() : GraphRewrite() {
convert_topk3();
}
private:
void convert_topk3();
ConvertTopK3();
};

View File

@ -65,8 +65,8 @@ bool IsConvInLowPrecision(const std::shared_ptr<Conv>& conv) {
}
template <class Conv>
ngraph::graph_rewrite_callback get_callback() {
ngraph::graph_rewrite_callback callback = [](ngraph::pattern::Matcher &m) {
ngraph::matcher_pass_callback get_callback() {
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) {
auto eltwise = m.get_match_root();
std::shared_ptr<ngraph::opset1::Constant> m_const;

View File

@ -83,7 +83,7 @@ bool check_depth_first(const ngraph::Shape& shape_input, const ngraph::Shape& sh
NGRAPH_RTTI_DEFINITION(ngraph::pass::DepthToSpaceFusion, "DepthToSpaceFusion", 0);
void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() {
ngraph::pass::DepthToSpaceFusion::DepthToSpaceFusion() {
auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
auto input2 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
@ -92,7 +92,7 @@ void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() {
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, input2);
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, input3, false);
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto reshape_after = std::dynamic_pointer_cast<ngraph::opset3::Reshape>(m.get_match_root());
if (!reshape_after) {
return false;
@ -160,7 +160,5 @@ void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() {
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_after, "DepthToSpaceFusion");
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
NGRAPH_SUPPRESS_DEPRECATED_END
register_matcher(m, callback);
}

View File

@ -21,7 +21,7 @@ ngraph::pass::MishFusion::MishFusion() {
auto tanh = std::make_shared<ngraph::opset4::Tanh>(log);
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, tanh);
ngraph::graph_rewrite_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
auto & pattern_to_output = m.get_pattern_value_map();
auto exp_input = pattern_to_output.at(input);

View File

@ -28,7 +28,7 @@ ngraph::pass::NormalizeL2FusionWithMax::NormalizeL2FusionWithMax() {
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);
ngraph::graph_rewrite_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
const auto data_input = pattern_to_output.at(input);
@ -81,7 +81,7 @@ ngraph::pass::NormalizeL2FusionWithAdd::NormalizeL2FusionWithAdd() {
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);
ngraph::graph_rewrite_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
const auto data_input = pattern_to_output.at(input);
@ -117,5 +117,5 @@ ngraph::pass::NormalizeL2FusionWithAdd::NormalizeL2FusionWithAdd() {
};
auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "NormalizeL2FusionWithMax");
register_matcher(m, matcher_pass_callback);
register_matcher(m, callback);
}

View File

@ -12,7 +12,7 @@
NGRAPH_RTTI_DEFINITION(ngraph::pass::RemoveFilteringBoxesBySize, "RemoveFilteringBoxesBySize", 0);
void ngraph::pass::RemoveFilteringBoxesBySize::remove_filtering_boxes_by_size() {
ngraph::pass::RemoveFilteringBoxesBySize::RemoveFilteringBoxesBySize() {
// variadic split
auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1000, 4});
auto sizes = opset3::Constant::create(element::i64, Shape{4}, std::vector<int64_t >({1, 1, 1, 1}));
@ -79,7 +79,7 @@ void ngraph::pass::RemoveFilteringBoxesBySize::remove_filtering_boxes_by_size()
auto cast = std::make_shared<ngraph::opset3::Convert>(squeeze_3, ngraph::element::i64);
ngraph::graph_rewrite_callback callback = [data](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [data](pattern::Matcher& m) {
auto start = opset3::Constant::create(element::i64, Shape{}, std::vector<int64_t >({0}));
auto step = opset3::Constant::create(element::i64, Shape{}, std::vector<int64_t >({1}));
@ -104,7 +104,5 @@ void ngraph::pass::RemoveFilteringBoxesBySize::remove_filtering_boxes_by_size()
};
auto m = std::make_shared<ngraph::pattern::Matcher>(cast, "RemoveFilteringBoxesBySize");
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
NGRAPH_SUPPRESS_DEPRECATED_END
register_matcher(m, callback);
}

View File

@ -15,7 +15,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertBatchToSpace, "ConvertBatchToSpace",
void ngraph::pass::ConvertBatchToSpace::convert_batch_to_space() {
auto batch_to_space = ngraph::pattern::wrap_type<ngraph::opset3::BatchToSpace>();
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto batch_to_space = std::dynamic_pointer_cast<ngraph::opset3::BatchToSpace> (m.get_match_root());
if (!batch_to_space) {
return false;
@ -127,7 +127,7 @@ void ngraph::pass::ConvertBatchToSpace::convert_batch_to_space() {
void ngraph::pass::ConvertBatchToSpace::convert_batch_to_space_by_elements() {
auto batch_to_space = ngraph::pattern::wrap_type<ngraph::opset3::BatchToSpace>();
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto batch_to_space = std::dynamic_pointer_cast<ngraph::opset3::BatchToSpace> (m.get_match_root());
if (!batch_to_space) {
return false;

View File

@ -16,7 +16,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertDepthToSpace, "ConvertDepthToSpace",
ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() {
auto dts_node = ngraph::pattern::wrap_type<ngraph::opset1::DepthToSpace>({pattern::any_input(pattern::has_static_shape())});
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
if (!dts_node || transformation_callback(dts_node)) {
return false;

View File

@ -16,7 +16,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertDivide, "ConvertDivide", 0);
ngraph::pass::ConvertDivide::ConvertDivide() {
auto div = ngraph::pattern::wrap_type<ngraph::opset1::Divide>();
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto div = std::dynamic_pointer_cast<ngraph::opset1::Divide> (m.get_match_root());
// We can not apply this transformation in case with integer input data type
if (!div || div->input(0).get_element_type().is_integral()) {

View File

@ -16,7 +16,7 @@ ngraph::pass::ConvertGELU::ConvertGELU() {
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto gelu = std::make_shared<ngraph::opset2::Gelu>(input);
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto gelu = std::dynamic_pointer_cast<ngraph::opset2::Gelu>(m.get_match_root());
if (!gelu || transformation_callback(gelu))
return false;

View File

@ -16,7 +16,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMinimum, "ConvertMinimum", 0);
ngraph::pass::ConvertMinimum::ConvertMinimum() {
auto minimum = ngraph::pattern::wrap_type<opset1::Minimum>();
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto minimum = std::dynamic_pointer_cast<ngraph::opset1::Minimum> (m.get_match_root());
if (!minimum || transformation_callback(minimum)) {
return false;

View File

@ -16,7 +16,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertNegative, "ConvertNegative", 0);
ngraph::pass::ConvertNegative::ConvertNegative() {
auto neg = ngraph::pattern::wrap_type<ngraph::opset1::Negative>();
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto neg = std::dynamic_pointer_cast<ngraph::opset1::Negative> (m.get_match_root());
if (!neg) {
return false;

View File

@ -14,7 +14,7 @@
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertScatterElementsToScatter, "ConvertScatterElementsToScatter", 0);
void ngraph::pass::ConvertScatterElementsToScatter::convert_scatter_elements_to_scatter() {
ngraph::pass::ConvertScatterElementsToScatter::ConvertScatterElementsToScatter() {
auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
auto indices = std::make_shared<pattern::op::Label>(element::i64, Shape{1});
auto updates = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
@ -25,7 +25,7 @@ void ngraph::pass::ConvertScatterElementsToScatter::convert_scatter_elements_to_
auto scatter = std::make_shared<ngraph::opset3::ScatterElementsUpdate>(data, broadcast, updates, axis);
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto scatter = m.get_match_root();
auto broadcast = scatter->input_value(1).get_node_shared_ptr();
auto axis_const = std::dynamic_pointer_cast<ngraph::opset3::Constant>(scatter->input_value(3).get_node_shared_ptr());
@ -209,8 +209,6 @@ void ngraph::pass::ConvertScatterElementsToScatter::convert_scatter_elements_to_
};
auto m = std::make_shared<ngraph::pattern::Matcher>(scatter, "ConvertScatterElementsToScatter");
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
NGRAPH_SUPPRESS_DEPRECATED_END
register_matcher(m, callback);
}

View File

@ -10,14 +10,14 @@
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertShapeOf3, "ConvertShapeOf3", 0);
void ngraph::pass::ConvertShapeOf3::convert_shapeof3() {
auto input = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
auto shapeof = std::make_shared<ngraph::opset3::ShapeOf>(input);
ngraph::pass::ConvertShapeOf3::ConvertShapeOf3() {
auto shapeof = pattern::wrap_type<ngraph::opset3::ShapeOf>();
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto shapeof = std::dynamic_pointer_cast<ngraph::opset3::ShapeOf> (m.get_match_root());
if (!shapeof) {
return false;
@ -43,7 +43,5 @@ void ngraph::pass::ConvertShapeOf3::convert_shapeof3() {
};
auto m = std::make_shared<ngraph::pattern::Matcher>(shapeof, "ConvertShapeOf3");
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
NGRAPH_SUPPRESS_DEPRECATED_END
register_matcher(m, callback);
}

View File

@ -15,11 +15,11 @@ using namespace ngraph;
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertShuffleChannels3, "ConvertShuffleChannels3", 0);
void ngraph::pass::ConvertShuffleChannels3::convert_shuffle_channels3() {
ngraph::pass::ConvertShuffleChannels3::ConvertShuffleChannels3() {
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto shuffle_channels = std::make_shared<::opset3::ShuffleChannels>(input);
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher &m) {
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
auto shuffle_channels = std::dynamic_pointer_cast<::opset3::ShuffleChannels>(m.get_match_root());
if (!shuffle_channels || transformation_callback(shuffle_channels)) {
return false;
@ -98,7 +98,5 @@ void ngraph::pass::ConvertShuffleChannels3::convert_shuffle_channels3() {
};
auto m = std::make_shared<ngraph::pattern::Matcher>(shuffle_channels, "ConvertShuffleChannels3");
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
NGRAPH_SUPPRESS_DEPRECATED_END
register_matcher(m, callback);
}

View File

@ -16,7 +16,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertSpaceToDepth, "ConvertSpaceToDepth",
ngraph::pass::ConvertSpaceToDepth::ConvertSpaceToDepth() {
auto dts = ngraph::pattern::wrap_type<ngraph::opset1::SpaceToDepth>({pattern::any_input(pattern::has_static_shape())});
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto std_node = std::dynamic_pointer_cast<ngraph::opset1::SpaceToDepth> (m.get_match_root());
if (!std_node || transformation_callback(std_node)) {
return false;

View File

@ -16,7 +16,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertSubtract, "ConvertSubtract", 0);
ngraph::pass::ConvertSubtract::ConvertSubtract() {
auto sub = ngraph::pattern::wrap_type<ngraph::opset1::Subtract>();
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto sub = std::dynamic_pointer_cast<ngraph::opset1::Subtract> (m.get_match_root());
if (!sub) {
return false;

View File

@ -12,12 +12,14 @@
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertTopK3, "ConvertTopK3", 0);
void ngraph::pass::ConvertTopK3::convert_topk3() {
auto topk = std::make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<opset3::TopK>());
ngraph::pass::ConvertTopK3::ConvertTopK3() {
auto topk = pattern::wrap_type<opset3::TopK>();
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto topk = std::dynamic_pointer_cast<ngraph::opset3::TopK> (m.get_match_root());
if (!topk) {
return false;
@ -60,7 +62,5 @@ void ngraph::pass::ConvertTopK3::convert_topk3() {
};
auto m = std::make_shared<ngraph::pattern::Matcher>(topk, "ConvertTopK3");
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
NGRAPH_SUPPRESS_DEPRECATED_END
register_matcher(m, callback);
}

View File

@ -18,6 +18,7 @@
#include <transformations/op_conversions/convert_scatter_elements_to_scatter.hpp>
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
@ -75,8 +76,10 @@ std::shared_ptr<ngraph::Function> get_reference_function(const ngraph::PartialSh
}
void test(std::shared_ptr<ngraph::Function> f, std::shared_ptr<ngraph::Function> f_ref) {
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertScatterElementsToScatter().run_on_function(f);
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertScatterElementsToScatter>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ngraph::pass::ConstantFolding().run_on_function(f);

View File

@ -12,6 +12,7 @@
#include <ngraph/opsets/opset3.hpp>
#include <transformations/op_conversions/convert_shapeof3.hpp>
#include <transformations/init_node_info.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
@ -26,8 +27,10 @@ TEST(TransformationTests, ConvertShapeOf3WithI64) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{shapeof}, ngraph::ParameterVector{input});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertShapeOf3().run_on_function(f);
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertShapeOf3>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -56,8 +59,10 @@ TEST(TransformationTests, ConvertShapeOf3WithI32) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{shapeof}, ngraph::ParameterVector{input});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertShapeOf3().run_on_function(f);
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertShapeOf3>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

View File

@ -12,6 +12,7 @@
#include <ngraph/opsets/opset3.hpp>
#include <transformations/op_conversions/convert_shuffle_channels3.hpp>
#include <transformations/init_node_info.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
@ -25,8 +26,10 @@ std::shared_ptr<ngraph::Function> buildInputGraph(int64_t axis, int64_t group, c
auto f = std::make_shared<::Function>(::NodeVector{shuffle_channels}, ::ParameterVector{input});
::pass::InitNodeInfo().run_on_function(f);
::pass::ConvertShuffleChannels3().run_on_function(f);
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertShuffleChannels3>();
manager.run_passes(f);
f->validate_nodes_and_infer_types();
return f;
}

View File

@ -15,6 +15,7 @@
#include <transformations/op_conversions/convert_topk3.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
@ -32,8 +33,11 @@ TEST(TransformationTests, ConvertTopK3I32Output0) {
// due to the 'compare_functions' limitation we will check only one output
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertTopK3().run_on_function(f);
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertTopK3>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -67,8 +71,10 @@ TEST(TransformationTests, ConvertTopK3I32Output1) {
// due to the 'compare_functions' limitation we will check only one output
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(1)}, ngraph::ParameterVector{input});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertTopK3().run_on_function(f);
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertTopK3>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -102,8 +108,10 @@ TEST(TransformationTests, ConvertTopK3I64Output0) {
// due to the 'compare_functions' limitation we will check only one output
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertTopK3().run_on_function(f);
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertTopK3>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -137,8 +145,10 @@ TEST(TransformationTests, ConvertTopK3I64Output1) {
// due to the 'compare_functions' limitation we will check only one output
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(1)}, ngraph::ParameterVector{input});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertTopK3().run_on_function(f);
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertTopK3>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

View File

@ -19,6 +19,7 @@
#include <transformations/common_optimizations/depth_to_space_fusion.hpp>
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
@ -37,14 +38,19 @@ TEST(TransformationTests, DepthToSpaceFusionDepthFirst) {
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
ngraph::pass::InitNodeInfo().run_on_function(f);
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
};
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
depth_to_space_transform.set_callback(callback);
depth_to_space_transform.run_on_function(f);
ngraph::pass::Manager manager;
auto pass_config = manager.get_pass_config();
pass_config->set_callback<ngraph::pass::DepthToSpaceFusion>(callback);
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -71,14 +77,19 @@ TEST(TransformationTests, DepthToSpaceFusionBlockFirst) {
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
ngraph::pass::InitNodeInfo().run_on_function(f);
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
};
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
depth_to_space_transform.set_callback(callback);
depth_to_space_transform.run_on_function(f);
ngraph::pass::Manager manager;
auto pass_config = manager.get_pass_config();
pass_config->set_callback<ngraph::pass::DepthToSpaceFusion>(callback);
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -105,15 +116,19 @@ TEST(TransformationTests, DepthToSpaceFusionDynamicShape) {
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0, shape_reshape_before});
ngraph::pass::InitNodeInfo().run_on_function(f);
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
};
// transformation won't be applied because of shape_reshape_before is dynamic, the graph will remain the same
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
depth_to_space_transform.set_callback(callback);
depth_to_space_transform.run_on_function(f);
ngraph::pass::Manager manager;
auto pass_config = manager.get_pass_config();
pass_config->set_callback<ngraph::pass::DepthToSpaceFusion>(callback);
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -150,15 +165,19 @@ TEST(TransformationTests, DepthToSpaceFusionSeveralConsumers) {
auto result = std::make_shared<ngraph::opset3::Result> (reshape_before);
auto result_2 = std::make_shared<ngraph::opset3::Result> (permute);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
ngraph::pass::InitNodeInfo().run_on_function(f);
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
};
// transformation won't be applied because of reshape_before has several consumers, the graph will remain the same
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
depth_to_space_transform.set_callback(callback);
depth_to_space_transform.run_on_function(f);
ngraph::pass::Manager manager;
auto pass_config = manager.get_pass_config();
pass_config->set_callback<ngraph::pass::DepthToSpaceFusion>(callback);
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}

View File

@ -132,8 +132,10 @@ public:
class MulOrAddConversionTests: public MulAddConversionTests {};
TEST_P(MulAddConversionTests, CompareFunctions) {
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertMulAddToScaleShiftOrPower().run_on_function(f);
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertMulAddToScaleShiftOrPower>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ngraph::pass::ConstantFolding().run_on_function(f);
f->validate_nodes_and_infer_types();

View File

@ -34,8 +34,10 @@ TEST(TransformationTests, ReshapeFCFusiuonTest1) {
auto fc = std::make_shared<ngraph::op::FullyConnected>(reshape, fc_weights, fc_biases, ngraph::Shape{1, 6});
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ReshapeFullyConnectedFusion().run_on_function(f);
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ReshapeFullyConnectedFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
ASSERT_EQ(f->get_ops().size(), 5);
@ -53,8 +55,10 @@ TEST(TransformationTests, ReshapeFCFusiuonTest2) {
auto fc = std::make_shared<ngraph::op::FullyConnected>(reshape, fc_weights, fc_biases, ngraph::Shape{1, 6});
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ReshapeFullyConnectedFusion().run_on_function(f);
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ReshapeFullyConnectedFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
ASSERT_EQ(f->get_ops().size(), 5);
@ -72,8 +76,10 @@ TEST(TransformationTests, ReshapeFCFusiuonTest3) {
auto fc = std::make_shared<ngraph::op::FullyConnected>(reshape, fc_weights, fc_biases, ngraph::Shape{2, 6});
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ReshapeFullyConnectedFusion().run_on_function(f);
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ReshapeFullyConnectedFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
ASSERT_EQ(f->get_ops().size(), 7);