Gna matmul fix (#9026)
* [GNA] Handle additional case of matmul transpose with POT * [GNA] Fix MVN decomposition initialization issue
This commit is contained in:
parent
55955f7ae9
commit
36b4dfedae
@ -78,7 +78,8 @@ std::shared_ptr<ngraph::Function> CreateMatmulTransposeFunction(
|
||||
bool enable_last_reshape,
|
||||
bool enable_add,
|
||||
bool matmul_on_left_side,
|
||||
bool enable_fq) {
|
||||
bool enable_fq1,
|
||||
bool enable_fq2) {
|
||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
|
||||
|
||||
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
|
||||
@ -86,16 +87,8 @@ std::shared_ptr<ngraph::Function> CreateMatmulTransposeFunction(
|
||||
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
|
||||
std::shared_ptr<ngraph::Node> node = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
|
||||
const auto matmul_output_shape = node->get_output_shape(0);
|
||||
if (enable_add) {
|
||||
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
|
||||
if (matmul_on_left_side) {
|
||||
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
|
||||
} else {
|
||||
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
|
||||
}
|
||||
}
|
||||
|
||||
if (enable_fq) {
|
||||
if (enable_fq1) {
|
||||
node = std::make_shared<ngraph::opset7::FakeQuantize>(
|
||||
node,
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||
@ -105,6 +98,25 @@ std::shared_ptr<ngraph::Function> CreateMatmulTransposeFunction(
|
||||
255);
|
||||
}
|
||||
|
||||
if (enable_add) {
|
||||
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
|
||||
if (matmul_on_left_side) {
|
||||
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
|
||||
} else {
|
||||
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
|
||||
}
|
||||
|
||||
if (enable_fq2) {
|
||||
node = std::make_shared<ngraph::opset7::FakeQuantize>(
|
||||
node,
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||
255);
|
||||
}
|
||||
}
|
||||
|
||||
if (create_reshape_before_transpose) {
|
||||
auto matmul_output_shape = node->get_output_shape(0);
|
||||
std::swap(matmul_output_shape[0], matmul_output_shape[1]);
|
||||
@ -133,7 +145,8 @@ std::shared_ptr<ngraph::Function> CreateMatmulFunction(
|
||||
bool enable_last_reshape,
|
||||
bool enable_add,
|
||||
bool matmul_on_left_side,
|
||||
bool enable_fq) {
|
||||
bool enable_fq1,
|
||||
bool enable_fq2) {
|
||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
|
||||
|
||||
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
|
||||
@ -141,16 +154,8 @@ std::shared_ptr<ngraph::Function> CreateMatmulFunction(
|
||||
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
|
||||
std::shared_ptr<ngraph::Node> node = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
|
||||
const auto matmul_output_shape = node->get_output_shape(0);
|
||||
if (enable_add) {
|
||||
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
|
||||
if (matmul_on_left_side) {
|
||||
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
|
||||
} else {
|
||||
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
|
||||
}
|
||||
}
|
||||
|
||||
if (enable_fq) {
|
||||
if (enable_fq1) {
|
||||
node = std::make_shared<ngraph::opset7::FakeQuantize>(
|
||||
node,
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||
@ -160,6 +165,25 @@ std::shared_ptr<ngraph::Function> CreateMatmulFunction(
|
||||
255);
|
||||
}
|
||||
|
||||
if (enable_add) {
|
||||
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
|
||||
if (matmul_on_left_side) {
|
||||
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
|
||||
} else {
|
||||
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
|
||||
}
|
||||
|
||||
if (enable_fq2) {
|
||||
node = std::make_shared<ngraph::opset7::FakeQuantize>(
|
||||
node,
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||
255);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> reshape;
|
||||
auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape);
|
||||
if (create_reshape_instead_of_transpose) {
|
||||
@ -235,19 +259,21 @@ TEST(TransformationTests, RemoveTransposeBeforeMatmulTestReshapeInOutEq) {
|
||||
}
|
||||
|
||||
TEST(TransformationTests, InsertTransposeAfterMatmulTest) {
|
||||
for (auto enable_add : { true, false}) {
|
||||
for (auto matmul_on_left_side : { true, false}) {
|
||||
for (auto enable_fq : { true, false}) {
|
||||
RunTest(
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq),
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq));
|
||||
RunTest(
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq),
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq));
|
||||
for (auto enable_add : { true, false }) {
|
||||
for (auto matmul_on_left_side : { true, false }) {
|
||||
for (auto enable_fq1 : { true, false }) {
|
||||
for (auto enable_fq2 : { true, false }) {
|
||||
RunTest(
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2),
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2));
|
||||
RunTest(
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2),
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -256,12 +282,14 @@ TEST(TransformationTests, InsertTransposeAfterMatmulTest) {
|
||||
TEST(TransformationTests, RemoveTransposeAfterMatmulTest) {
|
||||
for (auto enable_add : { true, false }) {
|
||||
for (auto matmul_on_left_side : { true, false }) {
|
||||
for (auto enable_fq : { true, false }) {
|
||||
RunTest(
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq),
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq));
|
||||
for (auto enable_fq1 : { true, false }) {
|
||||
for (auto enable_fq2 : { true, false }) {
|
||||
RunTest(
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2),
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -270,12 +298,14 @@ TEST(TransformationTests, RemoveTransposeAfterMatmulTest) {
|
||||
TEST(TransformationTests, RemoveTransposeAfterMatmulTestReshapeInOutEq) {
|
||||
for (auto enable_add : { true, false }) {
|
||||
for (auto matmul_on_left_side : { true, false }) {
|
||||
for (auto enable_fq : { true, false }) {
|
||||
RunTest(
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq),
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq));
|
||||
for (auto enable_fq1 : { true, false }) {
|
||||
for (auto enable_fq2 : { true, false }) {
|
||||
RunTest(
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2),
|
||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -285,12 +315,14 @@ TEST(TransformationTests, InsertTransposeAfterMatmulTestReshapeInOutEq) {
|
||||
for (auto enable_last_reshape : { true, false }) {
|
||||
for (auto enable_add : { true, false }) {
|
||||
for (auto matmul_on_left_side : { true, false }) {
|
||||
for (auto enable_fq : { true, false }) {
|
||||
RunTest(
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq),
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq));
|
||||
for (auto enable_fq1 : { true, false }) {
|
||||
for (auto enable_fq2 : { true, false }) {
|
||||
RunTest(
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq1, enable_fq2),
|
||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq1, enable_fq2));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -75,6 +75,8 @@ static bool GetVerifiedMVNData(const std::shared_ptr<opset8::MVN> mvn, MVNData&
|
||||
mvn_data.C = mvn_shape[0];
|
||||
mvn_data.H = mvn_shape[1];
|
||||
mvn_data.W = mvn_shape[2];
|
||||
} else {
|
||||
THROW_GNA_EXCEPTION << "Unsupported MVN shape size: " << mvn_shape_size;
|
||||
}
|
||||
|
||||
// Check if average must be split
|
||||
@ -224,7 +226,7 @@ static void Decompose(const std::shared_ptr<opset8::MVN> mvn, const MVNData& mvn
|
||||
|
||||
static bool Convert(std::shared_ptr<Node> mvn_node) {
|
||||
const auto mvn = std::dynamic_pointer_cast<opset8::MVN>(mvn_node);
|
||||
MVNData mvn_data;
|
||||
MVNData mvn_data = {};
|
||||
|
||||
if (!GetVerifiedMVNData(mvn, mvn_data))
|
||||
return false;
|
||||
|
@ -144,12 +144,15 @@ HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
|
||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({}, [](const ngraph::Output<ngraph::Node>& node) {
|
||||
auto out_shape = node.get_node_shared_ptr()->get_output_shape(0);
|
||||
return std::count_if(out_shape.begin(), out_shape.end(), [](size_t n) { return n > 1; }) > 1; });
|
||||
auto add_left = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, ngraph::pattern::any_input()});
|
||||
auto add_right = ngraph::pattern::wrap_type<ngraph::opset8::Add>({ngraph::pattern::any_input(), matmul});
|
||||
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add_left, add_right});
|
||||
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input, ngraph::pattern::any_input(),
|
||||
auto fq1 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({matmul, ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
|
||||
auto act_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq_input, fq});
|
||||
auto add_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, fq1});
|
||||
auto add_left = ngraph::pattern::wrap_type<ngraph::opset8::Add>({add_input, ngraph::pattern::any_input()});
|
||||
auto add_right = ngraph::pattern::wrap_type<ngraph::opset8::Add>({ngraph::pattern::any_input(), add_input});
|
||||
auto fq2_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add_left, add_right});
|
||||
auto fq2 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq2_input, ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
|
||||
auto act_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq2_input, fq2});
|
||||
auto act = ngraph::pattern::wrap_type<ngraph::opset8::Relu, ngraph::opset8::Sigmoid,
|
||||
ngraph::opset8::Tanh, ngraph::opset8::Abs, ngraph::opset8::Log, ngraph::opset8::Exp,
|
||||
ngraph::opset8::Sign, ngraph::opset8::Clamp>({act_input});
|
||||
@ -169,7 +172,7 @@ HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
|
||||
if (!GNALimitations::IsTransposeSupported(reshape_node->get_input_shape(0))) return false;
|
||||
auto iter = pattern_map.find(act);
|
||||
if (iter == pattern_map.end() &&
|
||||
(iter = pattern_map.find(fq)) == pattern_map.end() &&
|
||||
(iter = pattern_map.find(fq2)) == pattern_map.end() &&
|
||||
(iter = pattern_map.find(add_left)) == pattern_map.end() &&
|
||||
(iter = pattern_map.find(add_right)) == pattern_map.end() &&
|
||||
(iter = pattern_map.find(matmul)) == pattern_map.end()) {
|
||||
|
Loading…
Reference in New Issue
Block a user