Files
openvino/ngraph/test/pattern.cpp
Mikhail Treskin 6467c64000 Remove opset0 support and undesired passes from Interpreter backend (#1469)
* Move evaluate() interface from some OPs to Interpreter

* commit

* Move shuffle channels reference to OP's evaluate

* Add some operations missed in evaluate_node

* Fix select references invocation from evaluate_node()

* Activation refs (#2)

* HardSigmoid

* Elu

* Selu

* Gelu

* Move to test runtime

* Rollback donwgrade passes delition

* Initial batch to space refs

* Return opset1_upgrade

* WIP: Add space to batch evaluate

* Fix space to batch

* add evaluates function in evaluates_map (#4)

* Add space to batch evaluate

* Fix crop in batch to space references

* Remove vectors reallocation in evaluates for b2s and s2b

* .

* Add SpaceToDepth evaluate

* Add depth to space evaluate

* Remove code duplication depth to space evaluate

* Fix some failed layer tests

* Ngraph test (#3)

* Remove some v0 ops & fix some tests

* Fixes BatchNorm

* Next

* dd

* s

* Add dot & replace slice refs

* d

* dkj

* Review fixes part 1

* Fixes. Part 2

* Fixes. Part 3

* Enable cells refs in evaluate map

* Fix some failed layer tests

* Some more fixes

* Fix code style (#6)

* Tests (#7)

* PriorBox

* Mod

* NormilizeL2

* Update prior_box.hpp

* Fix one hot ref call

* .

* Select (#8)

* Select

* Fix code style

* Fix select messages

* ReverseSeq (#9)

* ReverseSeq

* Select

* ExtractImagePatches, Seqence

* Fix Code Style

* remove extra

* Remove etra line@

* Add fake quantize reference

* Align convolution layer tests instantiations with updated definition

* Disabled some failed LPT tests

* Disabled some failed LPT tests

* Remove undesired changes

* Update unit-test manifests + some code cleanup

* Fix code style (#10)

* Normalize L2 refs support (from PR #2327)

* Fix code style

* Apply review comments. Part 1 (#11)

* Apply first part of review comments

* Update onnx_import.in.cpp

* Remove redundant reshape from shuffle_channels evaluate

* Decompose GroupConvolution

* [IE Ngraph] Fix some operation inheritance  (#13)

* [IE TESTS] Depth2Space

* Space2Depth

* ShuffleChannels

* Fix ode style

* Fix code style

* [IE NGraph] Remove decompose op (#14)

* .

* Fix loosing control dependency in replace_node

* Fix loosing control dependency in replace_node

* Fix code style

* Fix FQ references build on windows

* Fix code style

* Apply comments (#15)

* [Ie Ngraph] Remove using v1::Add

* [Ie Ngraph] Remove using v1::Mutliply

* [Ie Ngraph] Remove using v1::Subtract

* [Ie Ngraph] Remove using v1::Divide

* [Ie Ngraph] Remove using v1::Equal

* [Ie Ngraph] Remove using v1::Greater

* [Ie Ngraph] Remove using v1::Greater_eq

* [Ie Ngraph] Remove using v1::Less

* [Ie Ngraph] Remove using v1::LessEq

* [Ie Ngraph] Remove using operator+

* [Ie Ngraph] Remove using operator/

* [Ie Ngraph] Remove using operator*

* [Ie Ngraph] Remove using operator-

* Fix code style

* Ci (#16)

* Fix CentOS compilation

* Revert ngraph::op::vo::Multiply removing due to OpenCV

* Android fix (#17)

* fix failures

* Fix code style

* Add (#18)

* Android fix

* Add

* Add in opset1 upgrade pass

* Add in opset1 upgrade pass

* Remove v0::Add, Reverted removing v0::Multiply (#19)

* Remove overloaded math operators from PyNgraph

* Remove overloaded math operators from PyNgraph

* Fix gna tests (#20)

* Fix gna tests

* Squashed commit of the following:

commit 565b504c1c
Author: Alexander Zhogov <alexander.zhogov@intel.com>
Date:   Tue Oct 13 13:27:34 2020 +0300

    GitHub CI: Add files_size.yml (#2570)

    * GitHub CI: Add files_size.yml

    * Update job name

commit ab0fb29853
Author: Vladislav Vinogradov <vlad.vinogradov@intel.com>
Date:   Tue Oct 13 11:37:30 2020 +0300

    [IE][BUILD] Fix C5208 warning under Windows (#2628)

    * C++ feature in C `typedef struct` code.
    * The warning can be promoted to error in dependent projects.

    C5208: unnamed class used in typedef name cannot declare members other than
    non-static data members, member enumerations, or member classes

commit 15a338e89b
Author: helmutg <helmut@subdivi.de>
Date:   Mon Oct 12 22:24:24 2020 +0200

    add build option USE_SYSTEM_PUGIXML (#2502)

    It allows skipping inference-engine/thirdparty/pugixml and using the
    system copy instead.

    Thanks to @Osse for helping understand cmake scoping rules.

    Co-authored-by: Helmut Grohne <helmut.grohne@intenta.de>

commit 7ac8cd8586
Author: Alexander Zhogov <alexander.zhogov@intel.com>
Date:   Mon Oct 12 19:23:00 2020 +0300

    Azure CI: Fix nGraph ONNX

commit 3a2e33962c
Author: Alexander Zhogov <alexander.zhogov@intel.com>
Date:   Mon Oct 12 19:20:28 2020 +0300

    Azure CI: Disable steps in nGraph ONNX

commit 5835974fad
Author: azhogov <alexander.zhogov@intel.com>
Date:   Mon Oct 12 18:46:14 2020 +0300

    Azure CI: Add linux_ngraph_onnx.yml

* LRN Reference (#21)

* Disable failed tests on ia32

* Remove redundant broadcast from MVN ref

* Fix missed GatherND in opset_int_tbl + code style

* Remove one extra temporary buffer from MVN ref

* Merge master (#22)

* Leaky relu transformation refactor (#2640)

* Refactored LeakyRelu transformation

* Added unit test for LeakyRelu transformation + removed duplicate test function valued_const

* nGraph implementation of NMS-5 (without `evaluate()`) (#2651)

* Written nGraph NMS-5 without evaluate().

* Used NGRAPH_RTTI_DECLARATION.

* setupvars.sh: Updated setting pyenv error to warning. (#2663)

* Fix itt build (#2662)

* Loop-5 operation specification (#2291)

The Loop-5 operation specification

* Time tests improvements (#2642)

* Remove extra functions from run_timetest.py

* Add `log.debug` of raw and aggregated statistics in run_timetest.py

* Implement storing of models locally for test_timetest.py

* Fixed CVS-35316 (#2072)

* Extend MO for operation GatherND (#2540)

* Extend MO for operation GatherND

* Update documentation

* Rename GatherNd.py to gathernd.py

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Add hsigmoid op to ngraph (#2647)

* [IE CLDNN] Fixes for GatherTree and ReverseSequence  (#2660)

* ReorgYolo reference implementation (#2384)

* Align ReorgYolo to the spec (vector strides -> int stride)

* ReorgYolo ref impl

* ReorgYolo evaluate method

* ReorgYolo tests

* Tests update

* Style apply

* Add some coments

* Code refactor

* Comment update

* Style apply

* Build fix, mark evaluate as override

* Revert "Align ReorgYolo to the spec (vector strides -> int stride)"

* Use int_executable instead of evaluate

* Use char* instead of templates

* Code refactor

* Comment update

* Code review comment

* Add constructor aligned with spec

* Update shape validation

* Update attributes tests

* Add type_prop tests

* Update backend tests

* Add single layer tests

* Update the spec

* Remove wrong transformation test

* Add ReorgYolo to evaluates_map

* code style

Co-authored-by: Evgeny Lazarev <evgeny.lazarev@intel.com>
Co-authored-by: Vladimir Gavrilov <vladimir.gavrilov@intel.com>
Co-authored-by: Artyom Anokhov <artyom.anokhov@intel.com>
Co-authored-by: Andrey Somsikov <andrey.somsikov@intel.com>
Co-authored-by: Vitaliy Urusovskij <vitaliy.urusovskij@intel.com>
Co-authored-by: Anastasiya Ageeva <anastasiya.ageeva@intel.com>
Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
Co-authored-by: iliya mironov <iliya.mironov@intel.com>
Co-authored-by: Vladimir Paramuzov <vladimir.paramuzov@intel.com>
Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com>

* RegionYolo

* Apply review comments

* Merge remote-tracking branch 'upstream/master' into update_evaluates

# Conflicts:
#	ngraph/core/src/op/mvn.cpp
#	ngraph/test/backend/fused_op.in.cpp
#	ngraph/test/runtime/ie/unit_test.manifest
#	ngraph/test/runtime/interpreter/int_executable.hpp
#	ngraph/test/runtime/interpreter/opset_int_tbl.hpp
#	ngraph/test/runtime/interpreter/unit_test.manifest
#	ngraph/test/runtime/opset0_tbl.hpp

* Apply code style

* Apply comments

* Apply code style

* Fix RegionYolo evaluate redefinition

* Removed defines from evaluates map

* Apply code style

* Fix MVN ref

* rename select reference argument

* Fix code style

* Fix Fake Quantize references calculation (#24)

* Fix MVN ref

* Fix MVN & adding NMS

* Fix TI

* Temporary relax comparison threshold for FQ SLT

* Fix GPU LPT Tests

* Add explicit rounding mode seetting in FQ references

* Apply code style

* Rollback op_is test deletion

* Apply code style

* Fix merge conflict resolving issues

* Apply code style

Co-authored-by: Irina Efode <irina.efode@intel.com>
Co-authored-by: Anton Zaytsev <anton.zaytsev@intel.com>
Co-authored-by: Evgeny Lazarev <evgeny.lazarev@intel.com>
Co-authored-by: Vladimir Gavrilov <vladimir.gavrilov@intel.com>
Co-authored-by: Artyom Anokhov <artyom.anokhov@intel.com>
Co-authored-by: Andrey Somsikov <andrey.somsikov@intel.com>
Co-authored-by: Vitaliy Urusovskij <vitaliy.urusovskij@intel.com>
Co-authored-by: Anastasiya Ageeva <anastasiya.ageeva@intel.com>
Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
Co-authored-by: iliya mironov <iliya.mironov@intel.com>
Co-authored-by: Vladimir Paramuzov <vladimir.paramuzov@intel.com>
Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com>
2020-12-03 12:36:34 +03:00

857 lines
37 KiB
C++

//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <list>
#include <memory>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "gtest/gtest.h"
#include "ngraph/file_util.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/batch_norm.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/sqrt.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/pattern/op/branch.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/pattern/op/or.hpp"
#include "ngraph/pattern/op/skip.hpp"
#include "ngraph/pattern/op/true.hpp"
#include "util/matcher.hpp"
#include "util/test_tools.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
using namespace ngraph;
using namespace std;
static std::shared_ptr<Node> construct_constant_node(int n)
{
return op::Constant::create(element::Type_t::i32, Shape{}, {n});
}
static std::shared_ptr<pattern::op::Label> construct_variance_graph()
{
// construct varaiance
auto N = op::Constant::create(element::Type_t::f32, Shape{3}, {2, 2, 2});
auto input = std::make_shared<pattern::op::Label>(element::Type_t::f32, Shape{2, 3});
auto input_sq = std::make_shared<op::v1::Multiply>(input, input);
auto sum_input = std::make_shared<op::v1::ReduceSum>(
input, op::Constant::create(element::Type_t::i64, {1}, {0}));
auto square_sumed_input = std::make_shared<op::v1::Multiply>(sum_input, sum_input);
auto sum_squared_input = std::make_shared<op::v1::ReduceSum>(
input_sq, op::Constant::create(element::Type_t::i64, {1}, {0}));
auto avg_input_sum_sq = std::make_shared<op::v1::Divide>(square_sumed_input, N);
auto xmu = std::make_shared<op::v1::Subtract>(sum_squared_input, avg_input_sum_sq);
auto variance = std::make_shared<op::v1::Divide>(xmu, N);
auto variance_label =
std::make_shared<pattern::op::Label>(variance, nullptr, NodeVector{variance});
return variance_label;
}
static std::shared_ptr<pattern::op::Label> construct_mean_graph()
{
// construct mean;
auto input = std::make_shared<pattern::op::Label>(element::Type_t::f32, Shape{2, 3});
auto N = op::Constant::create(element::Type_t::f32, Shape{3}, {2, 2, 2});
auto sum_input1 = std::make_shared<op::v1::ReduceSum>(
input, op::Constant::create(element::Type_t::i64, {1}, {0}));
auto mean = std::make_shared<op::v1::Divide>(sum_input1, N);
auto mean_label = std::make_shared<pattern::op::Label>(mean, nullptr, NodeVector{mean});
return mean_label;
}
class TestGraphRewrite : public ngraph::pass::GraphRewrite
{
public:
void construct_multiply_by_one()
{
// pattern #1 : a * 1 = a
auto iconst1 = construct_constant_node(1);
auto pattern = std::make_shared<pattern::op::Label>(iconst1);
auto callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_multiply_by_one against "
<< m.get_match_root()->get_name();
NGRAPH_CHECK(m.get_match_root()->input_values().size() == 2);
auto pattern_map = m.get_pattern_map();
size_t const_node_index =
m.get_match_root()->input_value(0).get_node_shared_ptr() == pattern_map[pattern];
auto const_node = as_type_ptr<op::Constant>(
m.get_match_root()->input_value(const_node_index).get_node_shared_ptr());
auto second_node =
m.get_match_root()->input_value(const_node_index).get_node_shared_ptr();
NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name();
if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
pattern_map[pattern]->get_shape() != const_node->get_shape())
{
NGRAPH_DEBUG << "Operands' types and/or shape don't match";
return false;
}
auto const_values = const_node->get_vector<int32_t>();
bool all_ones =
std::all_of(begin(const_values), end(const_values), [](int e) { return e == 1; });
if (!all_ones)
{
NGRAPH_DEBUG << "Constant vector's values aren't equal to 1";
return false;
}
ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
return true;
};
auto m = make_shared<TestMatcher>(make_shared<op::v1::Multiply>(pattern, iconst1));
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, callback);
NGRAPH_SUPPRESS_DEPRECATED_END
}
void construct_add_zero()
{
// pattern #2 : a + 0 = a
auto iconst0 = construct_constant_node(0);
auto pattern = std::make_shared<pattern::op::Label>(iconst0);
auto callback = [pattern](pattern::Matcher& m) {
NGRAPH_DEBUG << "In a callback for construct_add_zero against "
<< m.get_match_root()->get_name();
NGRAPH_CHECK(m.get_match_root()->input_values().size() == 2);
auto pattern_map = m.get_pattern_map();
size_t const_node_index =
m.get_match_root()->input_value(0).get_node_shared_ptr() == pattern_map[pattern];
auto const_node = as_type_ptr<op::Constant>(
m.get_match_root()->input_value(const_node_index).get_node_shared_ptr());
auto second_node =
m.get_match_root()->input_value(const_node_index).get_node_shared_ptr();
NGRAPH_DEBUG << "second_node = " << second_node->get_name()
<< " , pattern = " << pattern_map[pattern]->get_name();
if (pattern_map[pattern]->get_element_type() != const_node->get_element_type() ||
pattern_map[pattern]->get_shape() != const_node->get_shape())
{
NGRAPH_DEBUG << "Operands' types and/or shape don't match";
return false;
}
auto const_values = const_node->get_vector<int>();
bool all_zeros =
std::all_of(begin(const_values), end(const_values), [](int e) { return e == 0; });
if (!all_zeros)
{
NGRAPH_DEBUG << "Constant vector's values aren't equal to 0";
return false;
}
ngraph::replace_node(m.get_match_root(), pattern_map[pattern]);
return true;
};
auto add = make_shared<op::v1::Add>(pattern, iconst0);
auto m = make_shared<TestMatcher>(add);
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(m, callback);
NGRAPH_SUPPRESS_DEPRECATED_END
}
TestGraphRewrite()
: GraphRewrite()
{
construct_multiply_by_one();
construct_add_zero();
}
};
static void run_passes(pass::Manager& pass_manager,
shared_ptr<Node> graph,
std::vector<shared_ptr<op::Parameter>> parms)
{
auto func = make_shared<Function>(graph, ParameterVector{parms});
pass_manager.run_passes(func);
}
TEST(pattern, graph_rewrite)
{
Shape shape{};
pass::Manager pass_manager;
pass_manager.register_pass<TestGraphRewrite>();
{
auto a = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto b = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto c = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto iconst0 = construct_constant_node(0);
auto graph_a = make_shared<op::v1::Add>(a, iconst0);
auto graph_b = make_shared<op::v1::Add>(b, iconst0);
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, graph_a, c, graph_b},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_TRUE(graph_a->get_output_target_inputs(0).empty());
ASSERT_TRUE(graph_b->get_output_target_inputs(0).empty());
auto expected = ngraph::NodeVector{a, b, a, c, b};
ASSERT_TRUE(count_ops_of_type<op::v1::Add>(f) == 0);
}
{
auto a = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto b = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto iconst0 = construct_constant_node(0);
auto sum = make_shared<op::v1::Add>(a, iconst0);
auto graph = make_shared<op::v1::Add>(b, sum);
run_passes(pass_manager, graph, {a, b});
ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
ASSERT_TRUE(sum->output(0)
.get_target_inputs()
.empty()); // graph's input is removed from sum's target inptus
ASSERT_TRUE(a->get_output_target_inputs(0).count(
graph->input(1))); // a's output feeds into graph's input
}
{
auto a = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto b = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto iconst1 = construct_constant_node(1);
auto mul = make_shared<op::v1::Multiply>(a, iconst1);
auto graph = make_shared<op::v1::Add>(b, mul);
run_passes(pass_manager, graph, {a, b});
ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
ASSERT_TRUE(mul->output(0)
.get_target_inputs()
.empty()); // graph's input is removed from sum's target inputs
ASSERT_TRUE(a->get_output_target_inputs(0).count(
graph->input(1))); // a's output feeds into graph's input
}
{
auto a = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto b = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto iconst1 = construct_constant_node(1);
auto multiply =
make_shared<op::v1::Multiply>(make_shared<op::v1::Multiply>(a, iconst1), iconst1);
multiply = make_shared<op::v1::Multiply>(make_shared<op::v1::Multiply>(multiply, iconst1),
iconst1);
auto graph = make_shared<op::v1::Add>(multiply, b);
run_passes(pass_manager, graph, {a, b});
ASSERT_EQ(graph->input_value(0).get_node_shared_ptr(), a);
ASSERT_EQ(graph->input_value(0), a->output(0)); // graph's input points to a's output
ASSERT_TRUE(a->get_output_target_inputs(0).count(
graph->input(0))); // a's output feeds into graph's input
}
{
auto a = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto b = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto iconst0 = construct_constant_node(0);
auto iconst1 = construct_constant_node(1);
auto mul = make_shared<op::v1::Multiply>(make_shared<op::v1::Add>(a, iconst0), iconst1);
auto graph = make_shared<op::v1::Add>(b, make_shared<op::v1::Add>(iconst0, mul));
run_passes(pass_manager, graph, {a, b});
ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
ASSERT_TRUE(a->get_output_target_inputs(0).count(
graph->input(1))); // a's output feeds into graph's input
}
{
auto a = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto b = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto iconst1 = construct_constant_node(1);
auto mul =
make_shared<op::v1::Multiply>(iconst1, make_shared<op::v1::Multiply>(iconst1, a));
mul = make_shared<op::v1::Multiply>(iconst1, make_shared<op::v1::Multiply>(iconst1, mul));
auto graph = make_shared<op::v1::Add>(b, mul);
run_passes(pass_manager, graph, {a, b});
ASSERT_EQ(graph->input_value(1).get_node_shared_ptr(), a);
ASSERT_EQ(graph->input_value(1), a->output(0)); // graph's input points to a's output
ASSERT_TRUE(a->get_output_target_inputs(0).count(
graph->input(1))); // a's output feeds into graph's input
}
}
TEST(pattern, matcher)
{
Shape shape{};
auto a = make_shared<op::Parameter>(element::Type_t::i32, shape);
TestMatcher n;
ASSERT_TRUE(n.match(a, a));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
auto abs = make_shared<op::Abs>(a);
auto any = std::make_shared<pattern::op::Skip>(a);
ASSERT_TRUE(n.match(any, abs));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{abs, a}));
auto false_pred = [](std::shared_ptr<Node> /* no */) { return false; };
auto any_false = std::make_shared<pattern::op::Skip>(a, false_pred);
ASSERT_TRUE(n.match(any_false, a));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a, a}));
auto pattern = std::make_shared<pattern::op::Label>(a);
ASSERT_TRUE(n.match(pattern, a));
ASSERT_EQ(n.get_pattern_map()[pattern], a);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{a}));
auto pattern_false = std::make_shared<pattern::op::Label>(a, false_pred);
ASSERT_FALSE(n.match(pattern_false, a));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
auto b = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto is_bea = [](std::shared_ptr<Node> node) -> bool {
return op::is_binary_elementwise_arithmetic(node);
};
auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
auto add_ab = std::make_shared<op::v1::Add>(a, b);
ASSERT_TRUE(n.match(bea, add_ab));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_ab, a, b}));
ASSERT_TRUE(n.match(bea, std::make_shared<op::v1::Add>(b, a)));
auto bea_false = std::make_shared<pattern::op::Any>(a, false_pred, NodeVector{a, b});
ASSERT_FALSE(n.match(bea_false, std::make_shared<op::v1::Add>(a, b)));
auto add_abs_b = std::make_shared<op::v1::Add>(abs, b);
auto bea_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs});
ASSERT_TRUE(n.match(bea_any_of, add_abs_b));
auto add_b_abs = std::make_shared<op::v1::Add>(b, abs);
ASSERT_TRUE(n.match(bea_any_of, add_b_abs));
auto bea_any_of_label =
std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea_any_of});
ASSERT_TRUE(n.match(bea_any_of_label, add_b_abs));
ASSERT_EQ(n.get_pattern_map()[bea_any_of_label], add_b_abs);
auto abs_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{abs});
auto bea_label_any_of = std::make_shared<pattern::op::AnyOf>(a, is_bea, NodeVector{abs_label});
ASSERT_TRUE(n.match(bea_label_any_of, add_b_abs));
ASSERT_EQ(n.get_pattern_map()[abs_label], abs);
auto bea_label = std::make_shared<pattern::op::Label>(a, nullptr, NodeVector{bea});
auto ab = std::make_shared<op::v1::Add>(a, b);
ASSERT_TRUE(n.match(bea_label, ab));
ASSERT_EQ(n.get_pattern_map()[bea_label], ab);
auto d = make_shared<op::Parameter>(element::Type_t::i32, shape);
ASSERT_FALSE(n.match(d, b));
ASSERT_FALSE(
n.match(std::make_shared<op::v1::Add>(abs, b), std::make_shared<op::v1::Add>(b, b)));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{}));
auto add_absb = std::make_shared<op::v1::Add>(abs, b);
ASSERT_TRUE(n.match(std::make_shared<op::v1::Add>(any, b), add_absb));
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, a, b}));
ASSERT_TRUE(n.match(std::make_shared<op::v1::Add>(pattern, b), add_absb));
ASSERT_EQ(n.get_pattern_map()[pattern], abs);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
ASSERT_TRUE(n.match(std::make_shared<op::v1::Add>(b, pattern), add_absb));
ASSERT_EQ(n.get_pattern_map()[pattern], abs);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add_absb, abs, b}));
auto c = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto mul_add_absb = std::make_shared<op::v1::Multiply>(c, add_absb);
ASSERT_TRUE(
n.match(std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(b, pattern)),
mul_add_absb));
ASSERT_EQ(n.get_pattern_map()[pattern], abs);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, b}));
ASSERT_TRUE(
n.match(std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(any, b)),
mul_add_absb)); // nested any
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_add_absb, c, add_absb, abs, a, b}));
ASSERT_TRUE(
n.match(std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(any, b)),
std::make_shared<op::v1::Multiply>(std::make_shared<op::v1::Add>(b, abs),
c))); // permutations w/ any
auto mul_c_add_ab = make_shared<op::v1::Multiply>(c, add_ab);
ASSERT_TRUE(
n.match(std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(any_false, b)),
std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(a, b)))); //
// nested any
ASSERT_TRUE(
n.match(std::make_shared<op::v1::Multiply>(c, std::make_shared<op::v1::Add>(any_false, b)),
mul_c_add_ab)); // permutations w/ any_false
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{mul_c_add_ab, c, add_ab, a, a, b}));
auto iconst1_0 = construct_constant_node(1);
auto iconst1_1 = construct_constant_node(1);
ASSERT_TRUE(n.match(make_shared<op::v1::Multiply>(pattern, iconst1_0),
make_shared<op::v1::Multiply>(a, iconst1_1))); // different iconst
ASSERT_EQ(n.get_pattern_map()[pattern], a);
auto fconst1_0 = op::Constant::create(element::Type_t::f32, shape, {1});
auto patternf = std::make_shared<pattern::op::Label>(fconst1_0);
ASSERT_TRUE(n.match(make_shared<op::v1::Multiply>(patternf, fconst1_0),
make_shared<op::v1::Multiply>(a, iconst1_1))); // different iconst
// Subgraph labels
auto add = std::make_shared<op::v1::Add>(a, b);
auto label = std::make_shared<pattern::op::Label>(add, nullptr, NodeVector{add});
ASSERT_TRUE(n.match(label, add));
ASSERT_EQ(n.get_pattern_map()[label], add);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{add, add, a, b}));
ASSERT_FALSE(n.match(label, std::make_shared<op::v1::Subtract>(a, b)));
ASSERT_TRUE(n.match(make_shared<op::Abs>(label), make_shared<op::Abs>(add)));
ASSERT_EQ(n.get_pattern_map()[label], add);
// Correct argument order
ASSERT_FALSE(n.match(make_shared<op::v1::Subtract>(b, a), make_shared<op::v1::Subtract>(a, b)));
auto aab = make_shared<op::v1::Multiply>(a, make_shared<op::v1::Subtract>(a, b));
auto paab = make_shared<op::v1::Multiply>(pattern, make_shared<op::v1::Subtract>(pattern, b));
ASSERT_TRUE(n.match(paab, aab));
auto aba = make_shared<op::v1::Multiply>(a, make_shared<op::v1::Subtract>(b, a));
ASSERT_FALSE(n.match(paab, aba));
auto paba = make_shared<op::v1::Multiply>(pattern, make_shared<op::v1::Subtract>(b, pattern));
ASSERT_FALSE(n.match(paba, aab));
// Correlations
auto label1 = std::make_shared<pattern::op::Label>(a);
auto tmp = std::make_shared<op::v1::Add>(label1, b);
auto label2 = std::make_shared<pattern::op::Label>(tmp, nullptr, NodeVector{tmp});
auto sub_label1 = std::make_shared<op::v1::Subtract>(label1, label2);
auto sub_add = std::make_shared<op::v1::Subtract>(a, add);
ASSERT_TRUE(n.match(sub_label1, sub_add));
ASSERT_EQ(n.get_pattern_map()[label1], a);
ASSERT_EQ(n.get_pattern_map()[label2], add);
ASSERT_EQ(n.get_matched_nodes(), (NodeVector{sub_add, a, add, add, a, b}));
ASSERT_FALSE(n.match(sub_label1, std::make_shared<op::v1::Subtract>(add, a)));
auto add_label1 = std::make_shared<op::v1::Add>(label1, label2);
ASSERT_TRUE(n.match(add_label1, std::make_shared<op::v1::Add>(add, a)));
ASSERT_EQ(n.get_pattern_map()[label1], a);
ASSERT_EQ(n.get_pattern_map()[label2], add);
// Or
ASSERT_TRUE(
n.match(std::make_shared<pattern::op::Or>(OutputVector{
std::make_shared<op::v1::Add>(a, b), std::make_shared<op::v1::Subtract>(a, b)}),
std::make_shared<op::v1::Add>(a, b)));
ASSERT_TRUE(
n.match(std::make_shared<pattern::op::Or>(OutputVector{
std::make_shared<op::v1::Add>(a, b), std::make_shared<op::v1::Subtract>(a, b)}),
std::make_shared<op::v1::Subtract>(a, b)));
// Branch
{
auto branch = std::make_shared<pattern::op::Branch>();
auto star = std::make_shared<pattern::op::Or>(
OutputVector{branch, std::make_shared<pattern::op::True>()});
auto pattern = std::make_shared<op::v1::Add>(star, star);
branch->set_destination(pattern);
auto arg = std::make_shared<op::v1::Add>(std::make_shared<op::v1::Add>(a, b),
std::make_shared<op::v1::Add>(b, a));
ASSERT_TRUE(n.match(pattern, std::make_shared<op::v1::Add>(arg, a)));
ASSERT_EQ(n.get_matched_nodes().size(), 4);
}
// strict mode
{
TestMatcher sm(Output<Node>{}, "TestMatcher", true);
// exact shape and type
auto scalar_param = make_shared<op::Parameter>(element::Type_t::i32, Shape{});
auto label_dynamic_shape =
make_shared<pattern::op::Label>(element::Type_t::i32, PartialShape::dynamic());
auto param = make_shared<op::Parameter>(element::Type_t::f32, Shape{});
ASSERT_TRUE(sm.match(label_dynamic_shape, scalar_param));
// wrong type
auto scalar_param_wrong_type = make_shared<op::Parameter>(element::Type_t::f32, Shape{});
ASSERT_FALSE(sm.match(label, scalar_param_wrong_type));
// dynamic dimension
auto label_dynamic_dimension = make_shared<pattern::op::Label>(
element::Type_t::i32, PartialShape{Dimension::dynamic()});
auto vector_param = make_shared<op::Parameter>(element::Type_t::i32, Shape{10});
ASSERT_TRUE(sm.match(label_dynamic_dimension, vector_param));
// dynamic type
auto label_dynamic_type = make_shared<pattern::op::Label>(
element::Type_t::dynamic, PartialShape{Dimension::dynamic()});
ASSERT_TRUE(sm.match(label_dynamic_type, vector_param));
}
}
TEST(pattern, mean)
{
// construct mean
TestMatcher n;
auto input = std::make_shared<op::Parameter>(element::Type_t::f32, Shape{2, 3});
auto N = op::Constant::create(element::Type_t::f32, Shape{3}, {2, 2, 2});
auto sum_input1 = std::make_shared<op::v1::ReduceSum>(
input, op::Constant::create(element::Type_t::i64, {1}, {0}));
auto mean = std::make_shared<op::v1::Divide>(sum_input1, N);
auto mean_graph = construct_mean_graph();
ASSERT_TRUE(n.match(mean_graph, mean));
ASSERT_EQ(n.get_pattern_map()[mean_graph], mean);
}
TEST(pattern, variance)
{
// construct variance
TestMatcher n;
auto N = op::Constant::create(element::Type_t::f32, Shape{3}, {2, 2, 2});
auto input = std::make_shared<pattern::op::Label>(element::Type_t::f32, Shape{2, 3});
auto input_sq = std::make_shared<op::v1::Multiply>(input, input);
auto sum_input = std::make_shared<op::v1::ReduceSum>(
input, op::Constant::create(element::Type_t::i64, {1}, {0}));
auto square_sumed_input = std::make_shared<op::v1::Multiply>(sum_input, sum_input);
auto sum_squared_input = std::make_shared<op::v1::ReduceSum>(
input_sq, op::Constant::create(element::Type_t::i64, {1}, {0}));
auto avg_input_sum_sq = std::make_shared<op::v1::Divide>(square_sumed_input, N);
auto xmu = std::make_shared<op::v1::Subtract>(sum_squared_input, avg_input_sum_sq);
auto variance = std::make_shared<op::v1::Divide>(xmu, N);
auto var_graph = construct_variance_graph();
ASSERT_TRUE(n.match(var_graph, variance));
ASSERT_EQ(n.get_pattern_map()[var_graph], variance);
}
TEST(pattern, previous_matches)
{
using ngraph::pattern::Matcher;
Shape shape{};
Matcher::PatternMap previous_matches;
auto a = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto b = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto pattern = std::make_shared<pattern::op::Label>(b);
auto abs = make_shared<op::Abs>(a);
auto add = make_shared<op::v1::Add>(abs, b);
{
Matcher n(make_shared<op::v1::Add>(pattern, b));
ASSERT_TRUE(n.match(add, previous_matches));
ASSERT_EQ(n.get_pattern_map()[pattern], abs);
}
{
Matcher n(make_shared<op::v1::Add>(pattern, b));
previous_matches.insert(std::make_pair(pattern, a));
ASSERT_FALSE(n.match(add, previous_matches));
}
}
TEST(pattern, test_sort)
{
using ngraph::pattern::Matcher;
Shape shape{};
auto a = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto b = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto abs1 = make_shared<op::Abs>(a);
auto abs2 = make_shared<op::Abs>(b);
shared_ptr<Node> add = make_shared<op::v1::Add>(abs1, abs2);
auto pa = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto pb = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto pabs1 = make_shared<op::Abs>(pa);
auto pabs1_label = std::make_shared<pattern::op::Label>(pabs1);
auto pabs2 = make_shared<op::Abs>(b);
shared_ptr<Node> padd = make_shared<op::v1::Add>(pabs1_label, pabs2);
{
Matcher n1(padd);
ASSERT_TRUE(n1.match(add));
auto r1 = n1.get_pattern_map()[pabs1_label];
ASSERT_TRUE(n1.match(add));
ASSERT_EQ(r1, n1.get_pattern_map()[pabs1_label]);
}
}
TEST(pattern, recurrent_pattern)
{
using ngraph::pattern::RecurrentMatcher;
Shape shape{};
ngraph::pattern::Matcher::PatternMap previous_matches;
auto a = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto b = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto rpattern = std::make_shared<pattern::op::Label>(b);
auto iconst0 = construct_constant_node(0);
auto abs = make_shared<op::Abs>(a);
auto add1 = make_shared<op::v1::Add>(iconst0, b);
auto add2 = make_shared<op::v1::Add>(iconst0, add1);
auto add3 = make_shared<op::v1::Add>(iconst0, add2);
auto padd = make_shared<op::v1::Add>(iconst0, rpattern);
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
RecurrentMatcher rm(padd, rpattern, empty_correlated_matches);
ASSERT_TRUE(rm.match(add3));
ASSERT_EQ(rm.get_number_of_bound_labels(), 3);
auto recurrent_matches = rm.get_bound_nodes_for_pattern(rpattern);
ASSERT_EQ(recurrent_matches.at(0), add2);
ASSERT_EQ(recurrent_matches.at(1), add1);
ASSERT_EQ(recurrent_matches.at(2), b);
// Multiple labels in a reccuring pattern
auto iconst1 = construct_constant_node(1);
auto iconst_label = std::make_shared<pattern::op::Label>(iconst1, nullptr, NodeVector{iconst1});
auto add2_2 = make_shared<op::v1::Add>(iconst1, add1);
auto add3_2 = make_shared<op::v1::Add>(iconst0, add2_2);
auto padd2 = make_shared<op::v1::Add>(iconst_label, rpattern);
RecurrentMatcher rm2(padd2, rpattern, empty_correlated_matches);
ASSERT_TRUE(rm2.match(add3_2));
ASSERT_EQ(rm2.get_number_of_bound_labels(), 4);
recurrent_matches = rm2.get_bound_nodes_for_pattern(rpattern);
ASSERT_EQ(recurrent_matches.at(0), add2_2);
ASSERT_EQ(recurrent_matches.at(1), add1);
ASSERT_EQ(recurrent_matches.at(2), b);
auto iconst_matches = rm2.get_bound_nodes_for_pattern(iconst_label);
ASSERT_EQ(iconst_matches.at(0), iconst0);
ASSERT_EQ(iconst_matches.at(1), iconst1);
ASSERT_EQ(iconst_matches.at(2), iconst0);
// Non-matching correlated labels
std::set<std::shared_ptr<pattern::op::Label>> correlated_matches;
correlated_matches.insert(iconst_label);
RecurrentMatcher rm3(padd2, rpattern, correlated_matches);
ASSERT_TRUE(rm3.match(add3_2));
ASSERT_EQ(rm3.get_number_of_bound_labels(), 4);
iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
ASSERT_EQ(iconst_matches.size(), 1);
ASSERT_EQ(iconst_matches.at(0), iconst0);
// Matching correlated labels and
// testing if RecurrentMatcher can be reused for different nodes
ASSERT_TRUE(rm3.match(add3));
ASSERT_EQ(rm3.get_number_of_bound_labels(), 4);
recurrent_matches = rm3.get_bound_nodes_for_pattern(rpattern);
ASSERT_EQ(recurrent_matches.at(0), add2);
ASSERT_EQ(recurrent_matches.at(1), add1);
ASSERT_EQ(recurrent_matches.at(2), b);
iconst_matches = rm3.get_bound_nodes_for_pattern(iconst_label);
ASSERT_EQ(iconst_matches.at(0), iconst0);
ASSERT_EQ(iconst_matches.at(1), iconst0);
ASSERT_EQ(iconst_matches.at(2), iconst0);
}
class TestRecurrentGraphRewrite : public ngraph::pass::RecurrentGraphRewrite
{
public:
void construct_recurrent_add()
{
Shape shape{};
auto iconst0 = construct_constant_node(0);
auto iconst_label =
std::make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
auto rpattern = std::make_shared<pattern::op::Label>(element::Type_t::i32, shape);
auto padd = make_shared<op::v1::Add>(iconst_label, rpattern);
auto callback = [iconst_label, rpattern](pattern::RecurrentMatcher& rm) {
NGRAPH_DEBUG << "In a callback for construct_recurrent_add against "
<< rm.get_match_root()->get_name();
auto iconst_matches = rm.get_bound_nodes_for_pattern(iconst_label);
auto is_iconst_zero = [](std::shared_ptr<Node> n) {
bool result = ngraph::is_zero(n);
NGRAPH_DEBUG << n->get_name() << " is " << (result ? " a zero " : " not a zero");
return ngraph::is_zero(n);
};
bool are_all_iconst_zeros =
std::all_of(iconst_matches.begin(), iconst_matches.end(), is_iconst_zero);
if (!are_all_iconst_zeros)
{
return false;
}
auto number_of_adds = rm.get_number_of_recurrent_matches();
// replace the topmost add with the seed (i.e. the first parameter to add)
// matches are added in reverse order (i.e. the first match is the topmost node)
auto arg = rm.get_bound_nodes_for_pattern(rpattern).at(number_of_adds - 1);
NGRAPH_DEBUG << "Replacing " << rm.get_match_root()->get_name() << " with "
<< arg->get_name();
ngraph::replace_node(rm.get_match_root(), arg);
return true;
};
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
auto rm = make_shared<pattern::RecurrentMatcher>(padd, rpattern, empty_correlated_matches);
NGRAPH_SUPPRESS_DEPRECATED_START
this->add_matcher(rm, callback);
NGRAPH_SUPPRESS_DEPRECATED_END
}
TestRecurrentGraphRewrite()
: RecurrentGraphRewrite()
{
construct_recurrent_add();
}
};
TEST(pattern, recurrent_graph_rewrite)
{
Shape shape{};
pass::Manager pass_manager;
pass_manager.register_pass<TestRecurrentGraphRewrite>();
{
auto a = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto iconst0 = construct_constant_node(0);
auto add_a1 = make_shared<op::v1::Add>(a, iconst0);
auto add_a2 = make_shared<op::v1::Add>(add_a1, iconst0);
auto add_a3 = make_shared<op::v1::Add>(add_a2, iconst0);
auto abs_add_a3 = std::make_shared<op::Abs>(add_a3);
auto b = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto add_b1 = make_shared<op::v1::Add>(b, iconst0);
auto add_b2 = make_shared<op::v1::Add>(add_b1, iconst0);
auto abs_add_b2 = std::make_shared<op::Abs>(add_b2);
auto graph = make_shared<op::v1::Multiply>(abs_add_a3, abs_add_b2);
auto f = std::make_shared<Function>(ngraph::NodeVector{graph}, ParameterVector{a, b});
pass_manager.run_passes(f);
auto left_abs = graph->input_value(0).get_node_shared_ptr();
auto add_a = left_abs->input_value(0).get_node_shared_ptr();
ASSERT_EQ(add_a, a);
auto right_abs = graph->input_value(1).get_node_shared_ptr();
auto add_b = right_abs->input_value(0).get_node_shared_ptr();
ASSERT_EQ(add_b, b);
}
}
TEST(pattern, label_on_skip)
{
Shape shape{2, 2};
auto a = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto b = make_shared<op::Parameter>(element::Type_t::i32, Shape{});
auto iconst = ngraph::make_zero(element::Type_t::i32, Shape{});
auto label = std::make_shared<pattern::op::Label>(iconst);
auto const_label =
std::make_shared<pattern::op::Label>(iconst, ngraph::is_zero, NodeVector{iconst});
auto bcst_pred = [](std::shared_ptr<Node> n) {
return as_type_ptr<op::v1::Broadcast>(n) != nullptr;
};
auto shape_const = op::Constant::create(element::Type_t::u64, Shape{shape.size()}, shape);
auto axes_const = op::Constant::create(element::Type_t::u8, Shape{}, {0});
auto bcst = std::make_shared<pattern::op::Skip>(
OutputVector{const_label, shape_const, axes_const}, bcst_pred);
auto bcst_label = std::make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
auto matcher = std::make_shared<pattern::Matcher>(
std::make_shared<op::v1::Multiply>(label, bcst_label), "label_on_skip");
auto const_broadcast = make_shared<op::v1::Broadcast>(iconst, shape_const);
std::shared_ptr<Node> mul = std::make_shared<op::v1::Multiply>(a, const_broadcast);
std::shared_ptr<Node> mul_scalar = std::make_shared<op::v1::Multiply>(b, iconst);
ASSERT_TRUE(matcher->match(mul));
ASSERT_EQ(matcher->get_pattern_map()[bcst_label], const_broadcast);
ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
ASSERT_EQ(matcher->get_pattern_map()[label], a);
ASSERT_TRUE(matcher->match(mul_scalar));
ASSERT_EQ(matcher->get_pattern_map()[bcst_label], iconst);
ASSERT_EQ(matcher->get_pattern_map()[const_label], iconst);
ASSERT_EQ(matcher->get_pattern_map()[label], b);
}
TEST(pattern, is_contained_match)
{
Shape shape{};
auto a = make_shared<op::Parameter>(element::Type_t::i32, shape);
auto absn = make_shared<op::Abs>(a);
TestMatcher n;
auto label_a = std::make_shared<pattern::op::Label>(a);
auto label_abs = make_shared<op::Abs>(a);
ASSERT_TRUE(n.match(label_abs, absn));
auto result_absn = make_shared<op::Result>(absn);
ASSERT_TRUE(n.is_contained_match());
auto absn2 = make_shared<op::Abs>(absn);
auto result_absn2 = make_shared<op::Result>(absn2);
auto label_abs2 = make_shared<op::Abs>(label_abs);
ASSERT_TRUE(n.match(label_abs2, absn2));
ASSERT_FALSE(n.is_contained_match());
}
TEST(pattern, wrap_type)
{
auto a = make_shared<op::Parameter>(element::Type_t::f32, Shape{1, 3, 64, 64});
auto b = make_shared<op::Abs>(a);
auto c = make_shared<op::Relu>(a);
auto mul1 =
make_shared<op::v1::Multiply>(a, op::Constant::create(element::Type_t::f32, Shape{}, {1}));
auto mul2 =
make_shared<op::v1::Multiply>(op::Constant::create(element::Type_t::f32, Shape{}, {1}), a);
{
auto m = pattern::wrap_type<op::Abs>();
auto matcher = std::make_shared<pattern::Matcher>(m, "AbsMatcher");
ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(b)));
ASSERT_EQ(matcher->get_matched_nodes().size(), 1);
ASSERT_EQ(matcher->get_matched_nodes()[0], b);
ASSERT_EQ(matcher->get_pattern_map().count(m), 1);
ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
}
{
auto m1 = pattern::wrap_type<op::Parameter>();
auto m2 = pattern::wrap_type<op::Abs>({m1});
auto matcher = std::make_shared<pattern::Matcher>(m2, "ParamAbsMatcher");
ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(b)));
ASSERT_EQ(matcher->get_matched_nodes().size(), 2);
ASSERT_EQ(matcher->get_pattern_map().count(m1), 1);
ASSERT_EQ(matcher->get_pattern_map().count(m2), 1);
ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
}
{
auto m1 = pattern::wrap_type<op::v1::Multiply>(
{pattern::any_input(), pattern::wrap_type<op::Constant>()});
auto matcher = std::make_shared<pattern::Matcher>(m1, "MultiplyMatcher");
ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul1)));
ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));
}
{
auto m1 = pattern::wrap_type<op::v1::Multiply>(
{pattern::wrap_type<op::Constant>(), pattern::any_input()});
auto matcher = std::make_shared<pattern::Matcher>(m1, "MultiplyMatcher");
ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul1)));
ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));
}
}