Gna matmul fix (#9026)
* [GNA] Handle additional case of matmul transpose with POT * [GNA] Fix MVN decomposition initialization issue
This commit is contained in:
parent
55955f7ae9
commit
36b4dfedae
@ -78,7 +78,8 @@ std::shared_ptr<ngraph::Function> CreateMatmulTransposeFunction(
|
|||||||
bool enable_last_reshape,
|
bool enable_last_reshape,
|
||||||
bool enable_add,
|
bool enable_add,
|
||||||
bool matmul_on_left_side,
|
bool matmul_on_left_side,
|
||||||
bool enable_fq) {
|
bool enable_fq1,
|
||||||
|
bool enable_fq2) {
|
||||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
|
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
|
||||||
|
|
||||||
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
|
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
|
||||||
@ -86,16 +87,8 @@ std::shared_ptr<ngraph::Function> CreateMatmulTransposeFunction(
|
|||||||
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
|
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
|
||||||
std::shared_ptr<ngraph::Node> node = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
|
std::shared_ptr<ngraph::Node> node = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
|
||||||
const auto matmul_output_shape = node->get_output_shape(0);
|
const auto matmul_output_shape = node->get_output_shape(0);
|
||||||
if (enable_add) {
|
|
||||||
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
|
|
||||||
if (matmul_on_left_side) {
|
|
||||||
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
|
|
||||||
} else {
|
|
||||||
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (enable_fq) {
|
if (enable_fq1) {
|
||||||
node = std::make_shared<ngraph::opset7::FakeQuantize>(
|
node = std::make_shared<ngraph::opset7::FakeQuantize>(
|
||||||
node,
|
node,
|
||||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||||
@ -105,6 +98,25 @@ std::shared_ptr<ngraph::Function> CreateMatmulTransposeFunction(
|
|||||||
255);
|
255);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (enable_add) {
|
||||||
|
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
|
||||||
|
if (matmul_on_left_side) {
|
||||||
|
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
|
||||||
|
} else {
|
||||||
|
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (enable_fq2) {
|
||||||
|
node = std::make_shared<ngraph::opset7::FakeQuantize>(
|
||||||
|
node,
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||||
|
255);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (create_reshape_before_transpose) {
|
if (create_reshape_before_transpose) {
|
||||||
auto matmul_output_shape = node->get_output_shape(0);
|
auto matmul_output_shape = node->get_output_shape(0);
|
||||||
std::swap(matmul_output_shape[0], matmul_output_shape[1]);
|
std::swap(matmul_output_shape[0], matmul_output_shape[1]);
|
||||||
@ -133,7 +145,8 @@ std::shared_ptr<ngraph::Function> CreateMatmulFunction(
|
|||||||
bool enable_last_reshape,
|
bool enable_last_reshape,
|
||||||
bool enable_add,
|
bool enable_add,
|
||||||
bool matmul_on_left_side,
|
bool matmul_on_left_side,
|
||||||
bool enable_fq) {
|
bool enable_fq1,
|
||||||
|
bool enable_fq2) {
|
||||||
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
|
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
|
||||||
|
|
||||||
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
|
std::vector<size_t> data(ngraph::shape_size(matmul_shape));
|
||||||
@ -141,16 +154,8 @@ std::shared_ptr<ngraph::Function> CreateMatmulFunction(
|
|||||||
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
|
auto matmul_constant = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_shape, data);
|
||||||
std::shared_ptr<ngraph::Node> node = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
|
std::shared_ptr<ngraph::Node> node = std::make_shared<ngraph::opset7::MatMul>(input_params, matmul_constant);
|
||||||
const auto matmul_output_shape = node->get_output_shape(0);
|
const auto matmul_output_shape = node->get_output_shape(0);
|
||||||
if (enable_add) {
|
|
||||||
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
|
|
||||||
if (matmul_on_left_side) {
|
|
||||||
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
|
|
||||||
} else {
|
|
||||||
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (enable_fq) {
|
if (enable_fq1) {
|
||||||
node = std::make_shared<ngraph::opset7::FakeQuantize>(
|
node = std::make_shared<ngraph::opset7::FakeQuantize>(
|
||||||
node,
|
node,
|
||||||
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||||
@ -160,6 +165,25 @@ std::shared_ptr<ngraph::Function> CreateMatmulFunction(
|
|||||||
255);
|
255);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (enable_add) {
|
||||||
|
auto add_const = ngraph::opset7::Constant::create(ngraph::element::i64, matmul_output_shape, {1});
|
||||||
|
if (matmul_on_left_side) {
|
||||||
|
node = std::make_shared<ngraph::opset7::Add>(add_const, node);
|
||||||
|
} else {
|
||||||
|
node = std::make_shared<ngraph::opset7::Add>(node, add_const);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (enable_fq2) {
|
||||||
|
node = std::make_shared<ngraph::opset7::FakeQuantize>(
|
||||||
|
node,
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {-0.1}),
|
||||||
|
ngraph::opset7::Constant::create(ngraph::element::f32, {1}, {0.1}),
|
||||||
|
255);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Node> reshape;
|
std::shared_ptr<ngraph::Node> reshape;
|
||||||
auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape);
|
auto shape_const = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_shape.size()}, reshape_shape);
|
||||||
if (create_reshape_instead_of_transpose) {
|
if (create_reshape_instead_of_transpose) {
|
||||||
@ -235,19 +259,21 @@ TEST(TransformationTests, RemoveTransposeBeforeMatmulTestReshapeInOutEq) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(TransformationTests, InsertTransposeAfterMatmulTest) {
|
TEST(TransformationTests, InsertTransposeAfterMatmulTest) {
|
||||||
for (auto enable_add : { true, false}) {
|
for (auto enable_add : { true, false }) {
|
||||||
for (auto matmul_on_left_side : { true, false}) {
|
for (auto matmul_on_left_side : { true, false }) {
|
||||||
for (auto enable_fq : { true, false}) {
|
for (auto enable_fq1 : { true, false }) {
|
||||||
RunTest(
|
for (auto enable_fq2 : { true, false }) {
|
||||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
RunTest(
|
||||||
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq),
|
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2),
|
||||||
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq));
|
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||||
RunTest(
|
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2));
|
||||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
RunTest(
|
||||||
{1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq),
|
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
{1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2),
|
||||||
{1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq));
|
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||||
|
{1, 256}, {256, 256}, {8, 32}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -256,12 +282,14 @@ TEST(TransformationTests, InsertTransposeAfterMatmulTest) {
|
|||||||
TEST(TransformationTests, RemoveTransposeAfterMatmulTest) {
|
TEST(TransformationTests, RemoveTransposeAfterMatmulTest) {
|
||||||
for (auto enable_add : { true, false }) {
|
for (auto enable_add : { true, false }) {
|
||||||
for (auto matmul_on_left_side : { true, false }) {
|
for (auto matmul_on_left_side : { true, false }) {
|
||||||
for (auto enable_fq : { true, false }) {
|
for (auto enable_fq1 : { true, false }) {
|
||||||
RunTest(
|
for (auto enable_fq2 : { true, false }) {
|
||||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
RunTest(
|
||||||
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq),
|
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
{4, 1}, {1, 8}, {2, 16}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2),
|
||||||
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq));
|
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||||
|
{4, 1}, {1, 8}, {2, 16}, true, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -270,12 +298,14 @@ TEST(TransformationTests, RemoveTransposeAfterMatmulTest) {
|
|||||||
TEST(TransformationTests, RemoveTransposeAfterMatmulTestReshapeInOutEq) {
|
TEST(TransformationTests, RemoveTransposeAfterMatmulTestReshapeInOutEq) {
|
||||||
for (auto enable_add : { true, false }) {
|
for (auto enable_add : { true, false }) {
|
||||||
for (auto matmul_on_left_side : { true, false }) {
|
for (auto matmul_on_left_side : { true, false }) {
|
||||||
for (auto enable_fq : { true, false }) {
|
for (auto enable_fq1 : { true, false }) {
|
||||||
RunTest(
|
for (auto enable_fq2 : { true, false }) {
|
||||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
RunTest(
|
||||||
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq),
|
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||||
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2),
|
||||||
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq));
|
handle_transpose_after_matmul::CreateMatmulTransposeFunction(
|
||||||
|
{4, 1}, {1, 8}, {8, 4}, false, true, enable_add, matmul_on_left_side, enable_fq1, enable_fq2));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -285,12 +315,14 @@ TEST(TransformationTests, InsertTransposeAfterMatmulTestReshapeInOutEq) {
|
|||||||
for (auto enable_last_reshape : { true, false }) {
|
for (auto enable_last_reshape : { true, false }) {
|
||||||
for (auto enable_add : { true, false }) {
|
for (auto enable_add : { true, false }) {
|
||||||
for (auto matmul_on_left_side : { true, false }) {
|
for (auto matmul_on_left_side : { true, false }) {
|
||||||
for (auto enable_fq : { true, false }) {
|
for (auto enable_fq1 : { true, false }) {
|
||||||
RunTest(
|
for (auto enable_fq2 : { true, false }) {
|
||||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
RunTest(
|
||||||
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq),
|
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||||
handle_transpose_after_matmul::CreateMatmulFunction(
|
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq1, enable_fq2),
|
||||||
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq));
|
handle_transpose_after_matmul::CreateMatmulFunction(
|
||||||
|
{4, 1}, {1, 8}, {4, 8}, false, enable_last_reshape, enable_add, matmul_on_left_side, enable_fq1, enable_fq2));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -75,6 +75,8 @@ static bool GetVerifiedMVNData(const std::shared_ptr<opset8::MVN> mvn, MVNData&
|
|||||||
mvn_data.C = mvn_shape[0];
|
mvn_data.C = mvn_shape[0];
|
||||||
mvn_data.H = mvn_shape[1];
|
mvn_data.H = mvn_shape[1];
|
||||||
mvn_data.W = mvn_shape[2];
|
mvn_data.W = mvn_shape[2];
|
||||||
|
} else {
|
||||||
|
THROW_GNA_EXCEPTION << "Unsupported MVN shape size: " << mvn_shape_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if average must be split
|
// Check if average must be split
|
||||||
@ -224,7 +226,7 @@ static void Decompose(const std::shared_ptr<opset8::MVN> mvn, const MVNData& mvn
|
|||||||
|
|
||||||
static bool Convert(std::shared_ptr<Node> mvn_node) {
|
static bool Convert(std::shared_ptr<Node> mvn_node) {
|
||||||
const auto mvn = std::dynamic_pointer_cast<opset8::MVN>(mvn_node);
|
const auto mvn = std::dynamic_pointer_cast<opset8::MVN>(mvn_node);
|
||||||
MVNData mvn_data;
|
MVNData mvn_data = {};
|
||||||
|
|
||||||
if (!GetVerifiedMVNData(mvn, mvn_data))
|
if (!GetVerifiedMVNData(mvn, mvn_data))
|
||||||
return false;
|
return false;
|
||||||
|
@ -144,12 +144,15 @@ HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
|
|||||||
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({}, [](const ngraph::Output<ngraph::Node>& node) {
|
auto matmul = ngraph::pattern::wrap_type<ngraph::opset8::MatMul>({}, [](const ngraph::Output<ngraph::Node>& node) {
|
||||||
auto out_shape = node.get_node_shared_ptr()->get_output_shape(0);
|
auto out_shape = node.get_node_shared_ptr()->get_output_shape(0);
|
||||||
return std::count_if(out_shape.begin(), out_shape.end(), [](size_t n) { return n > 1; }) > 1; });
|
return std::count_if(out_shape.begin(), out_shape.end(), [](size_t n) { return n > 1; }) > 1; });
|
||||||
auto add_left = ngraph::pattern::wrap_type<ngraph::opset8::Add>({matmul, ngraph::pattern::any_input()});
|
auto fq1 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({matmul, ngraph::pattern::any_input(),
|
||||||
auto add_right = ngraph::pattern::wrap_type<ngraph::opset8::Add>({ngraph::pattern::any_input(), matmul});
|
|
||||||
auto fq_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add_left, add_right});
|
|
||||||
auto fq = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq_input, ngraph::pattern::any_input(),
|
|
||||||
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
|
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
|
||||||
auto act_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq_input, fq});
|
auto add_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, fq1});
|
||||||
|
auto add_left = ngraph::pattern::wrap_type<ngraph::opset8::Add>({add_input, ngraph::pattern::any_input()});
|
||||||
|
auto add_right = ngraph::pattern::wrap_type<ngraph::opset8::Add>({ngraph::pattern::any_input(), add_input});
|
||||||
|
auto fq2_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{matmul, add_left, add_right});
|
||||||
|
auto fq2 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({fq2_input, ngraph::pattern::any_input(),
|
||||||
|
ngraph::pattern::any_input(), ngraph::pattern::any_input(), ngraph::pattern::any_input()});
|
||||||
|
auto act_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{fq2_input, fq2});
|
||||||
auto act = ngraph::pattern::wrap_type<ngraph::opset8::Relu, ngraph::opset8::Sigmoid,
|
auto act = ngraph::pattern::wrap_type<ngraph::opset8::Relu, ngraph::opset8::Sigmoid,
|
||||||
ngraph::opset8::Tanh, ngraph::opset8::Abs, ngraph::opset8::Log, ngraph::opset8::Exp,
|
ngraph::opset8::Tanh, ngraph::opset8::Abs, ngraph::opset8::Log, ngraph::opset8::Exp,
|
||||||
ngraph::opset8::Sign, ngraph::opset8::Clamp>({act_input});
|
ngraph::opset8::Sign, ngraph::opset8::Clamp>({act_input});
|
||||||
@ -169,7 +172,7 @@ HandleTransposeAfterMatMul::HandleTransposeAfterMatMul() {
|
|||||||
if (!GNALimitations::IsTransposeSupported(reshape_node->get_input_shape(0))) return false;
|
if (!GNALimitations::IsTransposeSupported(reshape_node->get_input_shape(0))) return false;
|
||||||
auto iter = pattern_map.find(act);
|
auto iter = pattern_map.find(act);
|
||||||
if (iter == pattern_map.end() &&
|
if (iter == pattern_map.end() &&
|
||||||
(iter = pattern_map.find(fq)) == pattern_map.end() &&
|
(iter = pattern_map.find(fq2)) == pattern_map.end() &&
|
||||||
(iter = pattern_map.find(add_left)) == pattern_map.end() &&
|
(iter = pattern_map.find(add_left)) == pattern_map.end() &&
|
||||||
(iter = pattern_map.find(add_right)) == pattern_map.end() &&
|
(iter = pattern_map.find(add_right)) == pattern_map.end() &&
|
||||||
(iter = pattern_map.find(matmul)) == pattern_map.end()) {
|
(iter = pattern_map.find(matmul)) == pattern_map.end()) {
|
||||||
|
Loading…
Reference in New Issue
Block a user