AUGRUCell fusion transformation (#12844)
* AUGRUCell fusion transformation and tests * add missed includes * Apply review comments * Apply review comments * fix build * Enable accuracy tests * change supported weights format from zrh to rzh * update submodule version * resolve review comments * add additional checks to pattern * fix conflict with master
This commit is contained in:
parent
249df503eb
commit
77ba4ab6dd
@ -0,0 +1,37 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API AUGRUCellFusion;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief AUGRUCellFusion transformation replaces a sequence of
|
||||
* operations with AUGRUCell op.
|
||||
*
|
||||
* Supported activations: 1st is Sigmoid, 2nd is Tanh
|
||||
* Clip attribute is not supported.
|
||||
* Linear_before_reset attribute is not supported.
|
||||
* Supported weights format: 'rzh'
|
||||
*
|
||||
*/
|
||||
|
||||
class ov::pass::AUGRUCellFusion : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("AUGRUCellFusion", "0");
|
||||
AUGRUCellFusion();
|
||||
};
|
@ -0,0 +1,102 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/augru_cell_fusion.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ngraph_ops/augru_cell.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::opset9;
|
||||
using namespace ov::element;
|
||||
using namespace ov::pass::pattern;
|
||||
|
||||
ov::pass::AUGRUCellFusion::AUGRUCellFusion() {
|
||||
MATCHER_SCOPE(AUGRUCellFusion);
|
||||
|
||||
// we can't determine hidden_size or input_size in this case
|
||||
const auto is_first_dim_static = [](const Output<Node>& output) -> bool {
|
||||
const auto& p_shape = output.get_partial_shape();
|
||||
return !(p_shape.rank().is_dynamic() || p_shape[1].is_dynamic());
|
||||
};
|
||||
|
||||
auto concat_1 = wrap_type<Concat>({any_input(is_first_dim_static), any_input(is_first_dim_static)});
|
||||
auto matmul_1 = wrap_type<MatMul>({concat_1, any_input(is_first_dim_static)});
|
||||
auto add_1 = wrap_type<Add>({matmul_1, any_input()});
|
||||
// only Sigmoid is supported in the current version of AUGRUCell
|
||||
auto sigmoid = wrap_type<Sigmoid>({add_1});
|
||||
auto split = wrap_type<Split>({sigmoid, any_input()});
|
||||
auto multiply = wrap_type<Multiply>({split, any_input()});
|
||||
|
||||
auto concat_2 = wrap_type<Concat>({any_input(), multiply});
|
||||
auto matmul_2 = wrap_type<MatMul>({concat_2, any_input(is_first_dim_static)});
|
||||
auto add_2 = wrap_type<Add>({matmul_2, any_input()});
|
||||
// only Tanh is supported in the current version of AUGRUCell
|
||||
auto tanh = wrap_type<Tanh>({add_2});
|
||||
|
||||
auto subtract_1 = wrap_type<Subtract>({any_input(), any_input()});
|
||||
auto multiply_2 = wrap_type<Multiply>({subtract_1, split});
|
||||
auto subtract_2 = wrap_type<Subtract>({any_input(), multiply_2});
|
||||
auto multiply_3 = wrap_type<Multiply>({subtract_2, tanh});
|
||||
|
||||
auto multiply_4 = wrap_type<Multiply>({multiply_2, any_input()});
|
||||
auto add_3 = wrap_type<Add>({multiply_4, multiply_3});
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
NodeRegistry rg;
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
auto concat = pattern_map.at(concat_1);
|
||||
auto X = concat->input_value(0);
|
||||
auto H = concat->input_value(1);
|
||||
|
||||
auto h_pshape = H.get_partial_shape();
|
||||
auto x_pshape = X.get_partial_shape();
|
||||
|
||||
auto hidden_size = h_pshape[1].get_length();
|
||||
auto input_size = x_pshape[1].get_length();
|
||||
|
||||
auto axis_0 = rg.make<Constant>(i64, Shape{}, 0);
|
||||
auto axis_1 = rg.make<Constant>(i64, Shape{}, 1);
|
||||
|
||||
auto A = pattern_map.at(subtract_1)->input_value(1);
|
||||
// biases are required
|
||||
auto bias_add_1 = pattern_map.at(add_1);
|
||||
auto split_bias_r_z = rg.make<Split>(bias_add_1->input_value(1), axis_1, 2);
|
||||
auto bias_add_2 = pattern_map.at(add_2);
|
||||
|
||||
auto B = rg.make<Concat>(
|
||||
OutputVector{split_bias_r_z->output(1), split_bias_r_z->output(0), bias_add_2->input_value(1)},
|
||||
1);
|
||||
|
||||
auto WRrz = pattern_map.at(matmul_1)->input_value(1);
|
||||
auto WRh = pattern_map.at(matmul_2)->input_value(1);
|
||||
|
||||
auto split_lenghts = rg.make<Constant>(i64, Shape{2}, vector<int64_t>{input_size, hidden_size});
|
||||
auto split_WRrz = rg.make<VariadicSplit>(WRrz, axis_1, split_lenghts);
|
||||
auto split_W_r_z = rg.make<Split>(split_WRrz->output(0), axis_0, 2);
|
||||
auto split_R_r_z = rg.make<Split>(split_WRrz->output(1), axis_0, 2);
|
||||
auto split_WRh = rg.make<VariadicSplit>(WRh, axis_1, split_lenghts);
|
||||
auto Wzrh =
|
||||
rg.make<Concat>(OutputVector{split_W_r_z->output(1), split_W_r_z->output(0), split_WRh->output(0)}, 0);
|
||||
auto Rzrh =
|
||||
rg.make<Concat>(OutputVector{split_R_r_z->output(1), split_R_r_z->output(0), split_WRh->output(1)}, 0);
|
||||
|
||||
auto squeeze_B = rg.make<Squeeze>(B, axis_0);
|
||||
auto cell =
|
||||
rg.make<op::internal::AUGRUCell>(X, H, Wzrh, Rzrh, squeeze_B, A, H.get_partial_shape()[1].get_length());
|
||||
|
||||
cell->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
copy_runtime_info(m.get_matched_nodes(), rg.get());
|
||||
replace_node(m.get_match_root(), cell);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = make_shared<Matcher>(add_3, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,135 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <queue>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "ngraph_ops/augru_cell.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "transformations/common_optimizations/augru_cell_fusion.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace testing;
|
||||
|
||||
using namespace ov;
|
||||
using namespace opset9;
|
||||
using namespace element;
|
||||
|
||||
namespace {
|
||||
shared_ptr<Model> gen_model(size_t batch, size_t hidden_size, size_t input_size, bool use_dyn_shapes) {
|
||||
auto X = make_shared<Parameter>(f32, Shape{batch, input_size});
|
||||
if (use_dyn_shapes) {
|
||||
X = make_shared<Parameter>(f32, PartialShape{static_cast<int64_t>(batch), Dimension::dynamic()});
|
||||
}
|
||||
auto H = make_shared<Parameter>(f32, Shape{batch, hidden_size});
|
||||
auto WRzr = make_shared<Parameter>(f32, Shape{2 * hidden_size, input_size + hidden_size});
|
||||
auto Bzr = make_shared<Parameter>(f32, Shape{1, 2 * hidden_size});
|
||||
auto WRh = make_shared<Parameter>(f32, Shape{hidden_size, input_size + hidden_size});
|
||||
auto Bh = make_shared<Parameter>(f32, Shape{1, hidden_size});
|
||||
auto A = make_shared<Parameter>(f32, Shape{batch, 1});
|
||||
auto concat_1 = make_shared<Concat>(OutputVector{X, H}, 1);
|
||||
auto matmul_1 = make_shared<MatMul>(concat_1, WRzr, false, true);
|
||||
auto in_to_activation_1 = make_shared<Add>(matmul_1, Bzr);
|
||||
|
||||
auto sigmoid = make_shared<Sigmoid>(in_to_activation_1);
|
||||
auto axis_1 = make_shared<Constant>(i64, Shape{}, 1);
|
||||
auto split = make_shared<Split>(sigmoid, axis_1, 2);
|
||||
|
||||
auto multiply_1 = make_shared<Multiply>(split, H);
|
||||
auto concat_2 = make_shared<Concat>(OutputVector{X, multiply_1}, 1);
|
||||
auto matmul_2 = make_shared<MatMul>(concat_2, WRh, false, true);
|
||||
auto in_to_activation_2 = make_shared<Add>(matmul_2, Bh);
|
||||
auto tanh = make_shared<Tanh>(in_to_activation_2);
|
||||
|
||||
auto one = make_shared<Constant>(f32, Shape{1}, 1);
|
||||
auto subtract_1 = make_shared<Subtract>(one, A);
|
||||
auto multiply_2 = make_shared<Multiply>(subtract_1, split->output(1));
|
||||
auto subtract_2 = make_shared<Subtract>(one, multiply_2);
|
||||
auto multiply_3 = make_shared<Multiply>(subtract_2, tanh);
|
||||
|
||||
auto multiply_4 = make_shared<Multiply>(multiply_2, H);
|
||||
auto add = make_shared<Add>(multiply_4, multiply_3);
|
||||
return make_shared<Model>(OutputVector{add}, ParameterVector{X, H, WRzr, WRh, Bzr, Bh, A});
|
||||
}
|
||||
|
||||
shared_ptr<Model> gen_reference(size_t batch, size_t hidden_size, size_t input_size) {
|
||||
auto X = make_shared<Parameter>(f32, Shape{batch, input_size});
|
||||
auto H = make_shared<Parameter>(f32, Shape{batch, hidden_size});
|
||||
auto WRrz = make_shared<Parameter>(f32, Shape{2 * hidden_size, input_size + hidden_size});
|
||||
auto WRh = make_shared<Parameter>(f32, Shape{hidden_size, input_size + hidden_size});
|
||||
auto Brz = make_shared<Parameter>(f32, Shape{1, 2 * hidden_size});
|
||||
auto Bh = make_shared<Parameter>(f32, Shape{1, hidden_size});
|
||||
auto A = make_shared<Parameter>(f32, Shape{batch, 1});
|
||||
ParameterVector params = {X, H, WRrz, WRh, Brz, Bh, A};
|
||||
|
||||
auto axis_0 = make_shared<Constant>(i64, Shape{}, 0);
|
||||
auto axis_1 = make_shared<Constant>(i64, Shape{}, 1);
|
||||
auto split_lenghts = make_shared<Constant>(i64, Shape{2}, vector<size_t>{input_size, hidden_size});
|
||||
auto split_WRrz = make_shared<VariadicSplit>(WRrz, axis_1, split_lenghts);
|
||||
auto split_W_r_z = make_shared<Split>(split_WRrz->output(0), axis_0, 2);
|
||||
auto split_R_r_z = make_shared<Split>(split_WRrz->output(1), axis_0, 2);
|
||||
auto split_WRh = make_shared<VariadicSplit>(WRh, axis_1, split_lenghts);
|
||||
auto Wzrh =
|
||||
make_shared<Concat>(OutputVector{split_W_r_z->output(1), split_W_r_z->output(0), split_WRh->output(0)}, 0);
|
||||
auto Rzrh =
|
||||
make_shared<Concat>(OutputVector{split_R_r_z->output(1), split_R_r_z->output(0), split_WRh->output(1)}, 0);
|
||||
|
||||
auto split_bias_r_z = make_shared<Split>(Brz, axis_1, 2);
|
||||
auto B = make_shared<Concat>(OutputVector{split_bias_r_z->output(1), split_bias_r_z->output(0), Bh}, 1);
|
||||
|
||||
auto squeeze_B = make_shared<Squeeze>(B, axis_0);
|
||||
auto cell = make_shared<op::internal::AUGRUCell>(X, H, Wzrh, Rzrh, squeeze_B, A, hidden_size);
|
||||
return make_shared<Model>(OutputVector {cell}, params);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
struct AUGRUFusionParams {
|
||||
size_t batch;
|
||||
size_t hidden_size;
|
||||
size_t input_size;
|
||||
};
|
||||
|
||||
class AUGRUFusionTest
|
||||
: public WithParamInterface<AUGRUFusionParams>,
|
||||
public TransformationTestsF {
|
||||
};
|
||||
|
||||
TEST_P(AUGRUFusionTest, AUGRUCellPattern) {
|
||||
const auto& p = GetParam();
|
||||
{
|
||||
model = gen_model(p.batch, p.hidden_size, p.input_size, false);
|
||||
manager.register_pass<pass::AUGRUCellFusion>();
|
||||
}
|
||||
|
||||
{
|
||||
model_ref = gen_reference(p.batch, p.hidden_size, p.input_size);
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
}
|
||||
|
||||
class AUGRUFusionTestDyn
|
||||
: public WithParamInterface<AUGRUFusionParams>, public TransformationTestsF {
|
||||
};
|
||||
|
||||
TEST_P(AUGRUFusionTestDyn, AUGRUCellPatternDynamicShapes) {
|
||||
const auto& p = GetParam();
|
||||
{
|
||||
model = gen_model(p.batch, p.hidden_size, p.input_size, true);
|
||||
// the transformation won't be applied because we can't determine hidden_size/input_size,
|
||||
// they are dynamic.
|
||||
manager.register_pass<pass::AUGRUCellFusion>();
|
||||
}
|
||||
}
|
||||
|
||||
static const std::vector<AUGRUFusionParams> params = {
|
||||
AUGRUFusionParams{1, 1, 1},
|
||||
AUGRUFusionParams{2, 128, 32},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(AUGRUFusionTest, AUGRUFusionTest, ValuesIn(params));
|
||||
INSTANTIATE_TEST_SUITE_P(AUGRUFusionTestDyn, AUGRUFusionTestDyn, ValuesIn(params));
|
Loading…
Reference in New Issue
Block a user