AUGRUCell fusion transformation (#12844)

* AUGRUCell fusion transformation and tests

* add missed includes

* Apply review comments

* Apply review comments

* fix build

* Enable accuracy tests

* change supported  weights format from zrh to rzh

* update submodule version

* resolve review comments

* add additional checks to pattern

* fix conflict with master
This commit is contained in:
Ivan Tikhonov 2022-09-22 14:21:05 +03:00 committed by GitHub
parent 249df503eb
commit 77ba4ab6dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 274 additions and 0 deletions

View File

@ -0,0 +1,37 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <vector>
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
class TRANSFORMATIONS_API AUGRUCellFusion;
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief AUGRUCellFusion transformation replaces a sequence of
* operations with AUGRUCell op.
*
* Supported activations: 1st is Sigmoid, 2nd is Tanh
* Clip attribute is not supported.
* Linear_before_reset attribute is not supported.
* Supported weights format: 'rzh'
*
*/
class ov::pass::AUGRUCellFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("AUGRUCellFusion", "0");
AUGRUCellFusion();
};

View File

@ -0,0 +1,102 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/augru_cell_fusion.hpp"
#include <memory>
#include "itt.hpp"
#include "ngraph_ops/augru_cell.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/opsets/opset9.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
using namespace std;
using namespace ov::opset9;
using namespace ov::element;
using namespace ov::pass::pattern;
ov::pass::AUGRUCellFusion::AUGRUCellFusion() {
MATCHER_SCOPE(AUGRUCellFusion);
// we can't determine hidden_size or input_size in this case
const auto is_first_dim_static = [](const Output<Node>& output) -> bool {
const auto& p_shape = output.get_partial_shape();
return !(p_shape.rank().is_dynamic() || p_shape[1].is_dynamic());
};
auto concat_1 = wrap_type<Concat>({any_input(is_first_dim_static), any_input(is_first_dim_static)});
auto matmul_1 = wrap_type<MatMul>({concat_1, any_input(is_first_dim_static)});
auto add_1 = wrap_type<Add>({matmul_1, any_input()});
// only Sigmoid is supported in the current version of AUGRUCell
auto sigmoid = wrap_type<Sigmoid>({add_1});
auto split = wrap_type<Split>({sigmoid, any_input()});
auto multiply = wrap_type<Multiply>({split, any_input()});
auto concat_2 = wrap_type<Concat>({any_input(), multiply});
auto matmul_2 = wrap_type<MatMul>({concat_2, any_input(is_first_dim_static)});
auto add_2 = wrap_type<Add>({matmul_2, any_input()});
// only Tanh is supported in the current version of AUGRUCell
auto tanh = wrap_type<Tanh>({add_2});
auto subtract_1 = wrap_type<Subtract>({any_input(), any_input()});
auto multiply_2 = wrap_type<Multiply>({subtract_1, split});
auto subtract_2 = wrap_type<Subtract>({any_input(), multiply_2});
auto multiply_3 = wrap_type<Multiply>({subtract_2, tanh});
auto multiply_4 = wrap_type<Multiply>({multiply_2, any_input()});
auto add_3 = wrap_type<Add>({multiply_4, multiply_3});
matcher_pass_callback callback = [=](pattern::Matcher& m) {
NodeRegistry rg;
auto pattern_map = m.get_pattern_map();
auto concat = pattern_map.at(concat_1);
auto X = concat->input_value(0);
auto H = concat->input_value(1);
auto h_pshape = H.get_partial_shape();
auto x_pshape = X.get_partial_shape();
auto hidden_size = h_pshape[1].get_length();
auto input_size = x_pshape[1].get_length();
auto axis_0 = rg.make<Constant>(i64, Shape{}, 0);
auto axis_1 = rg.make<Constant>(i64, Shape{}, 1);
auto A = pattern_map.at(subtract_1)->input_value(1);
// biases are required
auto bias_add_1 = pattern_map.at(add_1);
auto split_bias_r_z = rg.make<Split>(bias_add_1->input_value(1), axis_1, 2);
auto bias_add_2 = pattern_map.at(add_2);
auto B = rg.make<Concat>(
OutputVector{split_bias_r_z->output(1), split_bias_r_z->output(0), bias_add_2->input_value(1)},
1);
auto WRrz = pattern_map.at(matmul_1)->input_value(1);
auto WRh = pattern_map.at(matmul_2)->input_value(1);
auto split_lenghts = rg.make<Constant>(i64, Shape{2}, vector<int64_t>{input_size, hidden_size});
auto split_WRrz = rg.make<VariadicSplit>(WRrz, axis_1, split_lenghts);
auto split_W_r_z = rg.make<Split>(split_WRrz->output(0), axis_0, 2);
auto split_R_r_z = rg.make<Split>(split_WRrz->output(1), axis_0, 2);
auto split_WRh = rg.make<VariadicSplit>(WRh, axis_1, split_lenghts);
auto Wzrh =
rg.make<Concat>(OutputVector{split_W_r_z->output(1), split_W_r_z->output(0), split_WRh->output(0)}, 0);
auto Rzrh =
rg.make<Concat>(OutputVector{split_R_r_z->output(1), split_R_r_z->output(0), split_WRh->output(1)}, 0);
auto squeeze_B = rg.make<Squeeze>(B, axis_0);
auto cell =
rg.make<op::internal::AUGRUCell>(X, H, Wzrh, Rzrh, squeeze_B, A, H.get_partial_shape()[1].get_length());
cell->set_friendly_name(m.get_match_root()->get_friendly_name());
copy_runtime_info(m.get_matched_nodes(), rg.get());
replace_node(m.get_match_root(), cell);
return true;
};
auto m = make_shared<Matcher>(add_3, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -0,0 +1,135 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <queue>
#include <gtest/gtest.h>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "ngraph_ops/augru_cell.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/opsets/opset9.hpp"
#include "transformations/common_optimizations/augru_cell_fusion.hpp"
using namespace std;
using namespace testing;
using namespace ov;
using namespace opset9;
using namespace element;
namespace {
shared_ptr<Model> gen_model(size_t batch, size_t hidden_size, size_t input_size, bool use_dyn_shapes) {
auto X = make_shared<Parameter>(f32, Shape{batch, input_size});
if (use_dyn_shapes) {
X = make_shared<Parameter>(f32, PartialShape{static_cast<int64_t>(batch), Dimension::dynamic()});
}
auto H = make_shared<Parameter>(f32, Shape{batch, hidden_size});
auto WRzr = make_shared<Parameter>(f32, Shape{2 * hidden_size, input_size + hidden_size});
auto Bzr = make_shared<Parameter>(f32, Shape{1, 2 * hidden_size});
auto WRh = make_shared<Parameter>(f32, Shape{hidden_size, input_size + hidden_size});
auto Bh = make_shared<Parameter>(f32, Shape{1, hidden_size});
auto A = make_shared<Parameter>(f32, Shape{batch, 1});
auto concat_1 = make_shared<Concat>(OutputVector{X, H}, 1);
auto matmul_1 = make_shared<MatMul>(concat_1, WRzr, false, true);
auto in_to_activation_1 = make_shared<Add>(matmul_1, Bzr);
auto sigmoid = make_shared<Sigmoid>(in_to_activation_1);
auto axis_1 = make_shared<Constant>(i64, Shape{}, 1);
auto split = make_shared<Split>(sigmoid, axis_1, 2);
auto multiply_1 = make_shared<Multiply>(split, H);
auto concat_2 = make_shared<Concat>(OutputVector{X, multiply_1}, 1);
auto matmul_2 = make_shared<MatMul>(concat_2, WRh, false, true);
auto in_to_activation_2 = make_shared<Add>(matmul_2, Bh);
auto tanh = make_shared<Tanh>(in_to_activation_2);
auto one = make_shared<Constant>(f32, Shape{1}, 1);
auto subtract_1 = make_shared<Subtract>(one, A);
auto multiply_2 = make_shared<Multiply>(subtract_1, split->output(1));
auto subtract_2 = make_shared<Subtract>(one, multiply_2);
auto multiply_3 = make_shared<Multiply>(subtract_2, tanh);
auto multiply_4 = make_shared<Multiply>(multiply_2, H);
auto add = make_shared<Add>(multiply_4, multiply_3);
return make_shared<Model>(OutputVector{add}, ParameterVector{X, H, WRzr, WRh, Bzr, Bh, A});
}
shared_ptr<Model> gen_reference(size_t batch, size_t hidden_size, size_t input_size) {
auto X = make_shared<Parameter>(f32, Shape{batch, input_size});
auto H = make_shared<Parameter>(f32, Shape{batch, hidden_size});
auto WRrz = make_shared<Parameter>(f32, Shape{2 * hidden_size, input_size + hidden_size});
auto WRh = make_shared<Parameter>(f32, Shape{hidden_size, input_size + hidden_size});
auto Brz = make_shared<Parameter>(f32, Shape{1, 2 * hidden_size});
auto Bh = make_shared<Parameter>(f32, Shape{1, hidden_size});
auto A = make_shared<Parameter>(f32, Shape{batch, 1});
ParameterVector params = {X, H, WRrz, WRh, Brz, Bh, A};
auto axis_0 = make_shared<Constant>(i64, Shape{}, 0);
auto axis_1 = make_shared<Constant>(i64, Shape{}, 1);
auto split_lenghts = make_shared<Constant>(i64, Shape{2}, vector<size_t>{input_size, hidden_size});
auto split_WRrz = make_shared<VariadicSplit>(WRrz, axis_1, split_lenghts);
auto split_W_r_z = make_shared<Split>(split_WRrz->output(0), axis_0, 2);
auto split_R_r_z = make_shared<Split>(split_WRrz->output(1), axis_0, 2);
auto split_WRh = make_shared<VariadicSplit>(WRh, axis_1, split_lenghts);
auto Wzrh =
make_shared<Concat>(OutputVector{split_W_r_z->output(1), split_W_r_z->output(0), split_WRh->output(0)}, 0);
auto Rzrh =
make_shared<Concat>(OutputVector{split_R_r_z->output(1), split_R_r_z->output(0), split_WRh->output(1)}, 0);
auto split_bias_r_z = make_shared<Split>(Brz, axis_1, 2);
auto B = make_shared<Concat>(OutputVector{split_bias_r_z->output(1), split_bias_r_z->output(0), Bh}, 1);
auto squeeze_B = make_shared<Squeeze>(B, axis_0);
auto cell = make_shared<op::internal::AUGRUCell>(X, H, Wzrh, Rzrh, squeeze_B, A, hidden_size);
return make_shared<Model>(OutputVector {cell}, params);
}
} // namespace
struct AUGRUFusionParams {
size_t batch;
size_t hidden_size;
size_t input_size;
};
class AUGRUFusionTest
: public WithParamInterface<AUGRUFusionParams>,
public TransformationTestsF {
};
TEST_P(AUGRUFusionTest, AUGRUCellPattern) {
const auto& p = GetParam();
{
model = gen_model(p.batch, p.hidden_size, p.input_size, false);
manager.register_pass<pass::AUGRUCellFusion>();
}
{
model_ref = gen_reference(p.batch, p.hidden_size, p.input_size);
}
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}
class AUGRUFusionTestDyn
: public WithParamInterface<AUGRUFusionParams>, public TransformationTestsF {
};
TEST_P(AUGRUFusionTestDyn, AUGRUCellPatternDynamicShapes) {
const auto& p = GetParam();
{
model = gen_model(p.batch, p.hidden_size, p.input_size, true);
// the transformation won't be applied because we can't determine hidden_size/input_size,
// they are dynamic.
manager.register_pass<pass::AUGRUCellFusion>();
}
}
static const std::vector<AUGRUFusionParams> params = {
AUGRUFusionParams{1, 1, 1},
AUGRUFusionParams{2, 128, 32},
};
INSTANTIATE_TEST_SUITE_P(AUGRUFusionTest, AUGRUFusionTest, ValuesIn(params));
INSTANTIATE_TEST_SUITE_P(AUGRUFusionTestDyn, AUGRUFusionTestDyn, ValuesIn(params));