Disable MatMul to FC conversion for VPU (#1520)
* Splited MatmulToFCorGemm transformation * Updated VPU transformation predicate to check that MatMul has DSR as input
This commit is contained in:
parent
0c3da56ae1
commit
531a7209d5
@ -19,11 +19,26 @@ namespace ngraph {
|
|||||||
namespace pass {
|
namespace pass {
|
||||||
|
|
||||||
class TRANSFORMATIONS_API ConvertMatMulToFCorGemm;
|
class TRANSFORMATIONS_API ConvertMatMulToFCorGemm;
|
||||||
|
class TRANSFORMATIONS_API ConvertMatMulToFC;
|
||||||
|
class TRANSFORMATIONS_API ConvertMatMulToGemm;
|
||||||
|
|
||||||
} // namespace pass
|
} // namespace pass
|
||||||
} // namespace ngraph
|
} // namespace ngraph
|
||||||
|
|
||||||
class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::MatcherPass {
|
class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::GraphRewrite {
|
||||||
public:
|
public:
|
||||||
ConvertMatMulToFCorGemm();
|
ConvertMatMulToFCorGemm() {
|
||||||
|
add_matcher<ngraph::pass::ConvertMatMulToFC>();
|
||||||
|
add_matcher<ngraph::pass::ConvertMatMulToGemm>();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ngraph::pass::ConvertMatMulToFC: public ngraph::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
ConvertMatMulToFC();
|
||||||
|
};
|
||||||
|
|
||||||
|
class ngraph::pass::ConvertMatMulToGemm: public ngraph::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
ConvertMatMulToGemm();
|
||||||
};
|
};
|
||||||
|
@ -18,14 +18,14 @@
|
|||||||
#include <transformations/utils/utils.hpp>
|
#include <transformations/utils/utils.hpp>
|
||||||
|
|
||||||
|
|
||||||
ngraph::pass::ConvertMatMulToFCorGemm::ConvertMatMulToFCorGemm() {
|
ngraph::pass::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||||
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
|
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
|
||||||
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
|
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
|
||||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input_0, input_1);
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input_0, input_1);
|
||||||
|
|
||||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||||
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root());
|
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root());
|
||||||
if (!matmul) {
|
if (!matmul || m_transformation_callback(matmul)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -150,54 +150,82 @@ ngraph::pass::ConvertMatMulToFCorGemm::ConvertMatMulToFCorGemm() {
|
|||||||
|
|
||||||
ngraph::copy_runtime_info(matmul, new_ops);
|
ngraph::copy_runtime_info(matmul, new_ops);
|
||||||
ngraph::replace_node(matmul, fc);
|
ngraph::replace_node(matmul, fc);
|
||||||
} else {
|
return true;
|
||||||
// WA for IE that Gemm must have inputs with the same length.
|
}
|
||||||
if (shape_a.size() < shape_b.size()) {
|
return false;
|
||||||
// Reshape first input (fc_input_a)
|
};
|
||||||
Shape reshape_shape(shape_b.size() - shape_a.size(), 1);
|
|
||||||
reshape_shape.insert(reshape_shape.end(), shape_a.begin(), shape_a.end());
|
|
||||||
fc_input_a = op::util::reshapeTo(fc_input_a, reshape_shape);
|
|
||||||
new_ops.push_back(fc_input_a.get_node_shared_ptr());
|
|
||||||
} else if (shape_b.size() < shape_a.size()) {
|
|
||||||
// Reshape second input (fc_input_b)
|
|
||||||
Shape reshape_shape;
|
|
||||||
if (shape_b.size() == 1) {
|
|
||||||
// In case if shape_b has only one dimension we reshape it to [...,1,X,1]
|
|
||||||
reshape_shape = Shape(shape_a.size() - (shape_b.size() + 1), 1);
|
|
||||||
reshape_shape.push_back(shape_b[0]); // add X dimension
|
|
||||||
reshape_shape.push_back(1); // add last 1 dimension
|
|
||||||
} else {
|
|
||||||
// In this case we reshape shape_b to [...,1,1,X]
|
|
||||||
reshape_shape = Shape(shape_a.size() - shape_b.size(), 1);
|
|
||||||
reshape_shape.insert(reshape_shape.end(), shape_b.begin(), shape_b.end());
|
|
||||||
}
|
|
||||||
fc_input_b = op::util::reshapeTo(fc_input_b, reshape_shape);
|
|
||||||
new_ops.push_back(fc_input_b.get_node_shared_ptr());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto gemm = std::make_shared<opset1::MatMul>(fc_input_a, fc_input_b, matmul->get_transpose_a(), matmul->get_transpose_b());
|
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToFC");
|
||||||
new_ops.push_back(gemm);
|
this->register_matcher(m, callback);
|
||||||
|
}
|
||||||
|
|
||||||
if (gemm->get_shape() != output_shape) {
|
ngraph::pass::ConvertMatMulToGemm::ConvertMatMulToGemm() {
|
||||||
// This case is possible only when second input had exactly 1 dimension (that is not supported by GEMM operation)
|
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
|
||||||
// and for this case we have to reshape second input to first but this affects output shape (additional dimensions)
|
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
|
||||||
// So to preserve output shape we insert additional reshape operation
|
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input_0, input_1);
|
||||||
auto reshape_output = op::util::reshapeTo(gemm, output_shape);
|
|
||||||
new_ops.push_back(reshape_output);
|
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||||
gemm->set_friendly_name(matmul->get_friendly_name() + "/gemm");
|
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root());
|
||||||
reshape_output->set_friendly_name(matmul->get_friendly_name());
|
if (!matmul) {
|
||||||
ngraph::copy_runtime_info(matmul, new_ops);
|
return false;
|
||||||
ngraph::replace_node(matmul, reshape_output);
|
}
|
||||||
|
|
||||||
|
auto input_a = matmul->input(0).get_source_output();
|
||||||
|
auto input_b = matmul->input(1).get_source_output();
|
||||||
|
|
||||||
|
auto shape_a = input_a.get_shape();
|
||||||
|
auto shape_b = input_b.get_shape();
|
||||||
|
auto output_shape = matmul->get_shape();
|
||||||
|
|
||||||
|
auto fc_input_a = input_a, fc_input_b = input_b;
|
||||||
|
NodeVector new_ops;
|
||||||
|
|
||||||
|
// WA for IE that Gemm must have inputs with the same length.
|
||||||
|
if (shape_a.size() < shape_b.size()) {
|
||||||
|
// Reshape first input (fc_input_a)
|
||||||
|
Shape reshape_shape(shape_b.size() - shape_a.size(), 1);
|
||||||
|
reshape_shape.insert(reshape_shape.end(), shape_a.begin(), shape_a.end());
|
||||||
|
fc_input_a = op::util::reshapeTo(fc_input_a, reshape_shape);
|
||||||
|
new_ops.push_back(fc_input_a.get_node_shared_ptr());
|
||||||
|
} else if (shape_b.size() < shape_a.size()) {
|
||||||
|
// Reshape second input (fc_input_b)
|
||||||
|
Shape reshape_shape;
|
||||||
|
if (shape_b.size() == 1) {
|
||||||
|
// In case if shape_b has only one dimension we reshape it to [...,1,X,1]
|
||||||
|
reshape_shape = Shape(shape_a.size() - (shape_b.size() + 1), 1);
|
||||||
|
reshape_shape.push_back(shape_b[0]); // add X dimension
|
||||||
|
reshape_shape.push_back(1); // add last 1 dimension
|
||||||
} else {
|
} else {
|
||||||
gemm->set_friendly_name(matmul->get_friendly_name());
|
// In this case we reshape shape_b to [...,1,1,X]
|
||||||
ngraph::copy_runtime_info(matmul, new_ops);
|
reshape_shape = Shape(shape_a.size() - shape_b.size(), 1);
|
||||||
ngraph::replace_node(matmul, gemm);
|
reshape_shape.insert(reshape_shape.end(), shape_b.begin(), shape_b.end());
|
||||||
}
|
}
|
||||||
|
fc_input_b = op::util::reshapeTo(fc_input_b, reshape_shape);
|
||||||
|
new_ops.push_back(fc_input_b.get_node_shared_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto gemm = std::make_shared<opset1::MatMul>(fc_input_a, fc_input_b, matmul->get_transpose_a(), matmul->get_transpose_b());
|
||||||
|
new_ops.push_back(gemm);
|
||||||
|
|
||||||
|
if (gemm->get_shape() != output_shape) {
|
||||||
|
// This case is possible only when second input had exactly 1 dimension (that is not supported by GEMM operation)
|
||||||
|
// and for this case we have to reshape second input to first but this affects output shape (additional dimensions)
|
||||||
|
// So to preserve output shape we insert additional reshape operation
|
||||||
|
auto reshape_output = op::util::reshapeTo(gemm, output_shape);
|
||||||
|
new_ops.push_back(reshape_output);
|
||||||
|
gemm->set_friendly_name(matmul->get_friendly_name() + "/gemm");
|
||||||
|
reshape_output->set_friendly_name(matmul->get_friendly_name());
|
||||||
|
ngraph::copy_runtime_info(matmul, new_ops);
|
||||||
|
ngraph::replace_node(matmul, reshape_output);
|
||||||
|
} else {
|
||||||
|
gemm->set_friendly_name(matmul->get_friendly_name());
|
||||||
|
ngraph::copy_runtime_info(matmul, new_ops);
|
||||||
|
ngraph::replace_node(matmul, gemm);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToFCorGemm");
|
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToGemm");
|
||||||
this->register_matcher(m, callback);
|
this->register_matcher(m, callback);
|
||||||
}
|
}
|
||||||
|
@ -80,7 +80,8 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
|||||||
decomp->add_matcher<ngraph::pass::ConvertDeconvolution>();
|
decomp->add_matcher<ngraph::pass::ConvertDeconvolution>();
|
||||||
decomp->add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
|
decomp->add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
|
||||||
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
|
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
|
||||||
decomp->add_matcher<ngraph::pass::ConvertMatMulToFCorGemm>();
|
decomp->add_matcher<ngraph::pass::ConvertMatMulToFC>();
|
||||||
|
decomp->add_matcher<ngraph::pass::ConvertMatMulToGemm>();
|
||||||
decomp->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
|
decomp->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
|
||||||
decomp->set_name("ngraph::pass::Decompositions");
|
decomp->set_name("ngraph::pass::Decompositions");
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@
|
|||||||
#include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp>
|
#include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp>
|
||||||
#include <transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
|
#include <transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
|
||||||
#include <vpu/ngraph/transformations/merge_subsequent_dsr_operations.hpp>
|
#include <vpu/ngraph/transformations/merge_subsequent_dsr_operations.hpp>
|
||||||
|
#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
|
||||||
|
|
||||||
namespace vpu {
|
namespace vpu {
|
||||||
|
|
||||||
@ -380,8 +381,10 @@ ModelPtr FrontEnd::runCommonPasses(ie::ICNNNetwork& network, const UnsupportedLa
|
|||||||
|
|
||||||
auto convertNetwork = [&convertedNetwork, &originalOrConvertNetwork]() {
|
auto convertNetwork = [&convertedNetwork, &originalOrConvertNetwork]() {
|
||||||
// disable GeLU decomposition
|
// disable GeLU decomposition
|
||||||
const auto transformationsPredicate = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
|
const auto transformationsPredicate = [](const std::shared_ptr<const ngraph::Node> &node) -> bool {
|
||||||
return std::dynamic_pointer_cast<const ::ngraph::opset3::Gelu>(node) != nullptr;
|
return std::dynamic_pointer_cast<const ngraph::opset3::Gelu>(node) ||
|
||||||
|
(std::dynamic_pointer_cast<const ngraph::opset3::MatMul>(node) &&
|
||||||
|
std::dynamic_pointer_cast<const ngraph::vpu::op::DynamicShapeResolver>(node->input_value(0).get_node_shared_ptr()));
|
||||||
};
|
};
|
||||||
|
|
||||||
auto nGraphFunc = originalOrConvertNetwork->getFunction();
|
auto nGraphFunc = originalOrConvertNetwork->getFunction();
|
||||||
|
@ -38,7 +38,8 @@ TEST(TransformationTests, ConvertMatMulTest1) {
|
|||||||
|
|
||||||
ngraph::pass::Manager m;
|
ngraph::pass::Manager m;
|
||||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
|
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
|
||||||
|
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
|
||||||
m.run_passes(f);
|
m.run_passes(f);
|
||||||
ASSERT_NO_THROW(check_rt_info(f));
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
}
|
}
|
||||||
@ -69,7 +70,8 @@ TEST(TransformationTests, ConvertMatMulTest2) {
|
|||||||
|
|
||||||
ngraph::pass::Manager m;
|
ngraph::pass::Manager m;
|
||||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
|
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
|
||||||
|
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
|
||||||
m.run_passes(f);
|
m.run_passes(f);
|
||||||
ASSERT_NO_THROW(check_rt_info(f));
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
}
|
}
|
||||||
@ -99,7 +101,8 @@ TEST(TransformationTests, ConvertMatMulTest3) {
|
|||||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
|
||||||
ngraph::pass::Manager m;
|
ngraph::pass::Manager m;
|
||||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
|
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
|
||||||
|
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
|
||||||
m.run_passes(f);
|
m.run_passes(f);
|
||||||
ASSERT_NO_THROW(check_rt_info(f));
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
}
|
}
|
||||||
@ -129,7 +132,8 @@ TEST(TransformationTests, ConvertMatMulTest4) {
|
|||||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
|
||||||
ngraph::pass::Manager m;
|
ngraph::pass::Manager m;
|
||||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
|
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
|
||||||
|
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
|
||||||
m.run_passes(f);
|
m.run_passes(f);
|
||||||
ASSERT_NO_THROW(check_rt_info(f));
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
}
|
}
|
||||||
@ -156,7 +160,8 @@ TEST(TransformationTests, ConvertMatMulTest5) {
|
|||||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||||
ngraph::pass::Manager m;
|
ngraph::pass::Manager m;
|
||||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
|
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
|
||||||
|
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
|
||||||
m.run_passes(f);
|
m.run_passes(f);
|
||||||
ASSERT_NO_THROW(check_rt_info(f));
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
}
|
}
|
||||||
@ -184,7 +189,8 @@ TEST(TransformationTests, ConvertMatMulTest6) {
|
|||||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||||
ngraph::pass::Manager m;
|
ngraph::pass::Manager m;
|
||||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
|
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
|
||||||
|
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
|
||||||
m.register_pass<ngraph::pass::ReshapeFullyConnected>();
|
m.register_pass<ngraph::pass::ReshapeFullyConnected>();
|
||||||
m.run_passes(f);
|
m.run_passes(f);
|
||||||
ASSERT_NO_THROW(check_rt_info(f));
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
@ -216,7 +222,8 @@ TEST(TransformationTests, ConvertMatMulTest7) {
|
|||||||
|
|
||||||
ngraph::pass::Manager m;
|
ngraph::pass::Manager m;
|
||||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
|
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
|
||||||
|
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
|
||||||
m.register_pass<ngraph::pass::ReshapeFullyConnected>();
|
m.register_pass<ngraph::pass::ReshapeFullyConnected>();
|
||||||
|
|
||||||
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||||
|
Loading…
Reference in New Issue
Block a user