Disable MatMul to FC conversion for VPU (#1520)

* Splited MatmulToFCorGemm transformation

* Updated VPU transformation predicate to check that MatMul has DSR as input
This commit is contained in:
Gleb Kazantaev 2020-07-30 11:50:52 +03:00 committed by GitHub
parent 0c3da56ae1
commit 531a7209d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 108 additions and 54 deletions

View File

@ -19,11 +19,26 @@ namespace ngraph {
namespace pass { namespace pass {
class TRANSFORMATIONS_API ConvertMatMulToFCorGemm; class TRANSFORMATIONS_API ConvertMatMulToFCorGemm;
class TRANSFORMATIONS_API ConvertMatMulToFC;
class TRANSFORMATIONS_API ConvertMatMulToGemm;
} // namespace pass } // namespace pass
} // namespace ngraph } // namespace ngraph
class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::MatcherPass { class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::GraphRewrite {
public: public:
ConvertMatMulToFCorGemm(); ConvertMatMulToFCorGemm() {
add_matcher<ngraph::pass::ConvertMatMulToFC>();
add_matcher<ngraph::pass::ConvertMatMulToGemm>();
}
};
class ngraph::pass::ConvertMatMulToFC: public ngraph::pass::MatcherPass {
public:
ConvertMatMulToFC();
};
class ngraph::pass::ConvertMatMulToGemm: public ngraph::pass::MatcherPass {
public:
ConvertMatMulToGemm();
}; };

View File

@ -18,14 +18,14 @@
#include <transformations/utils/utils.hpp> #include <transformations/utils/utils.hpp>
ngraph::pass::ConvertMatMulToFCorGemm::ConvertMatMulToFCorGemm() { ngraph::pass::ConvertMatMulToFC::ConvertMatMulToFC() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1}); auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1}); auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input_0, input_1); auto matmul = std::make_shared<ngraph::opset1::MatMul>(input_0, input_1);
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) { ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root()); auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root());
if (!matmul) { if (!matmul || m_transformation_callback(matmul)) {
return false; return false;
} }
@ -150,7 +150,36 @@ ngraph::pass::ConvertMatMulToFCorGemm::ConvertMatMulToFCorGemm() {
ngraph::copy_runtime_info(matmul, new_ops); ngraph::copy_runtime_info(matmul, new_ops);
ngraph::replace_node(matmul, fc); ngraph::replace_node(matmul, fc);
} else { return true;
}
return false;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToFC");
this->register_matcher(m, callback);
}
ngraph::pass::ConvertMatMulToGemm::ConvertMatMulToGemm() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input_0, input_1);
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root());
if (!matmul) {
return false;
}
auto input_a = matmul->input(0).get_source_output();
auto input_b = matmul->input(1).get_source_output();
auto shape_a = input_a.get_shape();
auto shape_b = input_b.get_shape();
auto output_shape = matmul->get_shape();
auto fc_input_a = input_a, fc_input_b = input_b;
NodeVector new_ops;
// WA for IE that Gemm must have inputs with the same length. // WA for IE that Gemm must have inputs with the same length.
if (shape_a.size() < shape_b.size()) { if (shape_a.size() < shape_b.size()) {
// Reshape first input (fc_input_a) // Reshape first input (fc_input_a)
@ -193,11 +222,10 @@ ngraph::pass::ConvertMatMulToFCorGemm::ConvertMatMulToFCorGemm() {
ngraph::copy_runtime_info(matmul, new_ops); ngraph::copy_runtime_info(matmul, new_ops);
ngraph::replace_node(matmul, gemm); ngraph::replace_node(matmul, gemm);
} }
}
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToFCorGemm"); auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToGemm");
this->register_matcher(m, callback); this->register_matcher(m, callback);
} }

View File

@ -80,7 +80,8 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
decomp->add_matcher<ngraph::pass::ConvertDeconvolution>(); decomp->add_matcher<ngraph::pass::ConvertDeconvolution>();
decomp->add_matcher<ngraph::pass::ConvertGroupDeconvolution>(); decomp->add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>(); decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
decomp->add_matcher<ngraph::pass::ConvertMatMulToFCorGemm>(); decomp->add_matcher<ngraph::pass::ConvertMatMulToFC>();
decomp->add_matcher<ngraph::pass::ConvertMatMulToGemm>();
decomp->add_matcher<ngraph::pass::PullTransposeThroughFQUp>(); decomp->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
decomp->set_name("ngraph::pass::Decompositions"); decomp->set_name("ngraph::pass::Decompositions");

View File

@ -26,6 +26,7 @@
#include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp> #include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp>
#include <transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp> #include <transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
#include <vpu/ngraph/transformations/merge_subsequent_dsr_operations.hpp> #include <vpu/ngraph/transformations/merge_subsequent_dsr_operations.hpp>
#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
namespace vpu { namespace vpu {
@ -380,8 +381,10 @@ ModelPtr FrontEnd::runCommonPasses(ie::ICNNNetwork& network, const UnsupportedLa
auto convertNetwork = [&convertedNetwork, &originalOrConvertNetwork]() { auto convertNetwork = [&convertedNetwork, &originalOrConvertNetwork]() {
// disable GeLU decomposition // disable GeLU decomposition
const auto transformationsPredicate = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool { const auto transformationsPredicate = [](const std::shared_ptr<const ngraph::Node> &node) -> bool {
return std::dynamic_pointer_cast<const ::ngraph::opset3::Gelu>(node) != nullptr; return std::dynamic_pointer_cast<const ngraph::opset3::Gelu>(node) ||
(std::dynamic_pointer_cast<const ngraph::opset3::MatMul>(node) &&
std::dynamic_pointer_cast<const ngraph::vpu::op::DynamicShapeResolver>(node->input_value(0).get_node_shared_ptr()));
}; };
auto nGraphFunc = originalOrConvertNetwork->getFunction(); auto nGraphFunc = originalOrConvertNetwork->getFunction();

View File

@ -38,7 +38,8 @@ TEST(TransformationTests, ConvertMatMulTest1) {
ngraph::pass::Manager m; ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>(); m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>(); m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
m.run_passes(f); m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f)); ASSERT_NO_THROW(check_rt_info(f));
} }
@ -69,7 +70,8 @@ TEST(TransformationTests, ConvertMatMulTest2) {
ngraph::pass::Manager m; ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>(); m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>(); m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
m.run_passes(f); m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f)); ASSERT_NO_THROW(check_rt_info(f));
} }
@ -99,7 +101,8 @@ TEST(TransformationTests, ConvertMatMulTest3) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2}); f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
ngraph::pass::Manager m; ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>(); m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>(); m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
m.run_passes(f); m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f)); ASSERT_NO_THROW(check_rt_info(f));
} }
@ -129,7 +132,8 @@ TEST(TransformationTests, ConvertMatMulTest4) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2}); f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
ngraph::pass::Manager m; ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>(); m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>(); m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
m.run_passes(f); m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f)); ASSERT_NO_THROW(check_rt_info(f));
} }
@ -156,7 +160,8 @@ TEST(TransformationTests, ConvertMatMulTest5) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1}); f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
ngraph::pass::Manager m; ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>(); m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>(); m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
m.run_passes(f); m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f)); ASSERT_NO_THROW(check_rt_info(f));
} }
@ -184,7 +189,8 @@ TEST(TransformationTests, ConvertMatMulTest6) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1}); f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
ngraph::pass::Manager m; ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>(); m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>(); m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
m.register_pass<ngraph::pass::ReshapeFullyConnected>(); m.register_pass<ngraph::pass::ReshapeFullyConnected>();
m.run_passes(f); m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f)); ASSERT_NO_THROW(check_rt_info(f));
@ -216,7 +222,8 @@ TEST(TransformationTests, ConvertMatMulTest7) {
ngraph::pass::Manager m; ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>(); m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>(); m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
m.register_pass<ngraph::pass::ReshapeFullyConnected>(); m.register_pass<ngraph::pass::ReshapeFullyConnected>();
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool { auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {