Remove deprecated methods usage from transformation library (#3881)
* Remove deprecated methods usage from transformation library * graph_rewrite_callback -> matcher_pass_callback * Clean-up legacy library from deprecated methods usage * Update func tests
This commit is contained in:
parent
aa73eb2424
commit
99b83e9238
@ -19,15 +19,10 @@ class INFERENCE_ENGINE_API_CLASS(ConvertMulAddToScaleShiftOrPower);
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::ConvertMulAddToScaleShiftOrPower: public ngraph::pass::GraphRewrite {
|
||||
class ngraph::pass::ConvertMulAddToScaleShiftOrPower: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertMulAddToScaleShiftOrPower() : GraphRewrite() {
|
||||
convert_mul_add_to_scaleshift_or_power();
|
||||
}
|
||||
|
||||
private:
|
||||
void convert_mul_add_to_scaleshift_or_power();
|
||||
ConvertMulAddToScaleShiftOrPower();
|
||||
};
|
||||
|
||||
enum class CONVERSION_RESULT {
|
||||
|
@ -64,8 +64,8 @@ bool convert_to_eltwise(std::shared_ptr<T> & node,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ngraph::graph_rewrite_callback get_callback() {
|
||||
ngraph::graph_rewrite_callback callback = [](ngraph::pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback get_callback() {
|
||||
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
|
||||
static_assert(std::is_same<T, ngraph::opset1::Add>() || std::is_same<T, ngraph::opset1::Subtract>() || std::is_same<T, ngraph::opset1::Multiply>(),
|
||||
"Unsupported template parameter. Only Add or Multiply allowed!");
|
||||
|
||||
|
@ -15,20 +15,29 @@ namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class INFERENCE_ENGINE_API_CLASS(ConvertPriorBox);
|
||||
class INFERENCE_ENGINE_API_CLASS(ConvertPriorBoxToLegacy);
|
||||
class INFERENCE_ENGINE_API_CLASS(ConvertPriorBoxClusteredToLegacy);
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::ConvertPriorBoxToLegacy : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertPriorBoxToLegacy();
|
||||
};
|
||||
|
||||
class ngraph::pass::ConvertPriorBoxClusteredToLegacy : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertPriorBoxClusteredToLegacy();
|
||||
};
|
||||
|
||||
class ngraph::pass::ConvertPriorBox: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertPriorBox() : GraphRewrite() {
|
||||
convert_prior_box();
|
||||
convert_prior_box_clustered();
|
||||
ConvertPriorBox() {
|
||||
add_matcher<ngraph::pass::ConvertPriorBoxToLegacy>();
|
||||
add_matcher<ngraph::pass::ConvertPriorBoxClusteredToLegacy>();
|
||||
}
|
||||
|
||||
private:
|
||||
void convert_prior_box();
|
||||
|
||||
void convert_prior_box_clustered();
|
||||
};
|
@ -25,7 +25,7 @@ class INFERENCE_ENGINE_API_CLASS(Reshape1DMaxPool);
|
||||
class ngraph::pass::Reshape1DOps: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
Reshape1DOps() : GraphRewrite() {
|
||||
Reshape1DOps() {
|
||||
add_matcher<ngraph::pass::Reshape1DConvolution>();
|
||||
add_matcher<ngraph::pass::Reshape1DAvgPool>();
|
||||
add_matcher<ngraph::pass::Reshape1DMaxPool>();
|
||||
|
@ -30,20 +30,8 @@ class INFERENCE_ENGINE_API_CLASS(ReshapeFullyConnectedFusion);
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::ReshapeFullyConnectedFusion : public ngraph::pass::GraphRewrite {
|
||||
class ngraph::pass::ReshapeFullyConnectedFusion : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ReshapeFullyConnectedFusion() : GraphRewrite() {
|
||||
construct_reshape_fc();
|
||||
}
|
||||
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override {
|
||||
if (!ngraph::op::util::has_op_with_type<ngraph::op::FakeQuantize>(f)) {
|
||||
return GraphRewrite::run_on_function(f);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
void construct_reshape_fc();
|
||||
ReshapeFullyConnectedFusion();
|
||||
};
|
||||
|
@ -57,7 +57,7 @@ CONVERSION_RESULT check_constant(const std::shared_ptr<ngraph::opset1::Constant>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMulAddToScaleShiftOrPower, "ConvertMulAddToScaleShiftOrPower", 0);
|
||||
|
||||
void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshift_or_power() {
|
||||
ngraph::pass::ConvertMulAddToScaleShiftOrPower::ConvertMulAddToScaleShiftOrPower() {
|
||||
auto data_batch = std::make_shared<pattern::op::Label>(element::f32, Shape {1});
|
||||
|
||||
auto weights = std::make_shared<ngraph::opset1::Constant>(element::f32, Shape {1}, std::vector<float> {0});
|
||||
@ -66,7 +66,7 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(data_batch, weights);
|
||||
auto add = std::make_shared<ngraph::opset1::Add>(mul, bias);
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto add_node = ngraph::as_type_ptr<ngraph::opset1::Add>(m.get_match_root());
|
||||
|
||||
if (!add_node) {
|
||||
@ -209,8 +209,6 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(add, "CPUFusion.MulAddToScaleShiftOrPower");
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(add, "MulAddToScaleShiftOrPower");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
@ -6,7 +6,7 @@
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMulOrAddFinally, "ConvertMulOrAddFinally", 0);
|
||||
|
||||
ngraph::pass::ConvertMulOrAddFinally::ConvertMulOrAddFinally() : GraphRewrite() {
|
||||
ngraph::pass::ConvertMulOrAddFinally::ConvertMulOrAddFinally() {
|
||||
convert_mul_or_add_finally<ngraph::opset1::Add>();
|
||||
convert_mul_or_add_finally<ngraph::opset1::Subtract>();
|
||||
convert_mul_or_add_finally<ngraph::opset1::Multiply>();
|
||||
|
@ -22,7 +22,7 @@ ngraph::pass::ConvertNormalizeL2WithMulToNormalizeIE::ConvertNormalizeL2WithMulT
|
||||
auto normalize = std::make_shared<ngraph::op::NormalizeL2>(input_0, axis, 0.0f, ngraph::op::EpsMode::ADD);
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply> (normalize, input_1);
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto mul = std::dynamic_pointer_cast<ngraph::opset1::Multiply> (m.get_match_root());
|
||||
if (!mul) return false;
|
||||
|
||||
|
@ -140,7 +140,10 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
|
||||
// List of final conversion transformations that must to be executed
|
||||
// after previous group of transformations
|
||||
|
||||
if (!ngraph::op::util::has_op_with_type<ngraph::op::FakeQuantize>(f)) {
|
||||
manager.register_pass<ngraph::pass::ReshapeFullyConnectedFusion>();
|
||||
}
|
||||
manager.register_pass<ngraph::pass::ConvertNormalizeL2ToLegacyMatcher>();
|
||||
manager.register_pass<ngraph::pass::ConvertMulAddToScaleShiftOrPower>();
|
||||
manager.register_pass<ngraph::pass::ConvertMulOrAddFinally>();
|
||||
|
@ -15,8 +15,10 @@
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertPriorBox, "ConvertPriorBox", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertPriorBoxToLegacy, "ConvertPriorBoxToLegacy", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertPriorBoxClusteredToLegacy, "ConvertPriorBoxClusteredToLegacy", 0);
|
||||
|
||||
void ngraph::pass::ConvertPriorBox::convert_prior_box() {
|
||||
ngraph::pass::ConvertPriorBoxToLegacy::ConvertPriorBoxToLegacy() {
|
||||
auto data = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
|
||||
auto axes = ngraph::opset1::Constant::create(element::i64, Shape{1}, {0});
|
||||
auto image = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
|
||||
@ -35,7 +37,7 @@ void ngraph::pass::ConvertPriorBox::convert_prior_box() {
|
||||
auto prior_box = std::make_shared<ngraph::opset1::PriorBox>(data, image, attr);
|
||||
auto unsqueeze = std::make_shared<ngraph::opset1::Unsqueeze> (prior_box, axes);
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto unsqueeze = std::dynamic_pointer_cast<ngraph::opset1::Unsqueeze> (m.get_match_root());
|
||||
if (!unsqueeze) {
|
||||
return false;
|
||||
@ -153,13 +155,11 @@ void ngraph::pass::ConvertPriorBox::convert_prior_box() {
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(unsqueeze, "CPUFusion.ConvertPriorBoxToPriorBoxIE");
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(unsqueeze, "ConvertPriorBoxToLegacy");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
void ngraph::pass::ConvertPriorBox::convert_prior_box_clustered() {
|
||||
ngraph::pass::ConvertPriorBoxClusteredToLegacy::ConvertPriorBoxClusteredToLegacy() {
|
||||
auto data = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
|
||||
auto axes = ngraph::opset1::Constant::create(element::i64, Shape{1}, {0});
|
||||
auto image = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
|
||||
@ -176,7 +176,7 @@ void ngraph::pass::ConvertPriorBox::convert_prior_box_clustered() {
|
||||
auto prior_box = std::make_shared<ngraph::opset1::PriorBoxClustered>(data, image, attr);
|
||||
auto unsqueeze = std::make_shared<ngraph::opset1::Unsqueeze> (prior_box, axes);
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto unsqueeze = std::dynamic_pointer_cast<ngraph::opset1::Unsqueeze> (m.get_match_root());
|
||||
if (!unsqueeze) {
|
||||
return false;
|
||||
@ -293,8 +293,6 @@ void ngraph::pass::ConvertPriorBox::convert_prior_box_clustered() {
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(unsqueeze, "CPUFusion.ConvertPriorBoxClusteredToPriorBoxClusteredIE");
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(unsqueeze, "ConvertPriorBoxClusteredToLegacy");
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -15,13 +15,13 @@
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeFullyConnectedFusion, "ReshapeFullyConnectedFusion", 0);
|
||||
|
||||
void ngraph::pass::ReshapeFullyConnectedFusion::construct_reshape_fc() {
|
||||
ngraph::pass::ReshapeFullyConnectedFusion::ReshapeFullyConnectedFusion() {
|
||||
auto m_reshape = pattern::wrap_type<opset1::Reshape>(pattern::has_static_shape());
|
||||
auto m_fc = pattern::wrap_type<op::FullyConnected>({m_reshape,
|
||||
pattern::any_input(),
|
||||
pattern::any_input()});
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [=](pattern::Matcher &m) {
|
||||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher &m) {
|
||||
auto & pattern_to_output = m.get_pattern_value_map();
|
||||
auto fc = pattern_to_output[m_fc].get_node_shared_ptr();
|
||||
auto reshape = pattern_to_output[m_reshape].get_node_shared_ptr();
|
||||
@ -52,7 +52,5 @@ void ngraph::pass::ReshapeFullyConnectedFusion::construct_reshape_fc() {
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(m_fc, "ReshapeFullyConnectedFusion");
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ ngraph::pass::ReshapeFullyConnected::ReshapeFullyConnected() {
|
||||
pattern::any_input()},
|
||||
pattern::has_static_shape());
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto fc = std::dynamic_pointer_cast<ngraph::op::FullyConnected> (m.get_match_root());
|
||||
if (!fc || transformation_callback(fc)) {
|
||||
return false;
|
||||
|
@ -34,16 +34,6 @@ class TRANSFORMATIONS_API DeconvAddFusion;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::ConvFusion: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvFusion() : GraphRewrite() {
|
||||
add_matcher<ngraph::pass::ConvAddFusion>();
|
||||
add_matcher<ngraph::pass::ConvMultiplyFusion>();
|
||||
add_matcher<ngraph::pass::DeconvAddFusion>();
|
||||
}
|
||||
};
|
||||
|
||||
class ngraph::pass::ConvAddFusion: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
@ -61,3 +51,13 @@ public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
DeconvAddFusion();
|
||||
};
|
||||
|
||||
class ngraph::pass::ConvFusion: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvFusion() {
|
||||
add_matcher<ngraph::pass::ConvAddFusion>();
|
||||
add_matcher<ngraph::pass::ConvMultiplyFusion>();
|
||||
add_matcher<ngraph::pass::DeconvAddFusion>();
|
||||
}
|
||||
};
|
@ -41,13 +41,8 @@ namespace pass {
|
||||
*
|
||||
*/
|
||||
|
||||
class ngraph::pass::DepthToSpaceFusion: public ngraph::pass::GraphRewrite {
|
||||
class ngraph::pass::DepthToSpaceFusion: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
DepthToSpaceFusion() : GraphRewrite() {
|
||||
depth_to_space_fusion();
|
||||
}
|
||||
|
||||
private:
|
||||
void depth_to_space_fusion();
|
||||
DepthToSpaceFusion();
|
||||
};
|
||||
|
@ -19,13 +19,8 @@ class TRANSFORMATIONS_API RemoveFilteringBoxesBySize;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::RemoveFilteringBoxesBySize: public ngraph::pass::GraphRewrite {
|
||||
class ngraph::pass::RemoveFilteringBoxesBySize: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
RemoveFilteringBoxesBySize() : GraphRewrite() {
|
||||
remove_filtering_boxes_by_size();
|
||||
}
|
||||
|
||||
private:
|
||||
void remove_filtering_boxes_by_size();
|
||||
RemoveFilteringBoxesBySize();
|
||||
};
|
||||
|
@ -23,13 +23,8 @@ class TRANSFORMATIONS_API ConvertScatterElementsToScatter;
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief ConvertScatterElementsToScatter convert opset3::ScatterElementsUpdate to opset3::ScatterUpdate.
|
||||
*/
|
||||
class ngraph::pass::ConvertScatterElementsToScatter: public ngraph::pass::GraphRewrite {
|
||||
class ngraph::pass::ConvertScatterElementsToScatter: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertScatterElementsToScatter() : GraphRewrite() {
|
||||
convert_scatter_elements_to_scatter();
|
||||
}
|
||||
|
||||
private:
|
||||
void convert_scatter_elements_to_scatter();
|
||||
ConvertScatterElementsToScatter();
|
||||
};
|
||||
|
@ -19,13 +19,8 @@ class TRANSFORMATIONS_API ConvertShapeOf3;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::ConvertShapeOf3: public ngraph::pass::GraphRewrite {
|
||||
class ngraph::pass::ConvertShapeOf3: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertShapeOf3() : GraphRewrite() {
|
||||
convert_shapeof3();
|
||||
}
|
||||
|
||||
private:
|
||||
void convert_shapeof3();
|
||||
ConvertShapeOf3();
|
||||
};
|
||||
|
@ -19,13 +19,8 @@ class TRANSFORMATIONS_API ConvertShuffleChannels3;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::ConvertShuffleChannels3: public ngraph::pass::GraphRewrite {
|
||||
class ngraph::pass::ConvertShuffleChannels3: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertShuffleChannels3() : GraphRewrite() {
|
||||
convert_shuffle_channels3();
|
||||
}
|
||||
|
||||
private:
|
||||
void convert_shuffle_channels3();
|
||||
ConvertShuffleChannels3();
|
||||
};
|
||||
|
@ -19,13 +19,8 @@ class TRANSFORMATIONS_API ConvertTopK3;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::ConvertTopK3: public ngraph::pass::GraphRewrite {
|
||||
class ngraph::pass::ConvertTopK3: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertTopK3() : GraphRewrite() {
|
||||
convert_topk3();
|
||||
}
|
||||
|
||||
private:
|
||||
void convert_topk3();
|
||||
ConvertTopK3();
|
||||
};
|
||||
|
@ -65,8 +65,8 @@ bool IsConvInLowPrecision(const std::shared_ptr<Conv>& conv) {
|
||||
}
|
||||
|
||||
template <class Conv>
|
||||
ngraph::graph_rewrite_callback get_callback() {
|
||||
ngraph::graph_rewrite_callback callback = [](ngraph::pattern::Matcher &m) {
|
||||
ngraph::matcher_pass_callback get_callback() {
|
||||
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) {
|
||||
auto eltwise = m.get_match_root();
|
||||
|
||||
std::shared_ptr<ngraph::opset1::Constant> m_const;
|
||||
|
@ -83,7 +83,7 @@ bool check_depth_first(const ngraph::Shape& shape_input, const ngraph::Shape& sh
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::DepthToSpaceFusion, "DepthToSpaceFusion", 0);
|
||||
|
||||
void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() {
|
||||
ngraph::pass::DepthToSpaceFusion::DepthToSpaceFusion() {
|
||||
auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
||||
auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
|
||||
auto input2 = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
|
||||
@ -92,7 +92,7 @@ void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() {
|
||||
auto permute = std::make_shared<ngraph::opset3::Transpose> (reshape_before, input2);
|
||||
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, input3, false);
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto reshape_after = std::dynamic_pointer_cast<ngraph::opset3::Reshape>(m.get_match_root());
|
||||
if (!reshape_after) {
|
||||
return false;
|
||||
@ -160,7 +160,5 @@ void ngraph::pass::DepthToSpaceFusion::depth_to_space_fusion() {
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_after, "DepthToSpaceFusion");
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ ngraph::pass::MishFusion::MishFusion() {
|
||||
auto tanh = std::make_shared<ngraph::opset4::Tanh>(log);
|
||||
auto mul = std::make_shared<ngraph::opset4::Multiply>(input, tanh);
|
||||
|
||||
ngraph::graph_rewrite_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto & pattern_to_output = m.get_pattern_value_map();
|
||||
auto exp_input = pattern_to_output.at(input);
|
||||
|
||||
|
@ -28,7 +28,7 @@ ngraph::pass::NormalizeL2FusionWithMax::NormalizeL2FusionWithMax() {
|
||||
auto sqrt_max_eps = std::make_shared<ngraph::opset4::Maximum>(sqrt, eps_const);
|
||||
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_max_eps);
|
||||
|
||||
ngraph::graph_rewrite_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
const auto data_input = pattern_to_output.at(input);
|
||||
@ -81,7 +81,7 @@ ngraph::pass::NormalizeL2FusionWithAdd::NormalizeL2FusionWithAdd() {
|
||||
auto sqrt_add_eps = std::make_shared<ngraph::opset4::Add>(sqrt, eps_const);
|
||||
auto divide = std::make_shared<ngraph::opset4::Divide>(input, sqrt_add_eps);
|
||||
|
||||
ngraph::graph_rewrite_callback matcher_pass_callback = [=](ngraph::pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
const auto data_input = pattern_to_output.at(input);
|
||||
@ -117,5 +117,5 @@ ngraph::pass::NormalizeL2FusionWithAdd::NormalizeL2FusionWithAdd() {
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(divide, "NormalizeL2FusionWithMax");
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
@ -12,7 +12,7 @@
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::RemoveFilteringBoxesBySize, "RemoveFilteringBoxesBySize", 0);
|
||||
|
||||
void ngraph::pass::RemoveFilteringBoxesBySize::remove_filtering_boxes_by_size() {
|
||||
ngraph::pass::RemoveFilteringBoxesBySize::RemoveFilteringBoxesBySize() {
|
||||
// variadic split
|
||||
auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1000, 4});
|
||||
auto sizes = opset3::Constant::create(element::i64, Shape{4}, std::vector<int64_t >({1, 1, 1, 1}));
|
||||
@ -79,7 +79,7 @@ void ngraph::pass::RemoveFilteringBoxesBySize::remove_filtering_boxes_by_size()
|
||||
|
||||
auto cast = std::make_shared<ngraph::opset3::Convert>(squeeze_3, ngraph::element::i64);
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [data](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [data](pattern::Matcher& m) {
|
||||
auto start = opset3::Constant::create(element::i64, Shape{}, std::vector<int64_t >({0}));
|
||||
auto step = opset3::Constant::create(element::i64, Shape{}, std::vector<int64_t >({1}));
|
||||
|
||||
@ -104,7 +104,5 @@ void ngraph::pass::RemoveFilteringBoxesBySize::remove_filtering_boxes_by_size()
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(cast, "RemoveFilteringBoxesBySize");
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertBatchToSpace, "ConvertBatchToSpace",
|
||||
|
||||
void ngraph::pass::ConvertBatchToSpace::convert_batch_to_space() {
|
||||
auto batch_to_space = ngraph::pattern::wrap_type<ngraph::opset3::BatchToSpace>();
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto batch_to_space = std::dynamic_pointer_cast<ngraph::opset3::BatchToSpace> (m.get_match_root());
|
||||
if (!batch_to_space) {
|
||||
return false;
|
||||
@ -127,7 +127,7 @@ void ngraph::pass::ConvertBatchToSpace::convert_batch_to_space() {
|
||||
|
||||
void ngraph::pass::ConvertBatchToSpace::convert_batch_to_space_by_elements() {
|
||||
auto batch_to_space = ngraph::pattern::wrap_type<ngraph::opset3::BatchToSpace>();
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto batch_to_space = std::dynamic_pointer_cast<ngraph::opset3::BatchToSpace> (m.get_match_root());
|
||||
if (!batch_to_space) {
|
||||
return false;
|
||||
|
@ -16,7 +16,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertDepthToSpace, "ConvertDepthToSpace",
|
||||
ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() {
|
||||
auto dts_node = ngraph::pattern::wrap_type<ngraph::opset1::DepthToSpace>({pattern::any_input(pattern::has_static_shape())});
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
|
||||
if (!dts_node || transformation_callback(dts_node)) {
|
||||
return false;
|
||||
|
@ -16,7 +16,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertDivide, "ConvertDivide", 0);
|
||||
ngraph::pass::ConvertDivide::ConvertDivide() {
|
||||
auto div = ngraph::pattern::wrap_type<ngraph::opset1::Divide>();
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto div = std::dynamic_pointer_cast<ngraph::opset1::Divide> (m.get_match_root());
|
||||
// We can not apply this transformation in case with integer input data type
|
||||
if (!div || div->input(0).get_element_type().is_integral()) {
|
||||
|
@ -16,7 +16,7 @@ ngraph::pass::ConvertGELU::ConvertGELU() {
|
||||
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{});
|
||||
auto gelu = std::make_shared<ngraph::opset2::Gelu>(input);
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto gelu = std::dynamic_pointer_cast<ngraph::opset2::Gelu>(m.get_match_root());
|
||||
if (!gelu || transformation_callback(gelu))
|
||||
return false;
|
||||
|
@ -16,7 +16,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMinimum, "ConvertMinimum", 0);
|
||||
ngraph::pass::ConvertMinimum::ConvertMinimum() {
|
||||
auto minimum = ngraph::pattern::wrap_type<opset1::Minimum>();
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto minimum = std::dynamic_pointer_cast<ngraph::opset1::Minimum> (m.get_match_root());
|
||||
if (!minimum || transformation_callback(minimum)) {
|
||||
return false;
|
||||
|
@ -16,7 +16,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertNegative, "ConvertNegative", 0);
|
||||
ngraph::pass::ConvertNegative::ConvertNegative() {
|
||||
auto neg = ngraph::pattern::wrap_type<ngraph::opset1::Negative>();
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto neg = std::dynamic_pointer_cast<ngraph::opset1::Negative> (m.get_match_root());
|
||||
if (!neg) {
|
||||
return false;
|
||||
|
@ -14,7 +14,7 @@
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertScatterElementsToScatter, "ConvertScatterElementsToScatter", 0);
|
||||
|
||||
void ngraph::pass::ConvertScatterElementsToScatter::convert_scatter_elements_to_scatter() {
|
||||
ngraph::pass::ConvertScatterElementsToScatter::ConvertScatterElementsToScatter() {
|
||||
auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
|
||||
auto indices = std::make_shared<pattern::op::Label>(element::i64, Shape{1});
|
||||
auto updates = std::make_shared<pattern::op::Label>(element::f32, Shape{1});
|
||||
@ -25,7 +25,7 @@ void ngraph::pass::ConvertScatterElementsToScatter::convert_scatter_elements_to_
|
||||
|
||||
auto scatter = std::make_shared<ngraph::opset3::ScatterElementsUpdate>(data, broadcast, updates, axis);
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto scatter = m.get_match_root();
|
||||
auto broadcast = scatter->input_value(1).get_node_shared_ptr();
|
||||
auto axis_const = std::dynamic_pointer_cast<ngraph::opset3::Constant>(scatter->input_value(3).get_node_shared_ptr());
|
||||
@ -209,8 +209,6 @@ void ngraph::pass::ConvertScatterElementsToScatter::convert_scatter_elements_to_
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(scatter, "ConvertScatterElementsToScatter");
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
|
@ -10,14 +10,14 @@
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertShapeOf3, "ConvertShapeOf3", 0);
|
||||
|
||||
void ngraph::pass::ConvertShapeOf3::convert_shapeof3() {
|
||||
auto input = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1, 1, 1});
|
||||
auto shapeof = std::make_shared<ngraph::opset3::ShapeOf>(input);
|
||||
ngraph::pass::ConvertShapeOf3::ConvertShapeOf3() {
|
||||
auto shapeof = pattern::wrap_type<ngraph::opset3::ShapeOf>();
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto shapeof = std::dynamic_pointer_cast<ngraph::opset3::ShapeOf> (m.get_match_root());
|
||||
if (!shapeof) {
|
||||
return false;
|
||||
@ -43,7 +43,5 @@ void ngraph::pass::ConvertShapeOf3::convert_shapeof3() {
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(shapeof, "ConvertShapeOf3");
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
@ -15,11 +15,11 @@ using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertShuffleChannels3, "ConvertShuffleChannels3", 0);
|
||||
|
||||
void ngraph::pass::ConvertShuffleChannels3::convert_shuffle_channels3() {
|
||||
ngraph::pass::ConvertShuffleChannels3::ConvertShuffleChannels3() {
|
||||
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
||||
auto shuffle_channels = std::make_shared<::opset3::ShuffleChannels>(input);
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher &m) {
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
|
||||
auto shuffle_channels = std::dynamic_pointer_cast<::opset3::ShuffleChannels>(m.get_match_root());
|
||||
if (!shuffle_channels || transformation_callback(shuffle_channels)) {
|
||||
return false;
|
||||
@ -98,7 +98,5 @@ void ngraph::pass::ConvertShuffleChannels3::convert_shuffle_channels3() {
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(shuffle_channels, "ConvertShuffleChannels3");
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertSpaceToDepth, "ConvertSpaceToDepth",
|
||||
ngraph::pass::ConvertSpaceToDepth::ConvertSpaceToDepth() {
|
||||
auto dts = ngraph::pattern::wrap_type<ngraph::opset1::SpaceToDepth>({pattern::any_input(pattern::has_static_shape())});
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto std_node = std::dynamic_pointer_cast<ngraph::opset1::SpaceToDepth> (m.get_match_root());
|
||||
if (!std_node || transformation_callback(std_node)) {
|
||||
return false;
|
||||
|
@ -16,7 +16,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertSubtract, "ConvertSubtract", 0);
|
||||
ngraph::pass::ConvertSubtract::ConvertSubtract() {
|
||||
auto sub = ngraph::pattern::wrap_type<ngraph::opset1::Subtract>();
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto sub = std::dynamic_pointer_cast<ngraph::opset1::Subtract> (m.get_match_root());
|
||||
if (!sub) {
|
||||
return false;
|
||||
|
@ -12,12 +12,14 @@
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertTopK3, "ConvertTopK3", 0);
|
||||
|
||||
void ngraph::pass::ConvertTopK3::convert_topk3() {
|
||||
auto topk = std::make_shared<pattern::op::Label>(element::f32, Shape{}, pattern::has_class<opset3::TopK>());
|
||||
ngraph::pass::ConvertTopK3::ConvertTopK3() {
|
||||
auto topk = pattern::wrap_type<opset3::TopK>();
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto topk = std::dynamic_pointer_cast<ngraph::opset3::TopK> (m.get_match_root());
|
||||
if (!topk) {
|
||||
return false;
|
||||
@ -60,7 +62,5 @@ void ngraph::pass::ConvertTopK3::convert_topk3() {
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(topk, "ConvertTopK3");
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
@ -18,6 +18,7 @@
|
||||
#include <transformations/op_conversions/convert_scatter_elements_to_scatter.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
@ -75,8 +76,10 @@ std::shared_ptr<ngraph::Function> get_reference_function(const ngraph::PartialSh
|
||||
}
|
||||
|
||||
void test(std::shared_ptr<ngraph::Function> f, std::shared_ptr<ngraph::Function> f_ref) {
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertScatterElementsToScatter().run_on_function(f);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertScatterElementsToScatter>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ngraph::pass::ConstantFolding().run_on_function(f);
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <transformations/op_conversions/convert_shapeof3.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
@ -26,8 +27,10 @@ TEST(TransformationTests, ConvertShapeOf3WithI64) {
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{shapeof}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertShapeOf3().run_on_function(f);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertShapeOf3>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
@ -56,8 +59,10 @@ TEST(TransformationTests, ConvertShapeOf3WithI32) {
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{shapeof}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertShapeOf3().run_on_function(f);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertShapeOf3>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <transformations/op_conversions/convert_shuffle_channels3.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
@ -25,8 +26,10 @@ std::shared_ptr<ngraph::Function> buildInputGraph(int64_t axis, int64_t group, c
|
||||
|
||||
auto f = std::make_shared<::Function>(::NodeVector{shuffle_channels}, ::ParameterVector{input});
|
||||
|
||||
::pass::InitNodeInfo().run_on_function(f);
|
||||
::pass::ConvertShuffleChannels3().run_on_function(f);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertShuffleChannels3>();
|
||||
manager.run_passes(f);
|
||||
f->validate_nodes_and_infer_types();
|
||||
return f;
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include <transformations/op_conversions/convert_topk3.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
@ -32,8 +33,11 @@ TEST(TransformationTests, ConvertTopK3I32Output0) {
|
||||
// due to the 'compare_functions' limitation we will check only one output
|
||||
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertTopK3().run_on_function(f);
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertTopK3>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
@ -67,8 +71,10 @@ TEST(TransformationTests, ConvertTopK3I32Output1) {
|
||||
// due to the 'compare_functions' limitation we will check only one output
|
||||
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(1)}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertTopK3().run_on_function(f);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertTopK3>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
@ -102,8 +108,10 @@ TEST(TransformationTests, ConvertTopK3I64Output0) {
|
||||
// due to the 'compare_functions' limitation we will check only one output
|
||||
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(0)}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertTopK3().run_on_function(f);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertTopK3>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
@ -137,8 +145,10 @@ TEST(TransformationTests, ConvertTopK3I64Output1) {
|
||||
// due to the 'compare_functions' limitation we will check only one output
|
||||
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{topk->output(1)}, ngraph::ParameterVector{input});
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertTopK3().run_on_function(f);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertTopK3>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include <transformations/common_optimizations/depth_to_space_fusion.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
@ -37,14 +38,19 @@ TEST(TransformationTests, DepthToSpaceFusionDepthFirst) {
|
||||
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
|
||||
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
|
||||
};
|
||||
|
||||
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
|
||||
depth_to_space_transform.set_callback(callback);
|
||||
depth_to_space_transform.run_on_function(f);
|
||||
ngraph::pass::Manager manager;
|
||||
|
||||
auto pass_config = manager.get_pass_config();
|
||||
pass_config->set_callback<ngraph::pass::DepthToSpaceFusion>(callback);
|
||||
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
@ -71,14 +77,19 @@ TEST(TransformationTests, DepthToSpaceFusionBlockFirst) {
|
||||
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
|
||||
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
|
||||
};
|
||||
|
||||
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
|
||||
depth_to_space_transform.set_callback(callback);
|
||||
depth_to_space_transform.run_on_function(f);
|
||||
ngraph::pass::Manager manager;
|
||||
|
||||
auto pass_config = manager.get_pass_config();
|
||||
pass_config->set_callback<ngraph::pass::DepthToSpaceFusion>(callback);
|
||||
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
@ -105,15 +116,19 @@ TEST(TransformationTests, DepthToSpaceFusionDynamicShape) {
|
||||
auto reshape_after = std::make_shared<ngraph::opset3::Reshape> (permute, shape_reshape_after, false);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0, shape_reshape_before});
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
|
||||
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
|
||||
};
|
||||
|
||||
// transformation won't be applied because of shape_reshape_before is dynamic, the graph will remain the same
|
||||
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
|
||||
depth_to_space_transform.set_callback(callback);
|
||||
depth_to_space_transform.run_on_function(f);
|
||||
ngraph::pass::Manager manager;
|
||||
|
||||
auto pass_config = manager.get_pass_config();
|
||||
pass_config->set_callback<ngraph::pass::DepthToSpaceFusion>(callback);
|
||||
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
@ -150,15 +165,19 @@ TEST(TransformationTests, DepthToSpaceFusionSeveralConsumers) {
|
||||
auto result = std::make_shared<ngraph::opset3::Result> (reshape_before);
|
||||
auto result_2 = std::make_shared<ngraph::opset3::Result> (permute);
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{reshape_after}, ngraph::ParameterVector{input0});
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
|
||||
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||
return std::dynamic_pointer_cast<const ngraph::opset3::DepthToSpace>(node) != nullptr;
|
||||
};
|
||||
|
||||
// transformation won't be applied because of reshape_before has several consumers, the graph will remain the same
|
||||
auto depth_to_space_transform = ngraph::pass::DepthToSpaceFusion();
|
||||
depth_to_space_transform.set_callback(callback);
|
||||
depth_to_space_transform.run_on_function(f);
|
||||
ngraph::pass::Manager manager;
|
||||
|
||||
auto pass_config = manager.get_pass_config();
|
||||
pass_config->set_callback<ngraph::pass::DepthToSpaceFusion>(callback);
|
||||
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::DepthToSpaceFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
|
@ -132,8 +132,10 @@ public:
|
||||
class MulOrAddConversionTests: public MulAddConversionTests {};
|
||||
|
||||
TEST_P(MulAddConversionTests, CompareFunctions) {
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertMulAddToScaleShiftOrPower().run_on_function(f);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertMulAddToScaleShiftOrPower>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ngraph::pass::ConstantFolding().run_on_function(f);
|
||||
f->validate_nodes_and_infer_types();
|
||||
|
@ -34,8 +34,10 @@ TEST(TransformationTests, ReshapeFCFusiuonTest1) {
|
||||
auto fc = std::make_shared<ngraph::op::FullyConnected>(reshape, fc_weights, fc_biases, ngraph::Shape{1, 6});
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{});
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ReshapeFullyConnectedFusion().run_on_function(f);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ReshapeFullyConnectedFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
ASSERT_EQ(f->get_ops().size(), 5);
|
||||
@ -53,8 +55,10 @@ TEST(TransformationTests, ReshapeFCFusiuonTest2) {
|
||||
auto fc = std::make_shared<ngraph::op::FullyConnected>(reshape, fc_weights, fc_biases, ngraph::Shape{1, 6});
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{});
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ReshapeFullyConnectedFusion().run_on_function(f);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ReshapeFullyConnectedFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
ASSERT_EQ(f->get_ops().size(), 5);
|
||||
@ -72,8 +76,10 @@ TEST(TransformationTests, ReshapeFCFusiuonTest3) {
|
||||
auto fc = std::make_shared<ngraph::op::FullyConnected>(reshape, fc_weights, fc_biases, ngraph::Shape{2, 6});
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{});
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ReshapeFullyConnectedFusion().run_on_function(f);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ReshapeFullyConnectedFusion>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
ASSERT_EQ(f->get_ops().size(), 7);
|
||||
|
Loading…
Reference in New Issue
Block a user