AUGRUCell fusion transformation (#12844)
* AUGRUCell fusion transformation and tests * add missed includes * Apply review comments * Apply review comments * fix build * Enable accuracy tests * change supported weights format from zrh to rzh * update submodule version * resolve review comments * add additional checks to pattern * fix conflict with master
This commit is contained in:
@@ -0,0 +1,37 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "openvino/pass/graph_rewrite.hpp"
|
||||||
|
#include "transformations_visibility.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class TRANSFORMATIONS_API AUGRUCellFusion;
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @ingroup ie_transformation_common_api
|
||||||
|
* @brief AUGRUCellFusion transformation replaces a sequence of
|
||||||
|
* operations with AUGRUCell op.
|
||||||
|
*
|
||||||
|
* Supported activations: 1st is Sigmoid, 2nd is Tanh
|
||||||
|
* Clip attribute is not supported.
|
||||||
|
* Linear_before_reset attribute is not supported.
|
||||||
|
* Supported weights format: 'rzh'
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
class ov::pass::AUGRUCellFusion : public ov::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("AUGRUCellFusion", "0");
|
||||||
|
AUGRUCellFusion();
|
||||||
|
};
|
||||||
@@ -0,0 +1,102 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/common_optimizations/augru_cell_fusion.hpp"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "itt.hpp"
|
||||||
|
#include "ngraph_ops/augru_cell.hpp"
|
||||||
|
#include "openvino/core/rt_info.hpp"
|
||||||
|
#include "openvino/opsets/opset9.hpp"
|
||||||
|
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace ov::opset9;
|
||||||
|
using namespace ov::element;
|
||||||
|
using namespace ov::pass::pattern;
|
||||||
|
|
||||||
|
ov::pass::AUGRUCellFusion::AUGRUCellFusion() {
|
||||||
|
MATCHER_SCOPE(AUGRUCellFusion);
|
||||||
|
|
||||||
|
// we can't determine hidden_size or input_size in this case
|
||||||
|
const auto is_first_dim_static = [](const Output<Node>& output) -> bool {
|
||||||
|
const auto& p_shape = output.get_partial_shape();
|
||||||
|
return !(p_shape.rank().is_dynamic() || p_shape[1].is_dynamic());
|
||||||
|
};
|
||||||
|
|
||||||
|
auto concat_1 = wrap_type<Concat>({any_input(is_first_dim_static), any_input(is_first_dim_static)});
|
||||||
|
auto matmul_1 = wrap_type<MatMul>({concat_1, any_input(is_first_dim_static)});
|
||||||
|
auto add_1 = wrap_type<Add>({matmul_1, any_input()});
|
||||||
|
// only Sigmoid is supported in the current version of AUGRUCell
|
||||||
|
auto sigmoid = wrap_type<Sigmoid>({add_1});
|
||||||
|
auto split = wrap_type<Split>({sigmoid, any_input()});
|
||||||
|
auto multiply = wrap_type<Multiply>({split, any_input()});
|
||||||
|
|
||||||
|
auto concat_2 = wrap_type<Concat>({any_input(), multiply});
|
||||||
|
auto matmul_2 = wrap_type<MatMul>({concat_2, any_input(is_first_dim_static)});
|
||||||
|
auto add_2 = wrap_type<Add>({matmul_2, any_input()});
|
||||||
|
// only Tanh is supported in the current version of AUGRUCell
|
||||||
|
auto tanh = wrap_type<Tanh>({add_2});
|
||||||
|
|
||||||
|
auto subtract_1 = wrap_type<Subtract>({any_input(), any_input()});
|
||||||
|
auto multiply_2 = wrap_type<Multiply>({subtract_1, split});
|
||||||
|
auto subtract_2 = wrap_type<Subtract>({any_input(), multiply_2});
|
||||||
|
auto multiply_3 = wrap_type<Multiply>({subtract_2, tanh});
|
||||||
|
|
||||||
|
auto multiply_4 = wrap_type<Multiply>({multiply_2, any_input()});
|
||||||
|
auto add_3 = wrap_type<Add>({multiply_4, multiply_3});
|
||||||
|
|
||||||
|
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||||
|
NodeRegistry rg;
|
||||||
|
auto pattern_map = m.get_pattern_map();
|
||||||
|
auto concat = pattern_map.at(concat_1);
|
||||||
|
auto X = concat->input_value(0);
|
||||||
|
auto H = concat->input_value(1);
|
||||||
|
|
||||||
|
auto h_pshape = H.get_partial_shape();
|
||||||
|
auto x_pshape = X.get_partial_shape();
|
||||||
|
|
||||||
|
auto hidden_size = h_pshape[1].get_length();
|
||||||
|
auto input_size = x_pshape[1].get_length();
|
||||||
|
|
||||||
|
auto axis_0 = rg.make<Constant>(i64, Shape{}, 0);
|
||||||
|
auto axis_1 = rg.make<Constant>(i64, Shape{}, 1);
|
||||||
|
|
||||||
|
auto A = pattern_map.at(subtract_1)->input_value(1);
|
||||||
|
// biases are required
|
||||||
|
auto bias_add_1 = pattern_map.at(add_1);
|
||||||
|
auto split_bias_r_z = rg.make<Split>(bias_add_1->input_value(1), axis_1, 2);
|
||||||
|
auto bias_add_2 = pattern_map.at(add_2);
|
||||||
|
|
||||||
|
auto B = rg.make<Concat>(
|
||||||
|
OutputVector{split_bias_r_z->output(1), split_bias_r_z->output(0), bias_add_2->input_value(1)},
|
||||||
|
1);
|
||||||
|
|
||||||
|
auto WRrz = pattern_map.at(matmul_1)->input_value(1);
|
||||||
|
auto WRh = pattern_map.at(matmul_2)->input_value(1);
|
||||||
|
|
||||||
|
auto split_lenghts = rg.make<Constant>(i64, Shape{2}, vector<int64_t>{input_size, hidden_size});
|
||||||
|
auto split_WRrz = rg.make<VariadicSplit>(WRrz, axis_1, split_lenghts);
|
||||||
|
auto split_W_r_z = rg.make<Split>(split_WRrz->output(0), axis_0, 2);
|
||||||
|
auto split_R_r_z = rg.make<Split>(split_WRrz->output(1), axis_0, 2);
|
||||||
|
auto split_WRh = rg.make<VariadicSplit>(WRh, axis_1, split_lenghts);
|
||||||
|
auto Wzrh =
|
||||||
|
rg.make<Concat>(OutputVector{split_W_r_z->output(1), split_W_r_z->output(0), split_WRh->output(0)}, 0);
|
||||||
|
auto Rzrh =
|
||||||
|
rg.make<Concat>(OutputVector{split_R_r_z->output(1), split_R_r_z->output(0), split_WRh->output(1)}, 0);
|
||||||
|
|
||||||
|
auto squeeze_B = rg.make<Squeeze>(B, axis_0);
|
||||||
|
auto cell =
|
||||||
|
rg.make<op::internal::AUGRUCell>(X, H, Wzrh, Rzrh, squeeze_B, A, H.get_partial_shape()[1].get_length());
|
||||||
|
|
||||||
|
cell->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||||
|
copy_runtime_info(m.get_matched_nodes(), rg.get());
|
||||||
|
replace_node(m.get_match_root(), cell);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = make_shared<Matcher>(add_3, matcher_name);
|
||||||
|
this->register_matcher(m, callback);
|
||||||
|
}
|
||||||
@@ -0,0 +1,135 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <queue>
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
#include "ngraph_ops/augru_cell.hpp"
|
||||||
|
#include "openvino/op/parameter.hpp"
|
||||||
|
#include "openvino/opsets/opset9.hpp"
|
||||||
|
#include "transformations/common_optimizations/augru_cell_fusion.hpp"
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using namespace testing;
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
using namespace opset9;
|
||||||
|
using namespace element;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
shared_ptr<Model> gen_model(size_t batch, size_t hidden_size, size_t input_size, bool use_dyn_shapes) {
|
||||||
|
auto X = make_shared<Parameter>(f32, Shape{batch, input_size});
|
||||||
|
if (use_dyn_shapes) {
|
||||||
|
X = make_shared<Parameter>(f32, PartialShape{static_cast<int64_t>(batch), Dimension::dynamic()});
|
||||||
|
}
|
||||||
|
auto H = make_shared<Parameter>(f32, Shape{batch, hidden_size});
|
||||||
|
auto WRzr = make_shared<Parameter>(f32, Shape{2 * hidden_size, input_size + hidden_size});
|
||||||
|
auto Bzr = make_shared<Parameter>(f32, Shape{1, 2 * hidden_size});
|
||||||
|
auto WRh = make_shared<Parameter>(f32, Shape{hidden_size, input_size + hidden_size});
|
||||||
|
auto Bh = make_shared<Parameter>(f32, Shape{1, hidden_size});
|
||||||
|
auto A = make_shared<Parameter>(f32, Shape{batch, 1});
|
||||||
|
auto concat_1 = make_shared<Concat>(OutputVector{X, H}, 1);
|
||||||
|
auto matmul_1 = make_shared<MatMul>(concat_1, WRzr, false, true);
|
||||||
|
auto in_to_activation_1 = make_shared<Add>(matmul_1, Bzr);
|
||||||
|
|
||||||
|
auto sigmoid = make_shared<Sigmoid>(in_to_activation_1);
|
||||||
|
auto axis_1 = make_shared<Constant>(i64, Shape{}, 1);
|
||||||
|
auto split = make_shared<Split>(sigmoid, axis_1, 2);
|
||||||
|
|
||||||
|
auto multiply_1 = make_shared<Multiply>(split, H);
|
||||||
|
auto concat_2 = make_shared<Concat>(OutputVector{X, multiply_1}, 1);
|
||||||
|
auto matmul_2 = make_shared<MatMul>(concat_2, WRh, false, true);
|
||||||
|
auto in_to_activation_2 = make_shared<Add>(matmul_2, Bh);
|
||||||
|
auto tanh = make_shared<Tanh>(in_to_activation_2);
|
||||||
|
|
||||||
|
auto one = make_shared<Constant>(f32, Shape{1}, 1);
|
||||||
|
auto subtract_1 = make_shared<Subtract>(one, A);
|
||||||
|
auto multiply_2 = make_shared<Multiply>(subtract_1, split->output(1));
|
||||||
|
auto subtract_2 = make_shared<Subtract>(one, multiply_2);
|
||||||
|
auto multiply_3 = make_shared<Multiply>(subtract_2, tanh);
|
||||||
|
|
||||||
|
auto multiply_4 = make_shared<Multiply>(multiply_2, H);
|
||||||
|
auto add = make_shared<Add>(multiply_4, multiply_3);
|
||||||
|
return make_shared<Model>(OutputVector{add}, ParameterVector{X, H, WRzr, WRh, Bzr, Bh, A});
|
||||||
|
}
|
||||||
|
|
||||||
|
shared_ptr<Model> gen_reference(size_t batch, size_t hidden_size, size_t input_size) {
|
||||||
|
auto X = make_shared<Parameter>(f32, Shape{batch, input_size});
|
||||||
|
auto H = make_shared<Parameter>(f32, Shape{batch, hidden_size});
|
||||||
|
auto WRrz = make_shared<Parameter>(f32, Shape{2 * hidden_size, input_size + hidden_size});
|
||||||
|
auto WRh = make_shared<Parameter>(f32, Shape{hidden_size, input_size + hidden_size});
|
||||||
|
auto Brz = make_shared<Parameter>(f32, Shape{1, 2 * hidden_size});
|
||||||
|
auto Bh = make_shared<Parameter>(f32, Shape{1, hidden_size});
|
||||||
|
auto A = make_shared<Parameter>(f32, Shape{batch, 1});
|
||||||
|
ParameterVector params = {X, H, WRrz, WRh, Brz, Bh, A};
|
||||||
|
|
||||||
|
auto axis_0 = make_shared<Constant>(i64, Shape{}, 0);
|
||||||
|
auto axis_1 = make_shared<Constant>(i64, Shape{}, 1);
|
||||||
|
auto split_lenghts = make_shared<Constant>(i64, Shape{2}, vector<size_t>{input_size, hidden_size});
|
||||||
|
auto split_WRrz = make_shared<VariadicSplit>(WRrz, axis_1, split_lenghts);
|
||||||
|
auto split_W_r_z = make_shared<Split>(split_WRrz->output(0), axis_0, 2);
|
||||||
|
auto split_R_r_z = make_shared<Split>(split_WRrz->output(1), axis_0, 2);
|
||||||
|
auto split_WRh = make_shared<VariadicSplit>(WRh, axis_1, split_lenghts);
|
||||||
|
auto Wzrh =
|
||||||
|
make_shared<Concat>(OutputVector{split_W_r_z->output(1), split_W_r_z->output(0), split_WRh->output(0)}, 0);
|
||||||
|
auto Rzrh =
|
||||||
|
make_shared<Concat>(OutputVector{split_R_r_z->output(1), split_R_r_z->output(0), split_WRh->output(1)}, 0);
|
||||||
|
|
||||||
|
auto split_bias_r_z = make_shared<Split>(Brz, axis_1, 2);
|
||||||
|
auto B = make_shared<Concat>(OutputVector{split_bias_r_z->output(1), split_bias_r_z->output(0), Bh}, 1);
|
||||||
|
|
||||||
|
auto squeeze_B = make_shared<Squeeze>(B, axis_0);
|
||||||
|
auto cell = make_shared<op::internal::AUGRUCell>(X, H, Wzrh, Rzrh, squeeze_B, A, hidden_size);
|
||||||
|
return make_shared<Model>(OutputVector {cell}, params);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
struct AUGRUFusionParams {
|
||||||
|
size_t batch;
|
||||||
|
size_t hidden_size;
|
||||||
|
size_t input_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
class AUGRUFusionTest
|
||||||
|
: public WithParamInterface<AUGRUFusionParams>,
|
||||||
|
public TransformationTestsF {
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(AUGRUFusionTest, AUGRUCellPattern) {
|
||||||
|
const auto& p = GetParam();
|
||||||
|
{
|
||||||
|
model = gen_model(p.batch, p.hidden_size, p.input_size, false);
|
||||||
|
manager.register_pass<pass::AUGRUCellFusion>();
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
model_ref = gen_reference(p.batch, p.hidden_size, p.input_size);
|
||||||
|
}
|
||||||
|
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||||
|
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||||
|
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||||
|
}
|
||||||
|
|
||||||
|
class AUGRUFusionTestDyn
|
||||||
|
: public WithParamInterface<AUGRUFusionParams>, public TransformationTestsF {
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(AUGRUFusionTestDyn, AUGRUCellPatternDynamicShapes) {
|
||||||
|
const auto& p = GetParam();
|
||||||
|
{
|
||||||
|
model = gen_model(p.batch, p.hidden_size, p.input_size, true);
|
||||||
|
// the transformation won't be applied because we can't determine hidden_size/input_size,
|
||||||
|
// they are dynamic.
|
||||||
|
manager.register_pass<pass::AUGRUCellFusion>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static const std::vector<AUGRUFusionParams> params = {
|
||||||
|
AUGRUFusionParams{1, 1, 1},
|
||||||
|
AUGRUFusionParams{2, 128, 32},
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(AUGRUFusionTest, AUGRUFusionTest, ValuesIn(params));
|
||||||
|
INSTANTIATE_TEST_SUITE_P(AUGRUFusionTestDyn, AUGRUFusionTestDyn, ValuesIn(params));
|
||||||
Reference in New Issue
Block a user