LSTMCellFusion - support transposed/not transposed weights (#21780)
* LSTMCellFusion - support transposed/not transposed weights * add comment describing fused subgraph
This commit is contained in:
parent
bc121c06c7
commit
2fcaa88af5
|
@ -17,10 +17,97 @@
|
|||
#include "openvino/op/split.hpp"
|
||||
#include "openvino/op/squeeze.hpp"
|
||||
#include "openvino/op/tanh.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/variadic_split.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
#include "validation_util.hpp"
|
||||
|
||||
/*
|
||||
The following graph is fused to LSTMCell
|
||||
|
||||
+-----+ +-----+
|
||||
| X | | H |
|
||||
+--+--+ +--+--+
|
||||
| |
|
||||
+---+ +---+
|
||||
| |
|
||||
v v
|
||||
+--+--+--+ +------+
|
||||
| Concat | | WR |
|
||||
+----+---+ +---+--+
|
||||
| |
|
||||
| +--------+
|
||||
| |
|
||||
v v
|
||||
+--+--+--+ +------+
|
||||
| MatMul | | Bias |
|
||||
+----+---+ +--+---+
|
||||
| |
|
||||
| +------+
|
||||
| |
|
||||
v v
|
||||
+--+---+--+
|
||||
| Add |
|
||||
+----+----+
|
||||
|
|
||||
|
|
||||
v
|
||||
+------+-------+
|
||||
| Split |
|
||||
+--+--+--+--+--+
|
||||
| | | |
|
||||
+--------------+ | | +------------------------------+
|
||||
| | | |
|
||||
v | +------+ +-------+ v
|
||||
+------+-----+ +-----+ | | const | +------+-----+
|
||||
| Activation | | | +---+---+ | Activation |
|
||||
| (i_t) | | | | | (o_t) |
|
||||
+------+-----+ | | +---+ +------+-----+
|
||||
| v | | |
|
||||
| +------+-----+ v v |
|
||||
| | Activation | +-+---+-+ |
|
||||
| | (c_t) | | Add | |
|
||||
| +------+-----+ +---+---+ |
|
||||
| | | |
|
||||
| | v |
|
||||
+---+ +---+ +------+-----+ |
|
||||
| | | Activation | +-----+ |
|
||||
v v | (f_t) | | C | |
|
||||
+--+---+---+ +------------+ +-----+ |
|
||||
| Multiply | | | |
|
||||
+----+-----+ | +--------+ |
|
||||
| | | |
|
||||
| v v |
|
||||
| +---+---+--+ |
|
||||
| | Multiply | |
|
||||
| +----+-----+ |
|
||||
| | |
|
||||
| | |
|
||||
+---------+ +--------+ |
|
||||
| | |
|
||||
v v |
|
||||
+-+-----+-+ |
|
||||
| Add | |
|
||||
| (C out) | |
|
||||
+----+----+ |
|
||||
| |
|
||||
v |
|
||||
+-----+------+ |
|
||||
| Activation | |
|
||||
+-----+------+ |
|
||||
| |
|
||||
| |
|
||||
+----------+ +-------------------+
|
||||
| |
|
||||
v v
|
||||
+--+----+--+
|
||||
| Multiply |
|
||||
| (H out) |
|
||||
+----------+
|
||||
|
||||
*/
|
||||
|
||||
static std::string get_activation_name(const std::shared_ptr<ov::Node>& node) {
|
||||
std::string name = node->get_type_name();
|
||||
name[0] = std::tolower(name[0]);
|
||||
|
@ -37,9 +124,7 @@ ov::pass::LSTMCellFusion::LSTMCellFusion() {
|
|||
return pattern::has_static_shape()(output) && pattern::rank_equals(2)(output);
|
||||
});
|
||||
auto matmul_label = pattern::wrap_type<op::v0::MatMul>({concat_label, weights_label});
|
||||
auto bias_label = pattern::any_input([](const Output<Node>& output) {
|
||||
return pattern::has_static_shape()(output) && pattern::rank_equals(2)(output);
|
||||
});
|
||||
auto bias_label = pattern::any_input(pattern::has_static_shape());
|
||||
auto bias_add_label = pattern::wrap_type<op::v1::Add>({matmul_label, bias_label});
|
||||
auto axis_label = pattern::wrap_type<op::v0::Constant>();
|
||||
auto split_label = pattern::wrap_type<op::v1::Split>({bias_add_label, axis_label});
|
||||
|
@ -62,51 +147,67 @@ ov::pass::LSTMCellFusion::LSTMCellFusion() {
|
|||
const auto& X = pattern_map.at(x_label);
|
||||
const auto& H = pattern_map.at(h_label);
|
||||
const auto& C = pattern_map.at(c_label);
|
||||
const auto& WR = pattern_map.at(weights_label);
|
||||
const auto& B = pattern_map.at(bias_label);
|
||||
auto WR = pattern_map.at(weights_label);
|
||||
auto B = pattern_map.at(bias_label);
|
||||
const auto& ft_additional_bias = pattern_map.at(ft_additional_bias_label);
|
||||
auto Ho = pattern_map.at(Ho_label);
|
||||
auto Co = pattern_map.at(Co_label);
|
||||
const auto matmul = ov::as_type_ptr<op::v0::MatMul>(pattern_map.at(matmul_label).get_node_shared_ptr());
|
||||
if (!matmul)
|
||||
return false;
|
||||
if (matmul->get_transpose_a())
|
||||
return false;
|
||||
|
||||
bool weights_transposed = matmul->get_transpose_b();
|
||||
const auto& WR_shape = WR.get_shape();
|
||||
const auto& B_shape = B.get_shape();
|
||||
const auto& ft_additional_bias_shape = ft_additional_bias.get_shape();
|
||||
|
||||
if (WR_shape[0] % 4 != 0)
|
||||
size_t input_size_plus_hidden_size = weights_transposed ? WR_shape[1] : WR_shape[0];
|
||||
size_t hidden_size_times_4 = weights_transposed ? WR_shape[0] : WR_shape[1];
|
||||
if (hidden_size_times_4 % 4 != 0)
|
||||
return false;
|
||||
if (WR_shape[0] != B_shape[1])
|
||||
return false;
|
||||
if (B_shape[0] != 1)
|
||||
if (B_shape.size() == 2) {
|
||||
if (hidden_size_times_4 != B_shape[1])
|
||||
return false;
|
||||
if (B_shape[0] != 1)
|
||||
return false;
|
||||
} else if (B_shape.size() == 1) {
|
||||
if (hidden_size_times_4 != B_shape[0])
|
||||
return false;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
if (shape_size(ft_additional_bias_shape) != 1)
|
||||
return false;
|
||||
|
||||
size_t hidden_size = WR_shape[0] / 4;
|
||||
size_t hidden_size = hidden_size_times_4 / 4;
|
||||
|
||||
if (WR_shape[1] <= hidden_size)
|
||||
if (input_size_plus_hidden_size <= hidden_size)
|
||||
return false;
|
||||
|
||||
size_t input_size = WR_shape[1] - hidden_size;
|
||||
size_t input_size = input_size_plus_hidden_size - hidden_size;
|
||||
|
||||
const auto& X_shape = X.get_partial_shape();
|
||||
const auto& H_shape = H.get_partial_shape();
|
||||
const auto& C_shape = C.get_partial_shape();
|
||||
|
||||
if (!H_shape[0].compatible(X_shape[0]))
|
||||
if (!H_shape[0].compatible(X_shape[0])) // batch size
|
||||
return false;
|
||||
|
||||
if (!C_shape[0].compatible(X_shape[0]))
|
||||
if (!C_shape[0].compatible(X_shape[0])) // batch size
|
||||
return false;
|
||||
|
||||
if (!X_shape[1].compatible(input_size))
|
||||
return false;
|
||||
|
||||
if (!H_shape[1].compatible(hidden_size))
|
||||
return false;
|
||||
|
||||
if (!C_shape[1].compatible(hidden_size))
|
||||
return false;
|
||||
|
||||
const auto split_axis = ov::as_type_ptr<op::v0::Constant>(pattern_map.at(axis_label).get_node_shared_ptr());
|
||||
int64_t split_axis_value = split_axis->cast_vector<int64_t>()[0];
|
||||
if (split_axis_value != 1 && split_axis_value != -1)
|
||||
return false;
|
||||
|
||||
NodeVector split_consumers{pattern_map.at(it_label).get_node_shared_ptr(),
|
||||
pattern_map.at(ct_label).get_node_shared_ptr(),
|
||||
pattern_map.at(ot_label).get_node_shared_ptr(),
|
||||
|
@ -142,14 +243,46 @@ ov::pass::LSTMCellFusion::LSTMCellFusion() {
|
|||
auto Co_activation = pattern_map.at(Co_activation_label).get_node_shared_ptr();
|
||||
std::string h_activation_name = get_activation_name(Co_activation);
|
||||
|
||||
auto zero = op::v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto WR_split = std::make_shared<op::v1::Split>(WR, zero /* axis */, 4);
|
||||
if (!weights_transposed) {
|
||||
WR = std::make_shared<op::v1::Transpose>(WR, op::v0::Constant::create(element::i32, Shape{0}, {}));
|
||||
}
|
||||
// Split WR to W, R and convert to the layout that OpenVino supports
|
||||
//
|
||||
// WR layout (icfo):
|
||||
//
|
||||
// W R
|
||||
// +-------+---+
|
||||
// i | | |
|
||||
// +-------+---+
|
||||
// c | | |
|
||||
// +-------+---+
|
||||
// f | | |
|
||||
// +-------+---+
|
||||
// o | | |
|
||||
// +-------+---+
|
||||
//
|
||||
//
|
||||
// W and R layouts that are supported by OpenVino (fico):
|
||||
//
|
||||
// W R
|
||||
// +-------+ +---+
|
||||
// f | | f | |
|
||||
// +-------+ +---+
|
||||
// i | | i | |
|
||||
// +-------+ +---+
|
||||
// c | | c | |
|
||||
// +-------+ +---+
|
||||
// o | | o | |
|
||||
// +-------+ +---+
|
||||
//
|
||||
auto zero_axis = op::v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto WR_split = std::make_shared<op::v1::Split>(WR, zero_axis, 4);
|
||||
auto WR_fico = std::make_shared<op::v0::Concat>(
|
||||
OutputVector{WR_split->output(2), WR_split->output(0), WR_split->output(1), WR_split->output(3)},
|
||||
0);
|
||||
auto one = op::v0::Constant::create(element::i32, Shape{}, {1});
|
||||
auto vsplit_axis = op::v0::Constant::create(element::i32, Shape{}, {1});
|
||||
auto split_lengths = op::v0::Constant::create(element::i32, Shape{2}, {input_size, hidden_size});
|
||||
auto vsplit = std::make_shared<op::v1::VariadicSplit>(WR_fico, one /* axis */, split_lengths);
|
||||
auto vsplit = std::make_shared<op::v1::VariadicSplit>(WR_fico, vsplit_axis, split_lengths);
|
||||
Output<Node> W = vsplit->output(0);
|
||||
if (auto constant = ov::util::constantfold_subgraph(W))
|
||||
W = constant;
|
||||
|
@ -157,7 +290,11 @@ ov::pass::LSTMCellFusion::LSTMCellFusion() {
|
|||
if (auto constant = ov::util::constantfold_subgraph(R))
|
||||
R = constant;
|
||||
|
||||
auto B_split = std::make_shared<op::v1::Split>(std::make_shared<op::v0::Squeeze>(B, zero), zero /* axis */, 4);
|
||||
if (B_shape.size() > 1)
|
||||
B = std::make_shared<op::v0::Squeeze>(B, zero_axis);
|
||||
|
||||
// Convert B layout from icfo to fico
|
||||
auto B_split = std::make_shared<op::v1::Split>(B, zero_axis, 4);
|
||||
auto B_f =
|
||||
std::make_shared<op::v1::Add>(B_split->output(2), std::make_shared<op::v0::Squeeze>(ft_additional_bias));
|
||||
|
||||
|
@ -176,13 +313,18 @@ ov::pass::LSTMCellFusion::LSTMCellFusion() {
|
|||
B_fico,
|
||||
hidden_size,
|
||||
std::vector<std::string>{f_activation_name, g_activation_name, h_activation_name});
|
||||
|
||||
if (transformation_callback(lstm_cell)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
lstm_cell->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
|
||||
copy_runtime_info(
|
||||
{
|
||||
pattern_map.at(concat_label).get_node_shared_ptr(),
|
||||
WR.get_node_shared_ptr(),
|
||||
pattern_map.at(matmul_label).get_node_shared_ptr(),
|
||||
matmul,
|
||||
B.get_node_shared_ptr(),
|
||||
pattern_map.at(bias_add_label).get_node_shared_ptr(),
|
||||
pattern_map.at(split_label).get_node_shared_ptr(),
|
||||
|
|
|
@ -16,30 +16,41 @@
|
|||
#include "openvino/op/sigmoid.hpp"
|
||||
#include "openvino/op/split.hpp"
|
||||
#include "openvino/op/tanh.hpp"
|
||||
#include "openvino/pass/constant_folding.hpp"
|
||||
|
||||
using namespace ov;
|
||||
|
||||
TEST_F(TransformationTestsF, LSTMCellFusion) {
|
||||
using LSTMCellFusionParam = std::tuple<bool, // true if second input to matmul is transposed
|
||||
int, // rank of bias (B)
|
||||
int>; // split axis
|
||||
|
||||
class LSTMCellFusionTestSuite : public testing::WithParamInterface<LSTMCellFusionParam>, public TransformationTestsF {};
|
||||
|
||||
TEST_P(LSTMCellFusionTestSuite, SubgraphFusedToLSTMCell) {
|
||||
const auto& param = GetParam();
|
||||
bool weights_transposed = std::get<0>(param);
|
||||
int B_rank = std::get<1>(param);
|
||||
int split_axis_value = std::get<2>(param);
|
||||
size_t input_size = 3;
|
||||
size_t hidden_size = 2;
|
||||
|
||||
{
|
||||
auto X = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, input_size});
|
||||
auto H = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, hidden_size});
|
||||
auto C = std::make_shared<op::v0::Parameter>(element::f32, Shape{1, hidden_size});
|
||||
auto concat = std::make_shared<op::v0::Concat>(OutputVector{X, H}, 1);
|
||||
Shape WR_shape{4 * hidden_size, input_size + hidden_size};
|
||||
Shape WR_shape = weights_transposed ? Shape{4 * hidden_size, input_size + hidden_size}
|
||||
: Shape{input_size + hidden_size, 4 * hidden_size};
|
||||
std::vector<float> WR_values(shape_size(WR_shape));
|
||||
std::iota(WR_values.begin(), WR_values.end(), 0.0f);
|
||||
auto WR = op::v0::Constant::create(element::f32, WR_shape, WR_values);
|
||||
auto matmul = std::make_shared<op::v0::MatMul>(concat, WR, false, true);
|
||||
Shape B_shape{1, 4 * hidden_size};
|
||||
auto matmul = std::make_shared<op::v0::MatMul>(concat, WR, false, weights_transposed);
|
||||
Shape B_shape = B_rank == 2 ? Shape{1, 4 * hidden_size} : Shape{4 * hidden_size};
|
||||
std::vector<float> B_values(shape_size(B_shape));
|
||||
std::iota(B_values.begin(), B_values.end(), 0.0f);
|
||||
auto B = op::v0::Constant::create(element::f32, B_shape, B_values);
|
||||
auto biasadd = std::make_shared<op::v1::Add>(matmul, B);
|
||||
auto one = op::v0::Constant::create(element::i32, Shape{}, {1});
|
||||
auto split = std::make_shared<op::v1::Split>(biasadd, one /* axis */, 4 /* num splits */);
|
||||
auto split_axis = op::v0::Constant::create(element::i32, Shape{}, {split_axis_value});
|
||||
auto split = std::make_shared<op::v1::Split>(biasadd, split_axis, 4 /* num splits */);
|
||||
auto it = std::make_shared<op::v0::Sigmoid>(split->output(0));
|
||||
auto ct = std::make_shared<op::v0::Tanh>(split->output(1));
|
||||
auto ft = std::make_shared<op::v0::Sigmoid>(
|
||||
|
@ -62,28 +73,15 @@ TEST_F(TransformationTestsF, LSTMCellFusion) {
|
|||
auto concat = std::make_shared<op::v0::Concat>(OutputVector{X, H}, 1);
|
||||
Shape W_shape{4 * hidden_size, input_size};
|
||||
Shape R_shape{4 * hidden_size, hidden_size};
|
||||
std::vector<float> W_values{
|
||||
20, 21, 22, 25, 26, 27, 0, 1, 2, 5, 6, 7, 10, 11, 12, 15, 16, 17, 30, 31, 32, 35, 36, 37,
|
||||
};
|
||||
std::vector<float> W_values = weights_transposed
|
||||
? std::vector<float>{20, 21, 22, 25, 26, 27, 0, 1, 2, 5, 6, 7,
|
||||
10, 11, 12, 15, 16, 17, 30, 31, 32, 35, 36, 37}
|
||||
: std::vector<float>{4, 12, 20, 5, 13, 21, 0, 8, 16, 1, 9, 17,
|
||||
2, 10, 18, 3, 11, 19, 6, 14, 22, 7, 15, 23};
|
||||
auto W = op::v0::Constant::create(element::f32, W_shape, W_values);
|
||||
std::vector<float> R_values{
|
||||
23,
|
||||
24,
|
||||
28,
|
||||
29,
|
||||
3,
|
||||
4,
|
||||
8,
|
||||
9,
|
||||
13,
|
||||
14,
|
||||
18,
|
||||
19,
|
||||
33,
|
||||
34,
|
||||
38,
|
||||
39,
|
||||
};
|
||||
std::vector<float> R_values =
|
||||
weights_transposed ? std::vector<float>{23, 24, 28, 29, 3, 4, 8, 9, 13, 14, 18, 19, 33, 34, 38, 39}
|
||||
: std::vector<float>{28, 36, 29, 37, 24, 32, 25, 33, 26, 34, 27, 35, 30, 38, 31, 39};
|
||||
auto R = op::v0::Constant::create(element::f32, R_shape, R_values);
|
||||
Shape B_shape{4 * hidden_size};
|
||||
std::vector<float> B_values{5, 6, 0, 1, 2, 3, 6, 7};
|
||||
|
@ -106,3 +104,7 @@ TEST_F(TransformationTestsF, LSTMCellFusion) {
|
|||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(LSTMCellFusion,
|
||||
LSTMCellFusionTestSuite,
|
||||
testing::Combine(testing::Values(false, true), testing::Values(1, 2), testing::Values(1, -1)));
|
||||
|
|
Loading…
Reference in New Issue
Block a user