WrapType Improvements (#4040)
* Extended WrapType to consume multiple types; Added variadic wrap_type support * Updated transformations to use wrap_type * Fix BatchNormDecomposition * Added tests
This commit is contained in:
@@ -19,7 +19,6 @@ namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API BatchNormDecomposition;
|
||||
class TRANSFORMATIONS_API BatchNormV5Decomposition;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
@@ -29,9 +28,3 @@ public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
BatchNormDecomposition();
|
||||
};
|
||||
|
||||
class ngraph::pass::BatchNormV5Decomposition: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
BatchNormV5Decomposition();
|
||||
};
|
||||
|
||||
@@ -83,7 +83,6 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
auto common_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
common_fusions->add_matcher<ngraph::pass::ConvertScatterElementsToScatter>();
|
||||
common_fusions->add_matcher<ngraph::pass::DepthToSpaceFusion>();
|
||||
//common_fusions->add_matcher<ngraph::pass::MishFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::SoftPlusFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::SoftPlusToMishFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::SwishFusion>();
|
||||
@@ -115,7 +114,6 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
decomp->add_matcher<ngraph::pass::ConvertDepthToSpace>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertSpaceToDepth>();
|
||||
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
|
||||
decomp->add_matcher<ngraph::pass::BatchNormV5Decomposition>();
|
||||
decomp->set_name("ngraph::pass::CommonDecompositions");
|
||||
|
||||
// CF is required after all decompositions
|
||||
|
||||
@@ -164,7 +164,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvAddFusion, "ConvAddFusion", 0);
|
||||
ngraph::pass::ConvAddFusion::ConvAddFusion() {
|
||||
MATCHER_SCOPE(ConvAddFusion);
|
||||
auto conv = ngraph::pattern::wrap_type<op::ConvolutionIE>(pattern::consumers_count(1));
|
||||
auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, std::make_shared<pattern::op::Label>()});
|
||||
auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, pattern::any_input()});
|
||||
|
||||
matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) {
|
||||
return conv_callback<op::ConvolutionIE>(m);
|
||||
@@ -179,7 +179,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvMultiplyFusion, "ConvMultiplyFusion", 0
|
||||
ngraph::pass::ConvMultiplyFusion::ConvMultiplyFusion() {
|
||||
MATCHER_SCOPE(ConvMultiplyFusion);
|
||||
auto conv = ngraph::pattern::wrap_type<op::ConvolutionIE>(pattern::consumers_count(1));
|
||||
auto add = ngraph::pattern::wrap_type<opset1::Multiply>({conv, std::make_shared<pattern::op::Label>()});
|
||||
auto add = ngraph::pattern::wrap_type<opset1::Multiply>({conv, pattern::any_input()});
|
||||
|
||||
matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) {
|
||||
return conv_callback<op::ConvolutionIE>(m);
|
||||
@@ -194,7 +194,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::DeconvAddFusion, "DeconvAddFusion", 0);
|
||||
ngraph::pass::DeconvAddFusion::DeconvAddFusion() {
|
||||
MATCHER_SCOPE(DeconvAddFusion);
|
||||
auto conv = ngraph::pattern::wrap_type<op::DeconvolutionIE>(pattern::consumers_count(1));
|
||||
auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, std::make_shared<pattern::op::Label>()});
|
||||
auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, pattern::any_input()});
|
||||
|
||||
matcher_pass_callback callback = [](ngraph::pattern::Matcher &m){
|
||||
return conv_callback<op::DeconvolutionIE>(m);
|
||||
|
||||
@@ -19,7 +19,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormDecomposition, "BatchNormDecomposi
|
||||
|
||||
ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
|
||||
MATCHER_SCOPE(BatchNormDecomposition);
|
||||
auto bn = pattern::wrap_type<opset1::BatchNormInference>({
|
||||
auto bn = pattern::wrap_type<opset1::BatchNormInference, opset5::BatchNormInference>({
|
||||
pattern::any_input(pattern::has_static_rank()),
|
||||
pattern::any_input(pattern::has_static_shape()),
|
||||
pattern::any_input(pattern::has_static_shape()),
|
||||
@@ -28,20 +28,30 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
|
||||
});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) {
|
||||
auto m_bn = dynamic_pointer_cast<opset1::BatchNormInference>(m.get_match_root());
|
||||
if (!m_bn) {
|
||||
auto m_bn = m.get_match_root();
|
||||
Output<Node> m_input, m_gamma, m_beta, m_mean, m_var;
|
||||
double eps;
|
||||
if (auto m_bn_v1 = dynamic_pointer_cast<opset1::BatchNormInference>(m_bn)) {
|
||||
m_gamma = m_bn_v1->input_value(0);
|
||||
m_beta = m_bn_v1->input_value(1);
|
||||
m_input = m_bn_v1->input_value(2);
|
||||
m_mean = m_bn_v1->input_value(3);
|
||||
m_var = m_bn_v1->input_value(4);
|
||||
eps = m_bn_v1->get_eps_value();
|
||||
} else if (auto m_bn_v5 = dynamic_pointer_cast<opset5::BatchNormInference>(m_bn)) {
|
||||
m_input = m_bn_v5->input_value(0);
|
||||
m_gamma = m_bn_v5->input_value(1);
|
||||
m_beta = m_bn_v5->input_value(2);
|
||||
m_mean = m_bn_v5->input_value(3);
|
||||
m_var = m_bn_v5->input_value(4);
|
||||
eps = m_bn_v5->get_eps_value();
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto m_gamma = m_bn->input_value(0);
|
||||
auto m_beta = m_bn->input_value(1);
|
||||
auto m_input = m_bn->input_value(2);
|
||||
auto m_mean = m_bn->input_value(3);
|
||||
auto m_var = m_bn->input_value(4);
|
||||
|
||||
const auto& input_type = m_input.get_element_type();
|
||||
// scale_add = variance + eps
|
||||
auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
|
||||
auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {eps}));
|
||||
// scale = sqrt(variance + eps)
|
||||
auto scale = make_shared<opset5::Sqrt>(scale_add);
|
||||
// Divide `gamma` by `sqrt(variance + eps)`
|
||||
@@ -79,67 +89,3 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormV5Decomposition, "BatchNormDecomposition", 5);
|
||||
|
||||
// TODO: this pass will be unified with BatchNormDecomposition pass
|
||||
ngraph::pass::BatchNormV5Decomposition::BatchNormV5Decomposition() {
|
||||
MATCHER_SCOPE(BatchNormV5Decomposition);
|
||||
auto bn = pattern::wrap_type<opset5::BatchNormInference>({
|
||||
pattern::any_input(pattern::has_static_rank()),
|
||||
pattern::any_input(pattern::has_static_shape()),
|
||||
pattern::any_input(pattern::has_static_shape()),
|
||||
pattern::any_input(pattern::has_static_shape()),
|
||||
pattern::any_input(pattern::has_static_shape())
|
||||
});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) {
|
||||
auto m_bn = dynamic_pointer_cast<opset5::BatchNormInference>(m.get_match_root());
|
||||
if (!m_bn) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto m_input = m_bn->input_value(0);
|
||||
auto m_gamma = m_bn->input_value(1);
|
||||
auto m_beta = m_bn->input_value(2);
|
||||
auto m_mean = m_bn->input_value(3);
|
||||
auto m_var = m_bn->input_value(4);
|
||||
|
||||
const auto& input_type = m_input.get_element_type();
|
||||
// scale_add = variance + eps
|
||||
auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
|
||||
// scale = sqrt(variance + eps)
|
||||
auto scale = make_shared<opset5::Sqrt>(scale_add);
|
||||
// Divide `gamma` by `sqrt(variance + eps)`
|
||||
auto gamma_div_scale = std::make_shared<opset5::Divide>(m_gamma, scale);
|
||||
|
||||
int64_t dims_to_add = m_input.get_partial_shape().rank().get_length() - 2;
|
||||
|
||||
// TODO: instead of getting full shape we can concatenate sequence of ones with ShapeOf
|
||||
Shape input_aligned_shape = m_gamma.get_shape();
|
||||
for (int64_t i = 0; i < dims_to_add; ++i)
|
||||
input_aligned_shape.push_back(1);
|
||||
auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);
|
||||
|
||||
auto gamma_div_scale_aligned = make_shared<opset5::Reshape>(gamma_div_scale, new_shape, true);
|
||||
auto beta_aligned = make_shared<opset5::Reshape>(m_beta, new_shape, true);
|
||||
auto mean_aligned = make_shared<opset5::Reshape>(m_mean, new_shape, true);
|
||||
|
||||
// input_sub_mean = input - mean
|
||||
auto input_sub_mean = register_new_node<opset5::Subtract>(m_input, mean_aligned);
|
||||
// Multiply `input - mean` and `gamma / sqrt(variance + eps)`
|
||||
auto mul = std::make_shared<opset5::Multiply>(input_sub_mean, gamma_div_scale_aligned);
|
||||
// Add `(input - mean) * gamma / sqrt(variance + eps)` and `beta`
|
||||
auto add = std::make_shared<opset5::Add>(mul, beta_aligned);
|
||||
|
||||
add->set_friendly_name(m_bn->get_friendly_name());
|
||||
|
||||
copy_runtime_info(m_bn, {scale_add, scale, gamma_div_scale, gamma_div_scale_aligned,
|
||||
beta_aligned, input_sub_mean, mul, add});
|
||||
|
||||
replace_node(m_bn, add);
|
||||
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(bn, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
@@ -10,13 +10,13 @@
|
||||
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGELU, "ConvertGELU", 0);
|
||||
|
||||
ngraph::pass::ConvertGELU::ConvertGELU() {
|
||||
MATCHER_SCOPE(ConvertGELU);
|
||||
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{});
|
||||
auto gelu = std::make_shared<ngraph::opset2::Gelu>(input);
|
||||
auto gelu = pattern::wrap_type<ngraph::opset2::Gelu>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto gelu = std::dynamic_pointer_cast<ngraph::opset2::Gelu>(m.get_match_root());
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <ngraph/opsets/opset2.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
@@ -18,8 +19,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertShuffleChannels3, "ConvertShuffleCha
|
||||
|
||||
ngraph::pass::ConvertShuffleChannels3::ConvertShuffleChannels3() {
|
||||
MATCHER_SCOPE(ConvertShuffleChannels3);
|
||||
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
||||
auto shuffle_channels = std::make_shared<::opset3::ShuffleChannels>(input);
|
||||
auto shuffle_channels = pattern::wrap_type<opset3::ShuffleChannels>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
|
||||
auto shuffle_channels = std::dynamic_pointer_cast<::opset3::ShuffleChannels>(m.get_match_root());
|
||||
|
||||
@@ -24,8 +24,8 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertTensorIteratorToGRUSequence, "Conver
|
||||
|
||||
ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSequence() {
|
||||
MATCHER_SCOPE(ConvertTensorIteratorToLSTMSequence);
|
||||
auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
|
||||
ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset5::TensorIterator>());
|
||||
auto tensor_iterator = pattern::wrap_type<ngraph::opset5::TensorIterator>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
|
||||
auto ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(m.get_match_root());
|
||||
if (!ti || transformation_callback(ti))
|
||||
@@ -201,8 +201,8 @@ ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSe
|
||||
|
||||
ngraph::pass::ConvertTensorIteratorToRNNSequence::ConvertTensorIteratorToRNNSequence() {
|
||||
MATCHER_SCOPE(ConvertTensorIteratorToRNNSequence);
|
||||
auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
|
||||
ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset5::TensorIterator>());
|
||||
auto tensor_iterator = pattern::wrap_type<ngraph::opset5::TensorIterator>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
|
||||
auto ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(m.get_match_root());
|
||||
if (!ti || transformation_callback(ti))
|
||||
@@ -357,8 +357,8 @@ ngraph::pass::ConvertTensorIteratorToRNNSequence::ConvertTensorIteratorToRNNSequ
|
||||
|
||||
ngraph::pass::ConvertTensorIteratorToGRUSequence::ConvertTensorIteratorToGRUSequence() {
|
||||
MATCHER_SCOPE(ConvertTensorIteratorToGRUSequence);
|
||||
auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
|
||||
ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset5::TensorIterator>());
|
||||
auto tensor_iterator = pattern::wrap_type<ngraph::opset5::TensorIterator>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
|
||||
auto ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(m.get_match_root());
|
||||
if (!ti || transformation_callback(ti))
|
||||
|
||||
@@ -18,10 +18,8 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::LSTMCellDecomposition, "LSTMCellDecompositi
|
||||
|
||||
ngraph::pass::LSTMCellDecomposition::LSTMCellDecomposition() {
|
||||
MATCHER_SCOPE(LSTMCellDecomposition);
|
||||
auto is_supported_lstm_cell = [](const std::shared_ptr<Node>& n) {
|
||||
return pattern::has_class<ngraph::opset1::LSTMCell>()(n) || pattern::has_class<ngraph::opset4::LSTMCell>()(n);
|
||||
};
|
||||
auto any_lstm = std::make_shared<pattern::op::Label>(element::f32, Shape{}, is_supported_lstm_cell);
|
||||
auto any_lstm = pattern::wrap_type<opset1::LSTMCell, opset4::LSTMCell>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
|
||||
auto lstm_cell = std::dynamic_pointer_cast<ngraph::op::util::RNNCellBase>(m.get_match_root());
|
||||
if (!lstm_cell || transformation_callback(lstm_cell)) {
|
||||
|
||||
@@ -36,7 +36,17 @@ namespace ngraph
|
||||
[](const Output<Node>& output) { return true; },
|
||||
const OutputVector& input_values = {})
|
||||
: Pattern(input_values, pred)
|
||||
, m_wrapped_type(wrapped_type)
|
||||
, m_wrapped_types({wrapped_type})
|
||||
{
|
||||
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
|
||||
}
|
||||
|
||||
explicit WrapType(std::vector<NodeTypeInfo> wrapped_types,
|
||||
const ValuePredicate& pred =
|
||||
[](const Output<Node>& output) { return true; },
|
||||
const OutputVector& input_values = {})
|
||||
: Pattern(input_values, pred)
|
||||
, m_wrapped_types(std::move(wrapped_types))
|
||||
{
|
||||
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
|
||||
}
|
||||
@@ -45,30 +55,33 @@ namespace ngraph
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value) override;
|
||||
|
||||
NodeTypeInfo get_wrapped_type() const { return m_wrapped_type; }
|
||||
NodeTypeInfo get_wrapped_type() const;
|
||||
|
||||
const std::vector<NodeTypeInfo>& get_wrapped_types() const;
|
||||
|
||||
private:
|
||||
NodeTypeInfo m_wrapped_type;
|
||||
std::vector<NodeTypeInfo> m_wrapped_types;
|
||||
};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <class... Args>
|
||||
std::shared_ptr<Node> wrap_type(const OutputVector& inputs,
|
||||
const pattern::op::ValuePredicate& pred)
|
||||
{
|
||||
static_assert(std::is_base_of<Node, T>::value, "Unexpected template type");
|
||||
return std::make_shared<op::WrapType>(T::type_info, pred, inputs);
|
||||
std::vector<DiscreteTypeInfo> info{Args::type_info...};
|
||||
return std::make_shared<op::WrapType>(info, pred, inputs);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <class... Args>
|
||||
std::shared_ptr<Node> wrap_type(const OutputVector& inputs = {})
|
||||
{
|
||||
return wrap_type<T>(inputs, [](const Output<Node>& output) { return true; });
|
||||
return wrap_type<Args...>(inputs, [](const Output<Node>& output) { return true; });
|
||||
}
|
||||
|
||||
template <class T>
|
||||
template <class... Args>
|
||||
std::shared_ptr<Node> wrap_type(const pattern::op::ValuePredicate& pred)
|
||||
{
|
||||
return wrap_type<T>({}, pred);
|
||||
return wrap_type<Args...>({}, pred);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -109,12 +109,14 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
|
||||
// it's type
|
||||
// and use it in unordered_map as key for fast MatcherPass search. Otherwise type is unknown
|
||||
// and default algorithm is used.
|
||||
NodeTypeInfo root_type_info = root->get_type_info();
|
||||
if (auto p = dynamic_pointer_cast<pattern::op::Pattern>(root))
|
||||
{
|
||||
if (auto any_type = dynamic_pointer_cast<pattern::op::WrapType>(p))
|
||||
{
|
||||
root_type_info = any_type->get_wrapped_type();
|
||||
for (const auto& root_type_info : any_type->get_wrapped_types())
|
||||
{
|
||||
type_to_matcher[root_type_info].push_back(matcher_index);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -122,7 +124,10 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
|
||||
break;
|
||||
}
|
||||
}
|
||||
type_to_matcher[root_type_info].push_back(matcher_index);
|
||||
else
|
||||
{
|
||||
type_to_matcher[root->get_type_info()].push_back(matcher_index);
|
||||
}
|
||||
|
||||
// TODO: traverse parents for root_type_info in order to register complete list of matchers
|
||||
// including ones triggered by parent type info.
|
||||
|
||||
@@ -31,7 +31,12 @@ bool pattern::op::WrapType::match_value(Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value)
|
||||
{
|
||||
if (graph_value.get_node_shared_ptr()->get_type_info().is_castable(get_wrapped_type()) &&
|
||||
if (std::any_of(m_wrapped_types.begin(),
|
||||
m_wrapped_types.end(),
|
||||
[&](const NodeTypeInfo& type_info) {
|
||||
return graph_value.get_node_shared_ptr()->get_type_info().is_castable(
|
||||
type_info);
|
||||
}) &&
|
||||
m_predicate(graph_value))
|
||||
{
|
||||
auto& pattern_map = matcher->get_pattern_value_map();
|
||||
@@ -44,3 +49,17 @@ bool pattern::op::WrapType::match_value(Matcher* matcher,
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
NodeTypeInfo pattern::op::WrapType::get_wrapped_type() const
|
||||
{
|
||||
if (m_wrapped_types.size() > 1)
|
||||
{
|
||||
throw ngraph::ngraph_error("get_wrapped_type() called on WrapType with more than one type");
|
||||
}
|
||||
return m_wrapped_types.at(0);
|
||||
}
|
||||
|
||||
const std::vector<NodeTypeInfo>& pattern::op::WrapType::get_wrapped_types() const
|
||||
{
|
||||
return m_wrapped_types;
|
||||
}
|
||||
@@ -810,7 +810,7 @@ TEST(pattern, is_contained_match)
|
||||
ASSERT_FALSE(n.is_contained_match());
|
||||
}
|
||||
|
||||
TEST(pattern, wrap_type)
|
||||
TEST(pattern, wrap_type_single_op)
|
||||
{
|
||||
auto a = make_shared<op::Parameter>(element::f32, Shape{1, 3, 64, 64});
|
||||
auto b = make_shared<op::Abs>(a);
|
||||
@@ -852,3 +852,47 @@ TEST(pattern, wrap_type)
|
||||
ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(pattern, wrap_type_multi_op)
|
||||
{
|
||||
auto a = make_shared<op::Parameter>(element::f32, Shape{1, 3, 64, 64});
|
||||
auto b = make_shared<op::Abs>(a);
|
||||
auto c = make_shared<op::Relu>(a);
|
||||
auto mul = make_shared<op::v1::Multiply>(a, op::Constant::create(element::f32, Shape{}, {1}));
|
||||
auto add = make_shared<op::v1::Add>(op::Constant::create(element::f32, Shape{}, {1}), a);
|
||||
|
||||
{
|
||||
auto m = pattern::wrap_type<op::v1::Multiply, op::v1::Add>();
|
||||
auto matcher = std::make_shared<pattern::Matcher>(m, "MulAddMatcher");
|
||||
ASSERT_TRUE(matcher->match(mul->output(0)));
|
||||
ASSERT_EQ(matcher->get_matched_nodes().size(), 1);
|
||||
ASSERT_EQ(matcher->get_matched_nodes()[0], mul);
|
||||
ASSERT_EQ(matcher->get_pattern_map().count(m), 1);
|
||||
|
||||
ASSERT_TRUE(matcher->match(add->output(0)));
|
||||
ASSERT_EQ(matcher->get_matched_nodes().size(), 1);
|
||||
ASSERT_EQ(matcher->get_matched_nodes()[0], add);
|
||||
ASSERT_EQ(matcher->get_pattern_map().count(m), 1);
|
||||
|
||||
ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(a)));
|
||||
ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(b)));
|
||||
ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
|
||||
}
|
||||
{
|
||||
auto m = pattern::wrap_type<op::util::BinaryElementwiseArithmetic>();
|
||||
auto matcher = std::make_shared<pattern::Matcher>(m, "ElementwiseMatcher");
|
||||
ASSERT_TRUE(matcher->match(mul->output(0)));
|
||||
ASSERT_EQ(matcher->get_matched_nodes().size(), 1);
|
||||
ASSERT_EQ(matcher->get_matched_nodes()[0], mul);
|
||||
ASSERT_EQ(matcher->get_pattern_map().count(m), 1);
|
||||
|
||||
ASSERT_TRUE(matcher->match(add->output(0)));
|
||||
ASSERT_EQ(matcher->get_matched_nodes().size(), 1);
|
||||
ASSERT_EQ(matcher->get_matched_nodes()[0], add);
|
||||
ASSERT_EQ(matcher->get_pattern_map().count(m), 1);
|
||||
|
||||
ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(a)));
|
||||
ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(b)));
|
||||
ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user