WrapType Improvements (#4040)

* Extended WrapType to consume multiple types; Added variadic wrap_type support

* Updated transformations to use wrap_type

* Fix BatchNormDecomposition

* Added tests
This commit is contained in:
Gleb Kazantaev
2021-02-02 09:27:05 +03:00
committed by GitHub
parent 3a86b3a17e
commit cca0d568e0
12 changed files with 131 additions and 115 deletions

View File

@@ -19,7 +19,6 @@ namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API BatchNormDecomposition;
class TRANSFORMATIONS_API BatchNormV5Decomposition;
} // namespace pass
} // namespace ngraph
@@ -29,9 +28,3 @@ public:
NGRAPH_RTTI_DECLARATION;
BatchNormDecomposition();
};
class ngraph::pass::BatchNormV5Decomposition: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
BatchNormV5Decomposition();
};

View File

@@ -83,7 +83,6 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
auto common_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
common_fusions->add_matcher<ngraph::pass::ConvertScatterElementsToScatter>();
common_fusions->add_matcher<ngraph::pass::DepthToSpaceFusion>();
//common_fusions->add_matcher<ngraph::pass::MishFusion>();
common_fusions->add_matcher<ngraph::pass::SoftPlusFusion>();
common_fusions->add_matcher<ngraph::pass::SoftPlusToMishFusion>();
common_fusions->add_matcher<ngraph::pass::SwishFusion>();
@@ -115,7 +114,6 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
decomp->add_matcher<ngraph::pass::ConvertDepthToSpace>();
decomp->add_matcher<ngraph::pass::ConvertSpaceToDepth>();
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
decomp->add_matcher<ngraph::pass::BatchNormV5Decomposition>();
decomp->set_name("ngraph::pass::CommonDecompositions");
// CF is required after all decompositions

View File

@@ -164,7 +164,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvAddFusion, "ConvAddFusion", 0);
ngraph::pass::ConvAddFusion::ConvAddFusion() {
MATCHER_SCOPE(ConvAddFusion);
auto conv = ngraph::pattern::wrap_type<op::ConvolutionIE>(pattern::consumers_count(1));
auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, std::make_shared<pattern::op::Label>()});
auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, pattern::any_input()});
matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) {
return conv_callback<op::ConvolutionIE>(m);
@@ -179,7 +179,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvMultiplyFusion, "ConvMultiplyFusion", 0
ngraph::pass::ConvMultiplyFusion::ConvMultiplyFusion() {
MATCHER_SCOPE(ConvMultiplyFusion);
auto conv = ngraph::pattern::wrap_type<op::ConvolutionIE>(pattern::consumers_count(1));
auto add = ngraph::pattern::wrap_type<opset1::Multiply>({conv, std::make_shared<pattern::op::Label>()});
auto add = ngraph::pattern::wrap_type<opset1::Multiply>({conv, pattern::any_input()});
matcher_pass_callback callback = [](ngraph::pattern::Matcher &m) {
return conv_callback<op::ConvolutionIE>(m);
@@ -194,7 +194,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::DeconvAddFusion, "DeconvAddFusion", 0);
ngraph::pass::DeconvAddFusion::DeconvAddFusion() {
MATCHER_SCOPE(DeconvAddFusion);
auto conv = ngraph::pattern::wrap_type<op::DeconvolutionIE>(pattern::consumers_count(1));
auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, std::make_shared<pattern::op::Label>()});
auto add = ngraph::pattern::wrap_type<opset1::Add>({conv, pattern::any_input()});
matcher_pass_callback callback = [](ngraph::pattern::Matcher &m){
return conv_callback<op::DeconvolutionIE>(m);

View File

@@ -19,7 +19,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormDecomposition, "BatchNormDecomposi
ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
MATCHER_SCOPE(BatchNormDecomposition);
auto bn = pattern::wrap_type<opset1::BatchNormInference>({
auto bn = pattern::wrap_type<opset1::BatchNormInference, opset5::BatchNormInference>({
pattern::any_input(pattern::has_static_rank()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
@@ -28,20 +28,30 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
});
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) {
auto m_bn = dynamic_pointer_cast<opset1::BatchNormInference>(m.get_match_root());
if (!m_bn) {
auto m_bn = m.get_match_root();
Output<Node> m_input, m_gamma, m_beta, m_mean, m_var;
double eps;
if (auto m_bn_v1 = dynamic_pointer_cast<opset1::BatchNormInference>(m_bn)) {
m_gamma = m_bn_v1->input_value(0);
m_beta = m_bn_v1->input_value(1);
m_input = m_bn_v1->input_value(2);
m_mean = m_bn_v1->input_value(3);
m_var = m_bn_v1->input_value(4);
eps = m_bn_v1->get_eps_value();
} else if (auto m_bn_v5 = dynamic_pointer_cast<opset5::BatchNormInference>(m_bn)) {
m_input = m_bn_v5->input_value(0);
m_gamma = m_bn_v5->input_value(1);
m_beta = m_bn_v5->input_value(2);
m_mean = m_bn_v5->input_value(3);
m_var = m_bn_v5->input_value(4);
eps = m_bn_v5->get_eps_value();
} else {
return false;
}
auto m_gamma = m_bn->input_value(0);
auto m_beta = m_bn->input_value(1);
auto m_input = m_bn->input_value(2);
auto m_mean = m_bn->input_value(3);
auto m_var = m_bn->input_value(4);
const auto& input_type = m_input.get_element_type();
// scale_add = variance + eps
auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {eps}));
// scale = sqrt(variance + eps)
auto scale = make_shared<opset5::Sqrt>(scale_add);
// Divide `gamma` by `sqrt(variance + eps)`
@@ -79,67 +89,3 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
this->register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormV5Decomposition, "BatchNormDecomposition", 5);
// TODO: this pass will be unified with BatchNormDecomposition pass
ngraph::pass::BatchNormV5Decomposition::BatchNormV5Decomposition() {
MATCHER_SCOPE(BatchNormV5Decomposition);
auto bn = pattern::wrap_type<opset5::BatchNormInference>({
pattern::any_input(pattern::has_static_rank()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape())
});
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) {
auto m_bn = dynamic_pointer_cast<opset5::BatchNormInference>(m.get_match_root());
if (!m_bn) {
return false;
}
auto m_input = m_bn->input_value(0);
auto m_gamma = m_bn->input_value(1);
auto m_beta = m_bn->input_value(2);
auto m_mean = m_bn->input_value(3);
auto m_var = m_bn->input_value(4);
const auto& input_type = m_input.get_element_type();
// scale_add = variance + eps
auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
// scale = sqrt(variance + eps)
auto scale = make_shared<opset5::Sqrt>(scale_add);
// Divide `gamma` by `sqrt(variance + eps)`
auto gamma_div_scale = std::make_shared<opset5::Divide>(m_gamma, scale);
int64_t dims_to_add = m_input.get_partial_shape().rank().get_length() - 2;
// TODO: instead of getting full shape we can concatenate sequence of ones with ShapeOf
Shape input_aligned_shape = m_gamma.get_shape();
for (int64_t i = 0; i < dims_to_add; ++i)
input_aligned_shape.push_back(1);
auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);
auto gamma_div_scale_aligned = make_shared<opset5::Reshape>(gamma_div_scale, new_shape, true);
auto beta_aligned = make_shared<opset5::Reshape>(m_beta, new_shape, true);
auto mean_aligned = make_shared<opset5::Reshape>(m_mean, new_shape, true);
// input_sub_mean = input - mean
auto input_sub_mean = register_new_node<opset5::Subtract>(m_input, mean_aligned);
// Multiply `input - mean` and `gamma / sqrt(variance + eps)`
auto mul = std::make_shared<opset5::Multiply>(input_sub_mean, gamma_div_scale_aligned);
// Add `(input - mean) * gamma / sqrt(variance + eps)` and `beta`
auto add = std::make_shared<opset5::Add>(mul, beta_aligned);
add->set_friendly_name(m_bn->get_friendly_name());
copy_runtime_info(m_bn, {scale_add, scale, gamma_div_scale, gamma_div_scale_aligned,
beta_aligned, input_sub_mean, mul, add});
replace_node(m_bn, add);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(bn, matcher_name);
this->register_matcher(m, callback);
}

View File

@@ -10,13 +10,13 @@
#include <ngraph/ngraph.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGELU, "ConvertGELU", 0);
ngraph::pass::ConvertGELU::ConvertGELU() {
MATCHER_SCOPE(ConvertGELU);
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto gelu = std::make_shared<ngraph::opset2::Gelu>(input);
auto gelu = pattern::wrap_type<ngraph::opset2::Gelu>();
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto gelu = std::dynamic_pointer_cast<ngraph::opset2::Gelu>(m.get_match_root());

View File

@@ -11,6 +11,7 @@
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
using namespace ngraph;
@@ -18,8 +19,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertShuffleChannels3, "ConvertShuffleCha
ngraph::pass::ConvertShuffleChannels3::ConvertShuffleChannels3() {
MATCHER_SCOPE(ConvertShuffleChannels3);
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
auto shuffle_channels = std::make_shared<::opset3::ShuffleChannels>(input);
auto shuffle_channels = pattern::wrap_type<opset3::ShuffleChannels>();
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
auto shuffle_channels = std::dynamic_pointer_cast<::opset3::ShuffleChannels>(m.get_match_root());

View File

@@ -24,8 +24,8 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertTensorIteratorToGRUSequence, "Conver
ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSequence() {
MATCHER_SCOPE(ConvertTensorIteratorToLSTMSequence);
auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset5::TensorIterator>());
auto tensor_iterator = pattern::wrap_type<ngraph::opset5::TensorIterator>();
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
auto ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(m.get_match_root());
if (!ti || transformation_callback(ti))
@@ -201,8 +201,8 @@ ngraph::pass::ConvertTensorIteratorToLSTMSequence::ConvertTensorIteratorToLSTMSe
ngraph::pass::ConvertTensorIteratorToRNNSequence::ConvertTensorIteratorToRNNSequence() {
MATCHER_SCOPE(ConvertTensorIteratorToRNNSequence);
auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset5::TensorIterator>());
auto tensor_iterator = pattern::wrap_type<ngraph::opset5::TensorIterator>();
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
auto ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(m.get_match_root());
if (!ti || transformation_callback(ti))
@@ -357,8 +357,8 @@ ngraph::pass::ConvertTensorIteratorToRNNSequence::ConvertTensorIteratorToRNNSequ
ngraph::pass::ConvertTensorIteratorToGRUSequence::ConvertTensorIteratorToGRUSequence() {
MATCHER_SCOPE(ConvertTensorIteratorToGRUSequence);
auto tensor_iterator = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32,
ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset5::TensorIterator>());
auto tensor_iterator = pattern::wrap_type<ngraph::opset5::TensorIterator>();
ngraph::matcher_pass_callback callback = [this](pattern::Matcher &m) {
auto ti = std::dynamic_pointer_cast<ngraph::opset5::TensorIterator>(m.get_match_root());
if (!ti || transformation_callback(ti))

View File

@@ -18,10 +18,8 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::LSTMCellDecomposition, "LSTMCellDecompositi
ngraph::pass::LSTMCellDecomposition::LSTMCellDecomposition() {
MATCHER_SCOPE(LSTMCellDecomposition);
auto is_supported_lstm_cell = [](const std::shared_ptr<Node>& n) {
return pattern::has_class<ngraph::opset1::LSTMCell>()(n) || pattern::has_class<ngraph::opset4::LSTMCell>()(n);
};
auto any_lstm = std::make_shared<pattern::op::Label>(element::f32, Shape{}, is_supported_lstm_cell);
auto any_lstm = pattern::wrap_type<opset1::LSTMCell, opset4::LSTMCell>();
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
auto lstm_cell = std::dynamic_pointer_cast<ngraph::op::util::RNNCellBase>(m.get_match_root());
if (!lstm_cell || transformation_callback(lstm_cell)) {

View File

@@ -36,7 +36,17 @@ namespace ngraph
[](const Output<Node>& output) { return true; },
const OutputVector& input_values = {})
: Pattern(input_values, pred)
, m_wrapped_type(wrapped_type)
, m_wrapped_types({wrapped_type})
{
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
}
explicit WrapType(std::vector<NodeTypeInfo> wrapped_types,
const ValuePredicate& pred =
[](const Output<Node>& output) { return true; },
const OutputVector& input_values = {})
: Pattern(input_values, pred)
, m_wrapped_types(std::move(wrapped_types))
{
set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
}
@@ -45,30 +55,33 @@ namespace ngraph
const Output<Node>& pattern_value,
const Output<Node>& graph_value) override;
NodeTypeInfo get_wrapped_type() const { return m_wrapped_type; }
NodeTypeInfo get_wrapped_type() const;
const std::vector<NodeTypeInfo>& get_wrapped_types() const;
private:
NodeTypeInfo m_wrapped_type;
std::vector<NodeTypeInfo> m_wrapped_types;
};
}
template <class T>
template <class... Args>
std::shared_ptr<Node> wrap_type(const OutputVector& inputs,
const pattern::op::ValuePredicate& pred)
{
static_assert(std::is_base_of<Node, T>::value, "Unexpected template type");
return std::make_shared<op::WrapType>(T::type_info, pred, inputs);
std::vector<DiscreteTypeInfo> info{Args::type_info...};
return std::make_shared<op::WrapType>(info, pred, inputs);
}
template <class T>
template <class... Args>
std::shared_ptr<Node> wrap_type(const OutputVector& inputs = {})
{
return wrap_type<T>(inputs, [](const Output<Node>& output) { return true; });
return wrap_type<Args...>(inputs, [](const Output<Node>& output) { return true; });
}
template <class T>
template <class... Args>
std::shared_ptr<Node> wrap_type(const pattern::op::ValuePredicate& pred)
{
return wrap_type<T>({}, pred);
return wrap_type<Args...>({}, pred);
}
}
}

View File

@@ -109,12 +109,14 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
// it's type
// and use it in unordered_map as key for fast MatcherPass search. Otherwise type is unknown
// and default algorithm is used.
NodeTypeInfo root_type_info = root->get_type_info();
if (auto p = dynamic_pointer_cast<pattern::op::Pattern>(root))
{
if (auto any_type = dynamic_pointer_cast<pattern::op::WrapType>(p))
{
root_type_info = any_type->get_wrapped_type();
for (const auto& root_type_info : any_type->get_wrapped_types())
{
type_to_matcher[root_type_info].push_back(matcher_index);
}
}
else
{
@@ -122,7 +124,10 @@ bool pass::GraphRewrite::run_on_function(shared_ptr<Function> f)
break;
}
}
type_to_matcher[root_type_info].push_back(matcher_index);
else
{
type_to_matcher[root->get_type_info()].push_back(matcher_index);
}
// TODO: traverse parents for root_type_info in order to register complete list of matchers
// including ones triggered by parent type info.

View File

@@ -31,7 +31,12 @@ bool pattern::op::WrapType::match_value(Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)
{
if (graph_value.get_node_shared_ptr()->get_type_info().is_castable(get_wrapped_type()) &&
if (std::any_of(m_wrapped_types.begin(),
m_wrapped_types.end(),
[&](const NodeTypeInfo& type_info) {
return graph_value.get_node_shared_ptr()->get_type_info().is_castable(
type_info);
}) &&
m_predicate(graph_value))
{
auto& pattern_map = matcher->get_pattern_value_map();
@@ -44,3 +49,17 @@ bool pattern::op::WrapType::match_value(Matcher* matcher,
}
return false;
}
NodeTypeInfo pattern::op::WrapType::get_wrapped_type() const
{
if (m_wrapped_types.size() > 1)
{
throw ngraph::ngraph_error("get_wrapped_type() called on WrapType with more than one type");
}
return m_wrapped_types.at(0);
}
const std::vector<NodeTypeInfo>& pattern::op::WrapType::get_wrapped_types() const
{
return m_wrapped_types;
}

View File

@@ -810,7 +810,7 @@ TEST(pattern, is_contained_match)
ASSERT_FALSE(n.is_contained_match());
}
TEST(pattern, wrap_type)
TEST(pattern, wrap_type_single_op)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{1, 3, 64, 64});
auto b = make_shared<op::Abs>(a);
@@ -852,3 +852,47 @@ TEST(pattern, wrap_type)
ASSERT_TRUE(matcher->match(static_pointer_cast<Node>(mul2)));
}
}
TEST(pattern, wrap_type_multi_op)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{1, 3, 64, 64});
auto b = make_shared<op::Abs>(a);
auto c = make_shared<op::Relu>(a);
auto mul = make_shared<op::v1::Multiply>(a, op::Constant::create(element::f32, Shape{}, {1}));
auto add = make_shared<op::v1::Add>(op::Constant::create(element::f32, Shape{}, {1}), a);
{
auto m = pattern::wrap_type<op::v1::Multiply, op::v1::Add>();
auto matcher = std::make_shared<pattern::Matcher>(m, "MulAddMatcher");
ASSERT_TRUE(matcher->match(mul->output(0)));
ASSERT_EQ(matcher->get_matched_nodes().size(), 1);
ASSERT_EQ(matcher->get_matched_nodes()[0], mul);
ASSERT_EQ(matcher->get_pattern_map().count(m), 1);
ASSERT_TRUE(matcher->match(add->output(0)));
ASSERT_EQ(matcher->get_matched_nodes().size(), 1);
ASSERT_EQ(matcher->get_matched_nodes()[0], add);
ASSERT_EQ(matcher->get_pattern_map().count(m), 1);
ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(a)));
ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(b)));
ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
}
{
auto m = pattern::wrap_type<op::util::BinaryElementwiseArithmetic>();
auto matcher = std::make_shared<pattern::Matcher>(m, "ElementwiseMatcher");
ASSERT_TRUE(matcher->match(mul->output(0)));
ASSERT_EQ(matcher->get_matched_nodes().size(), 1);
ASSERT_EQ(matcher->get_matched_nodes()[0], mul);
ASSERT_EQ(matcher->get_pattern_map().count(m), 1);
ASSERT_TRUE(matcher->match(add->output(0)));
ASSERT_EQ(matcher->get_matched_nodes().size(), 1);
ASSERT_EQ(matcher->get_matched_nodes()[0], add);
ASSERT_EQ(matcher->get_pattern_map().count(m), 1);
ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(a)));
ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(b)));
ASSERT_FALSE(matcher->match(static_pointer_cast<Node>(c)));
}
}