Disable MatMul to FC conversion for VPU (#1520)
* Splited MatmulToFCorGemm transformation * Updated VPU transformation predicate to check that MatMul has DSR as input
This commit is contained in:
parent
0c3da56ae1
commit
531a7209d5
@ -19,11 +19,26 @@ namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API ConvertMatMulToFCorGemm;
|
||||
class TRANSFORMATIONS_API ConvertMatMulToFC;
|
||||
class TRANSFORMATIONS_API ConvertMatMulToGemm;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::MatcherPass {
|
||||
class ngraph::pass::ConvertMatMulToFCorGemm: public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
ConvertMatMulToFCorGemm();
|
||||
ConvertMatMulToFCorGemm() {
|
||||
add_matcher<ngraph::pass::ConvertMatMulToFC>();
|
||||
add_matcher<ngraph::pass::ConvertMatMulToGemm>();
|
||||
}
|
||||
};
|
||||
|
||||
class ngraph::pass::ConvertMatMulToFC: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
ConvertMatMulToFC();
|
||||
};
|
||||
|
||||
class ngraph::pass::ConvertMatMulToGemm: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
ConvertMatMulToGemm();
|
||||
};
|
||||
|
@ -18,14 +18,14 @@
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
|
||||
ngraph::pass::ConvertMatMulToFCorGemm::ConvertMatMulToFCorGemm() {
|
||||
ngraph::pass::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
|
||||
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input_0, input_1);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root());
|
||||
if (!matmul) {
|
||||
if (!matmul || m_transformation_callback(matmul)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -150,54 +150,82 @@ ngraph::pass::ConvertMatMulToFCorGemm::ConvertMatMulToFCorGemm() {
|
||||
|
||||
ngraph::copy_runtime_info(matmul, new_ops);
|
||||
ngraph::replace_node(matmul, fc);
|
||||
} else {
|
||||
// WA for IE that Gemm must have inputs with the same length.
|
||||
if (shape_a.size() < shape_b.size()) {
|
||||
// Reshape first input (fc_input_a)
|
||||
Shape reshape_shape(shape_b.size() - shape_a.size(), 1);
|
||||
reshape_shape.insert(reshape_shape.end(), shape_a.begin(), shape_a.end());
|
||||
fc_input_a = op::util::reshapeTo(fc_input_a, reshape_shape);
|
||||
new_ops.push_back(fc_input_a.get_node_shared_ptr());
|
||||
} else if (shape_b.size() < shape_a.size()) {
|
||||
// Reshape second input (fc_input_b)
|
||||
Shape reshape_shape;
|
||||
if (shape_b.size() == 1) {
|
||||
// In case if shape_b has only one dimension we reshape it to [...,1,X,1]
|
||||
reshape_shape = Shape(shape_a.size() - (shape_b.size() + 1), 1);
|
||||
reshape_shape.push_back(shape_b[0]); // add X dimension
|
||||
reshape_shape.push_back(1); // add last 1 dimension
|
||||
} else {
|
||||
// In this case we reshape shape_b to [...,1,1,X]
|
||||
reshape_shape = Shape(shape_a.size() - shape_b.size(), 1);
|
||||
reshape_shape.insert(reshape_shape.end(), shape_b.begin(), shape_b.end());
|
||||
}
|
||||
fc_input_b = op::util::reshapeTo(fc_input_b, reshape_shape);
|
||||
new_ops.push_back(fc_input_b.get_node_shared_ptr());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
auto gemm = std::make_shared<opset1::MatMul>(fc_input_a, fc_input_b, matmul->get_transpose_a(), matmul->get_transpose_b());
|
||||
new_ops.push_back(gemm);
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToFC");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
if (gemm->get_shape() != output_shape) {
|
||||
// This case is possible only when second input had exactly 1 dimension (that is not supported by GEMM operation)
|
||||
// and for this case we have to reshape second input to first but this affects output shape (additional dimensions)
|
||||
// So to preserve output shape we insert additional reshape operation
|
||||
auto reshape_output = op::util::reshapeTo(gemm, output_shape);
|
||||
new_ops.push_back(reshape_output);
|
||||
gemm->set_friendly_name(matmul->get_friendly_name() + "/gemm");
|
||||
reshape_output->set_friendly_name(matmul->get_friendly_name());
|
||||
ngraph::copy_runtime_info(matmul, new_ops);
|
||||
ngraph::replace_node(matmul, reshape_output);
|
||||
ngraph::pass::ConvertMatMulToGemm::ConvertMatMulToGemm() {
|
||||
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
|
||||
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1});
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input_0, input_1);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root());
|
||||
if (!matmul) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_a = matmul->input(0).get_source_output();
|
||||
auto input_b = matmul->input(1).get_source_output();
|
||||
|
||||
auto shape_a = input_a.get_shape();
|
||||
auto shape_b = input_b.get_shape();
|
||||
auto output_shape = matmul->get_shape();
|
||||
|
||||
auto fc_input_a = input_a, fc_input_b = input_b;
|
||||
NodeVector new_ops;
|
||||
|
||||
// WA for IE that Gemm must have inputs with the same length.
|
||||
if (shape_a.size() < shape_b.size()) {
|
||||
// Reshape first input (fc_input_a)
|
||||
Shape reshape_shape(shape_b.size() - shape_a.size(), 1);
|
||||
reshape_shape.insert(reshape_shape.end(), shape_a.begin(), shape_a.end());
|
||||
fc_input_a = op::util::reshapeTo(fc_input_a, reshape_shape);
|
||||
new_ops.push_back(fc_input_a.get_node_shared_ptr());
|
||||
} else if (shape_b.size() < shape_a.size()) {
|
||||
// Reshape second input (fc_input_b)
|
||||
Shape reshape_shape;
|
||||
if (shape_b.size() == 1) {
|
||||
// In case if shape_b has only one dimension we reshape it to [...,1,X,1]
|
||||
reshape_shape = Shape(shape_a.size() - (shape_b.size() + 1), 1);
|
||||
reshape_shape.push_back(shape_b[0]); // add X dimension
|
||||
reshape_shape.push_back(1); // add last 1 dimension
|
||||
} else {
|
||||
gemm->set_friendly_name(matmul->get_friendly_name());
|
||||
ngraph::copy_runtime_info(matmul, new_ops);
|
||||
ngraph::replace_node(matmul, gemm);
|
||||
// In this case we reshape shape_b to [...,1,1,X]
|
||||
reshape_shape = Shape(shape_a.size() - shape_b.size(), 1);
|
||||
reshape_shape.insert(reshape_shape.end(), shape_b.begin(), shape_b.end());
|
||||
}
|
||||
fc_input_b = op::util::reshapeTo(fc_input_b, reshape_shape);
|
||||
new_ops.push_back(fc_input_b.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
auto gemm = std::make_shared<opset1::MatMul>(fc_input_a, fc_input_b, matmul->get_transpose_a(), matmul->get_transpose_b());
|
||||
new_ops.push_back(gemm);
|
||||
|
||||
if (gemm->get_shape() != output_shape) {
|
||||
// This case is possible only when second input had exactly 1 dimension (that is not supported by GEMM operation)
|
||||
// and for this case we have to reshape second input to first but this affects output shape (additional dimensions)
|
||||
// So to preserve output shape we insert additional reshape operation
|
||||
auto reshape_output = op::util::reshapeTo(gemm, output_shape);
|
||||
new_ops.push_back(reshape_output);
|
||||
gemm->set_friendly_name(matmul->get_friendly_name() + "/gemm");
|
||||
reshape_output->set_friendly_name(matmul->get_friendly_name());
|
||||
ngraph::copy_runtime_info(matmul, new_ops);
|
||||
ngraph::replace_node(matmul, reshape_output);
|
||||
} else {
|
||||
gemm->set_friendly_name(matmul->get_friendly_name());
|
||||
ngraph::copy_runtime_info(matmul, new_ops);
|
||||
ngraph::replace_node(matmul, gemm);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToFCorGemm");
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(matmul, "ConvertMatMulToGemm");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
@ -80,7 +80,8 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
decomp->add_matcher<ngraph::pass::ConvertDeconvolution>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
|
||||
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertMatMulToFCorGemm>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertMatMulToFC>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertMatMulToGemm>();
|
||||
decomp->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
|
||||
decomp->set_name("ngraph::pass::Decompositions");
|
||||
|
||||
|
@ -26,6 +26,7 @@
|
||||
#include <transformations/convert_opset2_to_opset1/convert_opset2_to_opset1.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.hpp>
|
||||
#include <vpu/ngraph/transformations/merge_subsequent_dsr_operations.hpp>
|
||||
#include <vpu/ngraph/operations/dynamic_shape_resolver.hpp>
|
||||
|
||||
namespace vpu {
|
||||
|
||||
@ -380,8 +381,10 @@ ModelPtr FrontEnd::runCommonPasses(ie::ICNNNetwork& network, const UnsupportedLa
|
||||
|
||||
auto convertNetwork = [&convertedNetwork, &originalOrConvertNetwork]() {
|
||||
// disable GeLU decomposition
|
||||
const auto transformationsPredicate = [](const std::shared_ptr<const ::ngraph::Node> &node) -> bool {
|
||||
return std::dynamic_pointer_cast<const ::ngraph::opset3::Gelu>(node) != nullptr;
|
||||
const auto transformationsPredicate = [](const std::shared_ptr<const ngraph::Node> &node) -> bool {
|
||||
return std::dynamic_pointer_cast<const ngraph::opset3::Gelu>(node) ||
|
||||
(std::dynamic_pointer_cast<const ngraph::opset3::MatMul>(node) &&
|
||||
std::dynamic_pointer_cast<const ngraph::vpu::op::DynamicShapeResolver>(node->input_value(0).get_node_shared_ptr()));
|
||||
};
|
||||
|
||||
auto nGraphFunc = originalOrConvertNetwork->getFunction();
|
||||
|
@ -38,7 +38,8 @@ TEST(TransformationTests, ConvertMatMulTest1) {
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
@ -69,7 +70,8 @@ TEST(TransformationTests, ConvertMatMulTest2) {
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
@ -99,7 +101,8 @@ TEST(TransformationTests, ConvertMatMulTest3) {
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
@ -129,7 +132,8 @@ TEST(TransformationTests, ConvertMatMulTest4) {
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1, input2});
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
@ -156,7 +160,8 @@ TEST(TransformationTests, ConvertMatMulTest5) {
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
@ -184,7 +189,8 @@ TEST(TransformationTests, ConvertMatMulTest6) {
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
|
||||
m.register_pass<ngraph::pass::ReshapeFullyConnected>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
@ -216,7 +222,8 @@ TEST(TransformationTests, ConvertMatMulTest7) {
|
||||
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToFCorGemm>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
|
||||
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
|
||||
m.register_pass<ngraph::pass::ReshapeFullyConnected>();
|
||||
|
||||
auto callback = [](const std::shared_ptr<const ngraph::Node> & node) -> bool {
|
||||
|
Loading…
Reference in New Issue
Block a user