Gna matmul fix (#9026)

* [GNA] Handle additional case of matmul transpose with POT

* [GNA] Fix MVN decomposition initialization issue
This commit is contained in:
Szymon Irzabek 2021-12-08 10:47:41 +01:00 committed by GitHub
parent 55955f7ae9
commit 36b4dfedae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 95 additions and 58 deletions

View File

@ -78,7 +78,8 @@ std::shared_ptr<ngraph::Function> CreateMatmulTransposeFunction(
bool enable_last_reshape,
bool enable_add,
bool matmul_on_left_side,
bool enable_fq) {
bool enable_fq1,
bool enable_fq2) {
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
@ -86,16 +87,8 @@ std::shared_ptr<ngraph::Function> CreateMatmulTransposeFunction(
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
std::shared_ptr<ngraph::Node> node = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
const auto matmul_output_shape = node->get_output_shape(0);
if (enable_add) {
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
if (matmul_on_left_side) {
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
} else {
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
}
}
if (enable_fq) {
if (enable_fq1) {
node = std::make_shared<ngraph::opset7::FakeQuantize>(
node,
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
@ -105,6 +98,25 @@ std::shared_ptr<ngraph::Function> CreateMatmulTransposeFunction(
255);
}
if (enable_add) {
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
if (matmul_on_left_side) {
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
} else {
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
}
if (enable_fq2) {
node = std::make_shared<ngraph::opset7::FakeQuantize>(
node,
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
255);
}
}
if (create_reshape_before_transpose) {
auto matmul_output_shape = node->get_output_shape(0);
std::swap(matmul_output_shape[0], matmul_output_shape[1]);
@ -133,7 +145,8 @@ std::shared_ptr<ngraph::Function> CreateMatmulFunction(
bool enable_last_reshape,
bool enable_add,
bool matmul_on_left_side,
bool enable_fq) {
bool enable_fq1,
bool enable_fq2) {
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
@ -141,16 +154,8 @@ std::shared_ptr<ngraph::Function> CreateMatmulFunction(
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
std::shared_ptr<ngraph::Node> node = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
const auto matmul_output_shape = node->get_output_shape(0);
if (enable_add) {
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
if (matmul_on_left_side) {
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
} else {
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
}
}
if (enable_fq) {
if (enable_fq1) {
node = std::make_shared<ngraph::opset7::FakeQuantize>(
node,
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
@ -160,6 +165,25 @@ std::shared_ptr<ngraph::Function> CreateMatmulFunction(
255);
}
if (enable_add) {
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
if (matmul_on_left_side) {
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
} else {
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
}
if (enable_fq2) {
node = std::make_shared<ngraph::opset7::FakeQuantize>(
node,
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
255);
}
}
std::shared_ptr<ngraph::Node> reshape;
auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape);
if (create_reshape_instead_of_transpose) {
@ -235,19 +259,21 @@ TEST(TransformationTests, RemoveTransposeBeforeMatmulTestReshapeInOutEq) {
}
TEST(TransformationTests, InsertTransposeAfterMatmulTest) {
for (auto enable_add : { true, false}) {
for (auto matmul_on_left_side : { true, false}) {
for (auto enable_fq : { true, false}) {
RunTest(
handle_transpose_after_matmul::CreateMatmulFunction(
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq),
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq));
RunTest(
handle_transpose_after_matmul::CreateMatmulFunction(
{1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq),
handle_transpose_after_matmul::CreateMatmulFunction(
{1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq));
for (auto enable_add : { true, false }) {
for (auto matmul_on_left_side : { true, false }) {
for (auto enable_fq1 : { true, false }) {
for (auto enable_fq2 : { true, false }) {
RunTest(
handle_transpose_after_matmul::CreateMatmulFunction(
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2),
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2));
RunTest(
handle_transpose_after_matmul::CreateMatmulFunction(
{1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2),
handle_transpose_after_matmul::CreateMatmulFunction(
{1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2));
}
}
}
}
@ -256,12 +282,14 @@ TEST(TransformationTests, InsertTransposeAfterMatmulTest) {
TEST(TransformationTests, RemoveTransposeAfterMatmulTest) {
for (auto enable_add : { true, false }) {
for (auto matmul_on_left_side : { true, false }) {
for (auto enable_fq : { true, false }) {
RunTest(
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq),
handle_transpose_after_matmul::CreateMatmulFunction(
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq));
for (auto enable_fq1 : { true, false }) {
for (auto enable_fq2 : { true, false }) {
RunTest(
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2),
handle_transpose_after_matmul::CreateMatmulFunction(
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2));
}
}
}
}
@ -270,12 +298,14 @@ TEST(TransformationTests, RemoveTransposeAfterMatmulTest) {
TEST(TransformationTests, RemoveTransposeAfterMatmulTestReshapeInOutEq) {
for (auto enable_add : { true, false }) {
for (auto matmul_on_left_side : { true, false }) {
for (auto enable_fq : { true, false }) {
RunTest(
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq),
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq));
for (auto enable_fq1 : { true, false }) {
for (auto enable_fq2 : { true, false }) {
RunTest(
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2),
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2));
}
}
}
}
@ -285,12 +315,14 @@ TEST(TransformationTests, InsertTransposeAfterMatmulTestReshapeInOutEq) {
for (auto enable_last_reshape : { true, false }) {
for (auto enable_add : { true, false }) {
for (auto matmul_on_left_side : { true, false }) {
for (auto enable_fq : { true, false }) {
RunTest(
handle_transpose_after_matmul::CreateMatmulFunction(
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq),
handle_transpose_after_matmul::CreateMatmulFunction(
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq));
for (auto enable_fq1 : { true, false }) {
for (auto enable_fq2 : { true, false }) {
RunTest(
handle_transpose_after_matmul::CreateMatmulFunction(
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq1, enable_fq2),
handle_transpose_after_matmul::CreateMatmulFunction(
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq1, enable_fq2));
}
}
}
}

View File

@ -75,6 +75,8 @@ static bool GetVerifiedMVNData(const std::shared_ptr<opset8::MVN> mvn, MVNData&
mvn_data.C = mvn_shape[0];
mvn_data.H = mvn_shape[1];
mvn_data.W = mvn_shape[2];
} else {
THROW_GNA_EXCEPTION << "Unsupported MVN shape size: " << mvn_shape_size;
}
// Check if average must be split
@ -224,7 +226,7 @@ static void Decompose(const std::shared_ptr<opset8::MVN> mvn, const MVNData& mvn
static bool Convert(std::shared_ptr<Node> mvn_node) {
const auto mvn = std::dynamic_pointer_cast<opset8::MVN>(mvn_node);
MVNData mvn_data;
MVNData mvn_data = {};
if (!GetVerifiedMVNData(mvn, mvn_data))
return false;

View File

@ -144,12 +144,15 @@ HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({}, [](const ngraph::Output<ngraph::Node>& node) {
auto out_shape = node.get_node_shared_ptr()->get_output_shape(0);
return std::count_if(out_shape.begin(), out_shape.end(), [](size_t n) { return n > 1; }) > 1; });
auto add_left = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, ngraph::pattern::any_input()});
auto add_right = ngraph::pattern::wrap_type<ngraph::opset8::Add>({ngraph::pattern::any_input(), matmul});
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add_left, add_right});
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input, ngraph::pattern::any_input(),
auto fq1 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({matmul, ngraph::pattern::any_input(),
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
auto act_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq_input, fq});
auto add_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, fq1});
auto add_left = ngraph::pattern::wrap_type<ngraph::opset8::Add>({add_input, ngraph::pattern::any_input()});
auto add_right = ngraph::pattern::wrap_type<ngraph::opset8::Add>({ngraph::pattern::any_input(), add_input});
auto fq2_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add_left, add_right});
auto fq2 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq2_input, ngraph::pattern::any_input(),
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
auto act_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq2_input, fq2});
auto act = ngraph::pattern::wrap_type<ngraph::opset8::Relu, ngraph::opset8::Sigmoid,
ngraph::opset8::Tanh, ngraph::opset8::Abs, ngraph::opset8::Log, ngraph::opset8::Exp,
ngraph::opset8::Sign, ngraph::opset8::Clamp>({act_input});
@ -169,7 +172,7 @@ HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
if (!GNALimitations::IsTransposeSupported(reshape_node->get_input_shape(0))) return false;
auto iter = pattern_map.find(act);
if (iter == pattern_map.end() &&
(iter = pattern_map.find(fq)) == pattern_map.end() &&
(iter = pattern_map.find(fq2)) == pattern_map.end() &&
(iter = pattern_map.find(add_left)) == pattern_map.end() &&
(iter = pattern_map.find(add_right)) == pattern_map.end() &&
(iter = pattern_map.find(matmul)) == pattern_map.end()) {