[GNA] Fixed failed to inserting reshape around MatMul (#7833)

* Fixed failed to inserting reshape around MatMul.

* remove unnecessary

* added comments
This commit is contained in:
Dmitrii Khurtin
2021-10-07 12:18:22 +03:00
committed by GitHub
parent d11ff04130
commit 6f43b8d0af
2 changed files with 167 additions and 70 deletions

View File

@@ -26,7 +26,7 @@ static bool InsertReshape(
const std::shared_ptr<ngraph::Node>& matmul2,
const std::shared_ptr<ngraph::Node>& add1 = nullptr,
const std::shared_ptr<ngraph::Node>& add2 = nullptr,
const std::shared_ptr<ngraph::Node>& fake_quantize2 = nullptr,
const std::shared_ptr<ngraph::Node>& fake_quantize = nullptr,
const std::shared_ptr<ngraph::Node>& transpose = nullptr) {
const auto& pattern_map = matcher.get_pattern_value_map();
size_t matmul_input_index = 1;
@@ -41,38 +41,58 @@ static bool InsertReshape(
}
std::shared_ptr<ngraph::Node> matmul_node = iter->second.get_node_shared_ptr();
auto matmul_node_shape = matmul_node->get_output_shape(0);
if ((iter = pattern_map.find(input)) == std::end(pattern_map)) {
return false;
}
std::shared_ptr<ngraph::Node> first_node = iter->second.get_node_shared_ptr();
auto first_node = iter->second.get_node_shared_ptr();
std::vector<std::shared_ptr<ngraph::Node>> nodes = { matmul_node };
for (auto node : {add2, add1, fake_quantize, transpose}) {
iter = pattern_map.find(node);
if (iter != pattern_map.end()) {
nodes.push_back(iter->second.get_node_shared_ptr());
}
}
auto last_node_shape = nodes.back()->get_output_shape(0);
auto reshape_input_node = std::dynamic_pointer_cast<ngraph::opset8::Reshape>(first_node);
bool need_reshape_before = !reshape_input_node || reshape_input_node->get_output_shape(0).size() != 2;
if (need_reshape_before) {
auto input_shape = first_node->get_output_shape(0);
std::vector<size_t> before_shape(2, 1);
std::copy_if(input_shape.begin(), input_shape.end(), before_shape.begin(), [](size_t e) { return e > 1; });
std::vector<int> before_shape = {-1, static_cast<int>(first_node->get_output_shape(0).back())};
auto reshape_before_node = std::make_shared<ngraph::opset8::Reshape>(first_node,
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{before_shape.size()}, before_shape), false);
reshape_before_node->set_friendly_name(matmul_node->get_friendly_name() + "/reshape_before_matmul");
ngraph::copy_runtime_info(first_node, reshape_before_node);
matmul_node->input(matmul_input_index).replace_source_output(reshape_before_node->output(0));
if (auto transpose_node = std::dynamic_pointer_cast<ngraph::opset8::Transpose>(nodes.back())) {
nodes.pop_back();
std::reverse(nodes.begin(), nodes.end());
while (!nodes.empty()) {
auto node_copy = nodes.back()->clone_with_new_inputs(nodes.back()->input_values());
ngraph::copy_runtime_info(nodes.back(), node_copy);
ngraph::replace_node(nodes.back(), node_copy);
nodes.pop_back();
}
auto transpose_input_shape = transpose_node->input_values()[0].get_node_shared_ptr()->get_output_shape(0);
auto transpose_constant_shape = transpose_node->input_values()[1].get_node_shared_ptr()->get_output_shape(0);
if (std::count_if(transpose_input_shape.begin(), transpose_input_shape.end(), [](size_t n) { return n > 1; }) > 2) {
THROW_GNA_EXCEPTION << "The number of dimensions that are greater than 1 is greater than 2"
<< " for Transpose layer (" << transpose_node->get_friendly_name() << ")."
<< " For this reason, there is no way to determine permutation shape.";
}
std::vector<int> permutation_shape = {1, 0};
auto transpose_node_copy = transpose_node->clone_with_new_inputs(
{transpose_node->input_values()[0],
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
ngraph::Shape{permutation_shape.size()}, permutation_shape)});
ngraph::copy_runtime_info(transpose_node, transpose_node_copy);
ngraph::replace_node(transpose_node, transpose_node_copy);
nodes.push_back(transpose_node_copy);
}
}
std::shared_ptr<ngraph::Node> last_node;
iter = pattern_map.find(transpose);
if (iter == pattern_map.end() &&
(iter = pattern_map.find(fake_quantize2)) == pattern_map.end() &&
(iter = pattern_map.find(add1)) == pattern_map.end() &&
(iter = pattern_map.find(add2)) == pattern_map.end()) {
last_node = matmul_node;
} else {
last_node = iter->second.get_node_shared_ptr();
}
auto consumers = last_node->output(0).get_target_inputs();
auto last_node_shape = last_node->get_output_shape(0);
auto consumers = nodes.back()->output(0).get_target_inputs();
bool need_reshape_after = false;
for (auto consumer : consumers) {
auto reshape_output_node = dynamic_cast<ngraph::opset8::Reshape*>(consumer.get_node());
@@ -83,10 +103,11 @@ static bool InsertReshape(
}
if (need_reshape_after) {
auto reshape_after_node = std::make_shared<ngraph::opset8::Reshape>(last_node,
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{last_node_shape.size()}, last_node_shape), false);
reshape_after_node->set_friendly_name(last_node->get_friendly_name());
ngraph::copy_runtime_info(last_node, reshape_after_node);
auto reshape_after_node = std::make_shared<ngraph::opset8::Reshape>(nodes.back(),
std::make_shared<ngraph::opset8::Constant>(ngraph::element::Type_t::i64,
ngraph::Shape{last_node_shape.size()}, last_node_shape), false);
reshape_after_node->set_friendly_name(nodes.back()->get_friendly_name());
ngraph::copy_runtime_info(nodes.back(), reshape_after_node);
for (auto consumer : consumers) {
consumer.replace_source_output(reshape_after_node);
}

View File

@@ -13,7 +13,7 @@
#include <transformations/init_node_info.hpp>
#include <numeric>
template<bool ADD, bool ADD_FIRST_INPUT_NOT_CONSTANT, bool FQ>
template<bool ADD = false, bool ADD_FIRST_INPUT_NOT_CONSTANT = false, bool FQ = false, bool TRANSPOSE = false>
struct InsertReshapeAroundMatmulTest {
static std::shared_ptr<ngraph::Node> CreateAdd(std::shared_ptr<ngraph::Node> input, const ngraph::Shape& constant_shape) {
std::vector<size_t> data(ngraph::shape_size(constant_shape));
@@ -24,7 +24,8 @@ struct InsertReshapeAroundMatmulTest {
static std::shared_ptr<ngraph::Node> CreateMatmul(
std::shared_ptr<ngraph::Node> input,
const ngraph::Shape& matmul_constant_shape) {
const ngraph::Shape& matmul_constant_shape,
const ngraph::Shape& permutation_shape) {
std::vector<size_t> data(ngraph::shape_size(matmul_constant_shape));
std::iota(std::begin(data), std::end(data), 1);
auto constant = ngraph::opset8::Constant::create(ngraph::element::i64, matmul_constant_shape, data);
@@ -55,16 +56,22 @@ struct InsertReshapeAroundMatmulTest {
255);
}
if (TRANSPOSE) {
node = std::make_shared<ngraph::opset8::Transpose>(
node,
ngraph::opset8::Constant::create(ngraph::element::i64, {permutation_shape.size()}, permutation_shape));
}
return node;
}
static std::shared_ptr<ngraph::Function> CreateFunction(
const ngraph::Shape& input_shape,
const ngraph::Shape& matmul_constant_shape,
const ngraph::Shape& result_shape) {
const ngraph::Shape& permutation_shape = ngraph::Shape()) {
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, input_shape);
auto before = std::make_shared<ngraph::opset8::Relu>(input);
auto matmul = CreateMatmul(before, matmul_constant_shape);
auto matmul = CreateMatmul(before, matmul_constant_shape, permutation_shape);
auto after = std::make_shared<ngraph::opset8::Relu>(matmul);
return std::make_shared<ngraph::Function>(
ngraph::ResultVector{std::make_shared<ngraph::opset8::Result>(after)},
@@ -73,16 +80,16 @@ struct InsertReshapeAroundMatmulTest {
static std::shared_ptr<ngraph::Function> CreateReferenceFunction(
const ngraph::Shape& input_shape,
const ngraph::Shape& reshape_before_shape,
const std::vector<int>& reshape_before_shape,
const ngraph::Shape& matmul_constant_shape,
const ngraph::Shape& reshape_after_shape,
const ngraph::Shape& result_shape) {
const ngraph::Shape& permutation_shape = ngraph::Shape()) {
auto input = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::i64, input_shape);
auto before = std::make_shared<ngraph::opset8::Relu>(input);
auto reshape_before_constant = ngraph::opset8::Constant::create(ngraph::element::i64,
ngraph::Shape{reshape_before_shape.size()}, reshape_before_shape);
auto reshape_before = std::make_shared<ngraph::opset8::Reshape>(before, reshape_before_constant, false);
auto matmul = CreateMatmul(reshape_before, matmul_constant_shape);
auto matmul = CreateMatmul(reshape_before, matmul_constant_shape, permutation_shape);
auto reshape_after_constant = ngraph::opset8::Constant::create(ngraph::element::i64,
ngraph::Shape{reshape_after_shape.size()}, reshape_after_shape);
auto reshape_after = std::make_shared<ngraph::opset8::Reshape>(matmul, reshape_after_constant, false);
@@ -116,75 +123,144 @@ void RunTest(const std::shared_ptr<ngraph::Function>& func, const std::shared_pt
TEST(TransformationTests, InsertReshapeAroundMatmul) {
RunTest(
InsertReshapeAroundMatmulTest<false, false, false>::
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<false, false, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
InsertReshapeAroundMatmulTest<>::
CreateFunction({1, 6, 8}, {8, 10}),
InsertReshapeAroundMatmulTest<>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 6, 10}));
RunTest(
InsertReshapeAroundMatmulTest<false, false, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<false, false, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
InsertReshapeAroundMatmulTest<>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 6, 10}));
RunTest(
InsertReshapeAroundMatmulTest<false, false, false>::
CreateFunction({1, 6, 1, 8}, {8, 10}, {1, 6, 1, 10}),
InsertReshapeAroundMatmulTest<false, false, false>::
CreateReferenceFunction({1, 6, 1, 8}, {6, 8}, {8, 10}, {1, 6, 1, 10}, {1, 6, 1, 10}));
InsertReshapeAroundMatmulTest<>::
CreateFunction({1, 6, 1, 8}, {8, 10}),
InsertReshapeAroundMatmulTest<>::
CreateReferenceFunction({1, 6, 1, 8}, {-1, 8}, {8, 10}, {1, 6, 1, 10}));
RunTest(
InsertReshapeAroundMatmulTest<false, false, false>::
CreateReferenceFunction({1, 6, 1, 8}, {6, 8}, {8, 10}, {1, 6, 1, 10}, {1, 6, 1, 10}),
InsertReshapeAroundMatmulTest<false, false, false>::
CreateReferenceFunction({1, 6, 1, 8}, {6, 8}, {8, 10}, {1, 6, 1, 10}, {1, 6, 1, 10}));
InsertReshapeAroundMatmulTest<>::
CreateReferenceFunction({1, 6, 1, 8}, {-1, 8}, {8, 10}, {1, 6, 1, 10}),
InsertReshapeAroundMatmulTest<>::
CreateReferenceFunction({1, 6, 1, 8}, {-1, 8}, {8, 10}, {1, 6, 1, 10}));
RunTest(
InsertReshapeAroundMatmulTest<>::
CreateFunction({1, 1, 8}, {8, 10}),
InsertReshapeAroundMatmulTest<>::
CreateReferenceFunction({1, 1, 8}, {-1, 8}, {8, 10}, {1, 1, 10}));
RunTest(
InsertReshapeAroundMatmulTest<>::
CreateReferenceFunction({1, 1, 8}, {-1, 8}, {8, 10}, {1, 1, 10}),
InsertReshapeAroundMatmulTest<>::
CreateReferenceFunction({1, 1, 8}, {-1, 8}, {8, 10}, {1, 1, 10}));
}
TEST(TransformationTests, InsertReshapeAroundMatmulWithAdd) {
RunTest(
InsertReshapeAroundMatmulTest<true, true, false>::
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<true, true, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
InsertReshapeAroundMatmulTest<true, true>::
CreateFunction({1, 6, 8}, {8, 10}),
InsertReshapeAroundMatmulTest<true, true>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 6, 10}));
RunTest(
InsertReshapeAroundMatmulTest<true, true, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<true, true, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
InsertReshapeAroundMatmulTest<true, true>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<true, true>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 6, 10}));
}
TEST(TransformationTests, InsertReshapeAroundMatmulWithAdd_AddFirstInputConstant) {
RunTest(
InsertReshapeAroundMatmulTest<true, false, false>::
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<true, false, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
InsertReshapeAroundMatmulTest<true>::
CreateFunction({1, 6, 8}, {8, 10}),
InsertReshapeAroundMatmulTest<true>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 6, 10}));
RunTest(
InsertReshapeAroundMatmulTest<true, false, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<true, false, false>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
InsertReshapeAroundMatmulTest<true>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<true>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 6, 10}));
}
TEST(TransformationTests, InsertReshapeAroundMatmulWithFq) {
RunTest(
InsertReshapeAroundMatmulTest<false, false, true>::
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
CreateFunction({1, 6, 8}, {8, 10}),
InsertReshapeAroundMatmulTest<false, false, true>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 6, 10}));
RunTest(
InsertReshapeAroundMatmulTest<false, false, true>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<false, false, true>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 6, 10}));
}
TEST(TransformationTests, InsertReshapeAroundMatmulWithAddAndFq) {
RunTest(
InsertReshapeAroundMatmulTest<true, true, true>::
CreateFunction({1, 6, 8}, {8, 10}, {1, 6, 10}),
CreateFunction({1, 6, 8}, {8, 10}),
InsertReshapeAroundMatmulTest<true, true, true>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 6, 10}));
RunTest(
InsertReshapeAroundMatmulTest<true, true, true>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}),
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 6, 10}),
InsertReshapeAroundMatmulTest<true, true, true>::
CreateReferenceFunction({1, 6, 8}, {6, 8}, {8, 10}, {1, 6, 10}, {1, 6, 10}));
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 6, 10}));
}
TEST(TransformationTests, InsertReshapeAroundMatmulWithTranspose) {
RunTest(
InsertReshapeAroundMatmulTest<false, false, false, true>::
CreateFunction({1, 6, 8}, {8, 10}, {0, 2, 1}),
InsertReshapeAroundMatmulTest<false, false, false, true>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 10, 6}, {1, 0}));
RunTest(
InsertReshapeAroundMatmulTest<false, false, false, true>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 10, 6}, {1, 0}),
InsertReshapeAroundMatmulTest<false, false, false, true>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 10, 6}, {1, 0}));
RunTest(
InsertReshapeAroundMatmulTest<false, false, false, true>::
CreateFunction({1, 1, 8}, {8, 10}, {0, 2, 1}),
InsertReshapeAroundMatmulTest<false, false, false, true>::
CreateReferenceFunction({1, 1, 8}, {-1, 8}, {8, 10}, {1, 10, 1}, {1, 0}));
RunTest(
InsertReshapeAroundMatmulTest<false, false, false, true>::
CreateReferenceFunction({1, 1, 8}, {-1, 8}, {8, 10}, {1, 10, 1}, {1, 0}),
InsertReshapeAroundMatmulTest<false, false, false, true>::
CreateReferenceFunction({1, 1, 8}, {-1, 8}, {8, 10}, {1, 10, 1}, {1, 0}));
}
TEST(TransformationTests, InsertReshapeAroundMatmulWithFqAndTranspose) {
RunTest(
InsertReshapeAroundMatmulTest<false, false, true, true>::
CreateFunction({1, 6, 8}, {8, 10}, {0, 2, 1}),
InsertReshapeAroundMatmulTest<false, false, true, true>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 10, 6}, {1, 0}));
RunTest(
InsertReshapeAroundMatmulTest<false, false, true, true>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 10, 6}, {1, 0}),
InsertReshapeAroundMatmulTest<false, false, true, true>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 10, 6}, {1, 0}));
RunTest(
InsertReshapeAroundMatmulTest<false, false, true, true>::
CreateFunction({1, 1, 8}, {8, 10}, {0, 2, 1}),
InsertReshapeAroundMatmulTest<false, false, true, true>::
CreateReferenceFunction({1, 1, 8}, {-1, 8}, {8, 10}, {1, 10, 1}, {1, 0}));
RunTest(
InsertReshapeAroundMatmulTest<false, false, true, true>::
CreateReferenceFunction({1, 1, 8}, {-1, 8}, {8, 10}, {1, 10, 1}, {1, 0}),
InsertReshapeAroundMatmulTest<false, false, true, true>::
CreateReferenceFunction({1, 1, 8}, {-1, 8}, {8, 10}, {1, 10, 1}, {1, 0}));
}
TEST(TransformationTests, InsertReshapeAroundMatmulWithAddAndFqAndTranspose) {
RunTest(
InsertReshapeAroundMatmulTest<true, true, true, true>::
CreateFunction({1, 6, 8}, {8, 10}, {0, 2, 1}),
InsertReshapeAroundMatmulTest<true, true, true, true>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 10, 6}, {1, 0}));
RunTest(
InsertReshapeAroundMatmulTest<true, true, true, true>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 10, 6}, {1, 0}),
InsertReshapeAroundMatmulTest<true, true, true, true>::
CreateReferenceFunction({1, 6, 8}, {-1, 8}, {8, 10}, {1, 10, 6}, {1, 0}));
}