Disable MatMul to FC conversion for VPU (#1520)

* Splited MatmulToFCorGemm transformation

* Updated VPU transformation predicate to check that MatMul has DSR as input
This commit is contained in:
Gleb Kazantaev 2020-07-30 11:50:52 +03:00 committed by GitHub
parent 0c3da56ae1
commit 531a7209d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 108 additions and 54 deletions

View File

@ -19,11 +19,26 @@ namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API ConvertMatMulToFCorGemm;
class TRANSFORMATIONS_API ConvertMatMulToFC;
class TRANSFORMATIONS_API ConvertMatMulToGemm;
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::MatcherPass {
class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::GraphRewrite {
public:
ConvertMatMulToFCorGemm();
ConvertMatMulToFCorGemm() {
add_matcher<ngraph::pass::ConvertMatMulToFC>();
add_matcher<ngraph::pass::ConvertMatMulToGemm>();
}
};
class ngraph::pass::ConvertMatMulToFC: public ngraph::pass::MatcherPass {
public:
ConvertMatMulToFC();
};
class ngraph::pass::ConvertMatMulToGemm: public ngraph::pass::MatcherPass {
public:
ConvertMatMulToGemm();
};

View File

@ -18,14 +18,14 @@
#include <transformations/utils/utils.hpp>
ngraph::pass::ConvertMatMulToFCorGemm::ConvertMatMulToFCorGemm() {
ngraph::pass::ConvertMatMulToFC::ConvertMatMulToFC() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input_0, input_1);
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root());
if (!matmul) {
if (!matmul || m_transformation_callback(matmul)) {
return false;
}
@ -150,54 +150,82 @@ ngraph::pass::ConvertMatMulToFCorGemm::ConvertMatMulToFCorGemm() {
ngraph::copy_runtime_info(matmul, new_ops);
ngraph::replace_node(matmul, fc);
} else {
// WA for IE that Gemm must have inputs with the same length.
if (shape_a.size() < shape_b.size()) {
// Reshape first input (fc_input_a)
Shape reshape_shape(shape_b.size() - shape_a.size(), 1);
reshape_shape.insert(reshape_shape.end(), shape_a.begin(), shape_a.end());
fc_input_a = op::util::reshapeTo(fc_input_a, reshape_shape);
new_ops.push_back(fc_input_a.get_node_shared_ptr());
} else if (shape_b.size() < shape_a.size()) {
// Reshape second input (fc_input_b)
Shape reshape_shape;
if (shape_b.size() == 1) {
// In case if shape_b has only one dimension we reshape it to [...,1,X,1]
reshape_shape = Shape(shape_a.size() - (shape_b.size() + 1), 1);
reshape_shape.push_back(shape_b[0]); // add X dimension
reshape_shape.push_back(1); // add last 1 dimension
} else {
// In this case we reshape shape_b to [...,1,1,X]
reshape_shape = Shape(shape_a.size() - shape_b.size(), 1);
reshape_shape.insert(reshape_shape.end(), shape_b.begin(), shape_b.end());
}
fc_input_b = op::util::reshapeTo(fc_input_b, reshape_shape);
new_ops.push_back(fc_input_b.get_node_shared_ptr());
}
return true;
}
return false;
};
auto gemm = std::make_shared<opset1::MatMul>(fc_input_a, fc_input_b, matmul->get_transpose_a(), matmul->get_transpose_b());
new_ops.push_back(gemm);
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToFC");
this->register_matcher(m, callback);
}
if (gemm->get_shape() != output_shape) {
// This case is possible only when second input had exactly 1 dimension (that is not supported by GEMM operation)
// and for this case we have to reshape second input to first but this affects output shape (additional dimensions)
// So to preserve output shape we insert additional reshape operation
auto reshape_output = op::util::reshapeTo(gemm, output_shape);
new_ops.push_back(reshape_output);
gemm->set_friendly_name(matmul->get_friendly_name() + "/gemm");
reshape_output->set_friendly_name(matmul->get_friendly_name());
ngraph::copy_runtime_info(matmul, new_ops);
ngraph::replace_node(matmul, reshape_output);
ngraph::pass::ConvertMatMulToGemm::ConvertMatMulToGemm() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input_0, input_1);
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root());
if (!matmul) {
return false;
}
auto input_a = matmul->input(0).get_source_output();
auto input_b = matmul->input(1).get_source_output();
auto shape_a = input_a.get_shape();
auto shape_b = input_b.get_shape();
auto output_shape = matmul->get_shape();
auto fc_input_a = input_a, fc_input_b = input_b;
NodeVector new_ops;
// WA for IE that Gemm must have inputs with the same length.
if (shape_a.size() < shape_b.size()) {
// Reshape first input (fc_input_a)
Shape reshape_shape(shape_b.size() - shape_a.size(), 1);
reshape_shape.insert(reshape_shape.end(), shape_a.begin(), shape_a.end());
fc_input_a = op::util::reshapeTo(fc_input_a, reshape_shape);
new_ops.push_back(fc_input_a.get_node_shared_ptr());
} else if (shape_b.size() < shape_a.size()) {
// Reshape second input (fc_input_b)
Shape reshape_shape;
if (shape_b.size() == 1) {
// In case if shape_b has only one dimension we reshape it to [...,1,X,1]
reshape_shape = Shape(shape_a.size() - (shape_b.size() + 1), 1);
reshape_shape.push_back(shape_b[0]); // add X dimension
reshape_shape.push_back(1); // add last 1 dimension
} else {
gemm->set_friendly_name(matmul->get_friendly_name());
ngraph::copy_runtime_info(matmul, new_ops);
ngraph::replace_node(matmul, gemm);
// In this case we reshape shape_b to [...,1,1,X]
reshape_shape = Shape(shape_a.size() - shape_b.size(), 1);
reshape_shape.insert(reshape_shape.end(), shape_b.begin(), shape_b.end());
}
fc_input_b = op::util::reshapeTo(fc_input_b, reshape_shape);
new_ops.push_back(fc_input_b.get_node_shared_ptr());
}
auto gemm = std::make_shared<opset1::MatMul>(fc_input_a, fc_input_b, matmul->get_transpose_a(), matmul->get_transpose_b());
new_ops.push_back(gemm);
if (gemm->get_shape() != output_shape) {
// This case is possible only when second input had exactly 1 dimension (that is not supported by GEMM operation)
// and for this case we have to reshape second input to first but this affects output shape (additional dimensions)
// So to preserve output shape we insert additional reshape operation
auto reshape_output = op::util::reshapeTo(gemm, output_shape);
new_ops.push_back(reshape_output);
gemm->set_friendly_name(matmul->get_friendly_name() + "/gemm");
reshape_output->set_friendly_name(matmul->get_friendly_name());
ngraph::copy_runtime_info(matmul, new_ops);
ngraph::replace_node(matmul, reshape_output);
} else {
gemm->set_friendly_name(matmul->get_friendly_name());
ngraph::copy_runtime_info(matmul, new_ops);
ngraph::replace_node(matmul, gemm);
}
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToFCorGemm");
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToGemm");
this->register_matcher(m, callback);
}

View File

@ -80,7 +80,8 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
decomp->add_matcher<ngraph::pass::ConvertDeconvolution>();
decomp->add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
decomp->add_matcher<ngraph::pass::ConvertMatMulToFCorGemm>();
decomp->add_matcher<ngraph::pass::ConvertMatMulToFC>();
decomp->add_matcher<ngraph::pass::ConvertMatMulToGemm>();
decomp->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
decomp->set_name("ngraph::pass::Decompositions");

View File

@ -26,6 +26,7 @@
#include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp>
#include <transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
#include <vpu/ngraph/transformations/merge_subsequent_dsr_operations.hpp>
#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
namespace vpu {
@ -380,8 +381,10 @@ ModelPtr FrontEnd::runCommonPasses(ie::ICNNNetwork& network, const UnsupportedLa
auto convertNetwork = [&convertedNetwork, &originalOrConvertNetwork]() {
// disable GeLU decomposition
const auto transformationsPredicate = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
return std::dynamic_pointer_cast<const ::ngraph::opset3::Gelu>(node) != nullptr;
const auto transformationsPredicate = [](const std::shared_ptr<const ngraph::Node> &node) -> bool {
return std::dynamic_pointer_cast<const ngraph::opset3::Gelu>(node) ||
(std::dynamic_pointer_cast<const ngraph::opset3::MatMul>(node) &&
std::dynamic_pointer_cast<const ngraph::vpu::op::DynamicShapeResolver>(node->input_value(0).get_node_shared_ptr()));
};
auto nGraphFunc = originalOrConvertNetwork->getFunction();

View File

@ -38,7 +38,8 @@ TEST(TransformationTests, ConvertMatMulTest1) {
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -69,7 +70,8 @@ TEST(TransformationTests, ConvertMatMulTest2) {
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -99,7 +101,8 @@ TEST(TransformationTests, ConvertMatMulTest3) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -129,7 +132,8 @@ TEST(TransformationTests, ConvertMatMulTest4) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -156,7 +160,8 @@ TEST(TransformationTests, ConvertMatMulTest5) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -184,7 +189,8 @@ TEST(TransformationTests, ConvertMatMulTest6) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
m.register_pass<ngraph::pass::ReshapeFullyConnected>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
@ -216,7 +222,8 @@ TEST(TransformationTests, ConvertMatMulTest7) {
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
m.register_pass<ngraph::pass::ReshapeFullyConnected>();
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {