Update Transformations For Dynamic cases (#9803)

* Dynamic support

* Update transformations
This commit is contained in:
Gleb Kazantaev 2022-01-20 19:48:29 +03:00 committed by GitHub
parent 68ef22a53d
commit 61762fbaf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 48 additions and 13 deletions

View File

@ -5,6 +5,7 @@
#include "itt.hpp"
#include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
@ -13,13 +14,26 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::BroadcastElementwiseFusion, "BroadcastEleme
namespace {
bool can_eliminate_broadcast(const ngraph::Output<ngraph::Node>& eltwise,
const ngraph::PartialShape & input_shape,
const ngraph::PartialShape & broadcast_shape) {
const ngraph::Output<ngraph::Node>& eltwise_input,
const ngraph::Output<ngraph::Node>& broadcast) {
auto b = std::dynamic_pointer_cast<ngraph::op::util::BinaryElementwiseArithmetic>(eltwise.get_node_shared_ptr());
if (!b || b->get_autob() == ngraph::op::AutoBroadcastType::NONE) {
return false;
}
// Check that eltwise_input is the same input which comes to ShapeOf which comes
// to Broadcast operation as a output shape target. In this case we can eliminate
// Broadcast since eltwise_input will broadcast another eltwise input automatically.
auto broadcast_input = broadcast.get_node()->get_input_node_shared_ptr(1);
if ((ov::is_type<ngraph::opset5::ShapeOf>(broadcast_input) ||
ov::is_type<ngraph::opset1::ShapeOf>(broadcast_input)) &&
broadcast_input->input_value(0) == eltwise_input) {
return true;
}
const auto & input_shape = eltwise_input.get_partial_shape();
const auto & broadcast_shape = broadcast.get_partial_shape();
if (input_shape.rank().is_dynamic() || broadcast_shape.rank().is_dynamic()) {
return false;
}
@ -71,8 +85,7 @@ ngraph::pass::BroadcastElementwiseFusion::BroadcastElementwiseFusion() {
const auto & m_broadcast_input = pattern_value.at(broadcast_input);
auto & m_broadcast = pattern_value.at(broadcast);
if (!can_eliminate_broadcast(m_eltwise, m_eltwise_input.get_partial_shape(),
m_broadcast.get_partial_shape())) {
if (!can_eliminate_broadcast(m_eltwise, m_eltwise_input, m_broadcast)) {
return false;
}

View File

@ -20,8 +20,8 @@ static std::shared_ptr<Node> fuse_const_to_weights(const std::shared_ptr<Node>&
const op::AutoBroadcastSpec& autob) {
auto const_shape = mul_const->get_shape();
auto const_rank = static_cast<int64_t>(const_shape.size());
const auto& weights_shape = weights.get_shape();
int64_t weights_rank = static_cast<int64_t>(weights_shape.size());
const auto& weights_shape = weights.get_partial_shape();
int64_t weights_rank = static_cast<int64_t>(weights_shape.rank().get_length());
// Fuse if const is a scalar
if (ngraph::is_scalar(const_shape)) {
@ -61,10 +61,12 @@ static std::shared_ptr<Node> fuse_const_to_weights(const std::shared_ptr<Node>&
if (const_shape.back() > 1) {
// Check if const's last dimension matches last weights dimension
if (matmul_casted->get_transpose_b()) {
if (weights_rank > 1 && const_shape.back() != weights_shape[weights_rank - 2]) {
if (weights_shape[weights_rank - 2].is_dynamic() ||
(weights_rank > 1 && const_shape.back() != static_cast<size_t>(weights_shape[weights_rank - 2].get_length()))) {
return nullptr;
}
} else if (const_shape.back() != weights_shape.back()) {
} else if (weights_shape[weights_rank - 1].is_dynamic() ||
const_shape.back() != static_cast<size_t>(weights_shape[weights_rank - 1].get_length())) {
return nullptr;
}
}
@ -139,7 +141,7 @@ static std::shared_ptr<Node> fuse_const_to_weights(const std::shared_ptr<Node>&
pass::MatMulMultiplyFusion::MatMulMultiplyFusion() {
MATCHER_SCOPE(MatMulMultiplyFusion);
auto input_pattern = pattern::any_input();
auto weights_pattern = pattern::any_input(pattern::has_static_shape());
auto weights_pattern = pattern::any_input(pattern::has_static_rank());
auto mul_const_pattern = pattern::wrap_type<opset8::Constant>();
auto matmul_pattern = pattern::wrap_type<opset8::MatMul>({input_pattern, weights_pattern});
auto mul_pattern = pattern::wrap_type<opset8::Multiply>({matmul_pattern, mul_const_pattern});

View File

@ -58,11 +58,10 @@ bool isConvertableToPowerStatic(const std::shared_ptr<BaseOp> &node) {
template <>
bool isConvertableToPowerStatic(const std::shared_ptr<ngraph::opset1::Power> &node) {
auto input_rank = node->get_input_partial_shape(0).rank();
auto const_shape = node->get_input_shape(1);
if (input_rank.is_dynamic())
return false;
return std::dynamic_pointer_cast<ngraph::opset1::Constant>(node->get_input_node_shared_ptr(1)) != nullptr &&
input_rank.get_length() >= const_shape.size() && ngraph::shape_size(const_shape) == 1;
auto const_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(node->get_input_node_shared_ptr(1));
return const_node && input_rank.get_length() >= const_node->get_shape().size() && ngraph::shape_size(const_node->get_shape()) == 1;
}
template <class BaseOp>

View File

@ -13,7 +13,9 @@
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ReshapeFullyConnectedFusion, "ReshapeFullyConnectedFusion", 0);
MKLDNNPlugin::ReshapeFullyConnectedFusion::ReshapeFullyConnectedFusion() {
auto m_reshape = ngraph::pattern::wrap_type<ngraph::opset1::Reshape>(ngraph::pattern::has_static_shape());
auto m_reshape = ngraph::pattern::wrap_type<ngraph::opset1::Reshape>({ngraph::pattern::any_input(ov::pass::pattern::has_static_shape()),
ngraph::pattern::any_input()},
ngraph::pattern::has_static_shape());
ngraph::OutputVector twoInputs = {m_reshape, ngraph::pattern::any_input()};
ngraph::OutputVector threeInputs = {m_reshape, ngraph::pattern::any_input(), ngraph::pattern::any_input()};
auto fcTwoInputs = ngraph::pattern::wrap_type<MKLDNNPlugin::FullyConnectedNode>(twoInputs, ngraph::pattern::has_static_shape());

View File

@ -267,3 +267,22 @@ INSTANTIATE_TEST_SUITE_P(EliminateDynamicBroadcast, EliminateDynamicBroadcastTes
INSTANTIATE_TEST_SUITE_P(NoEliminateDynamicBroadcast, NoEliminateDynamicBroadcastTest,
testing::Values(std::make_tuple(InputShape{2, 1, 4}, InputShape{2, DYN, 4}, InputShape{2, DYN, 4}),
std::make_tuple(InputShape{2, DYN, 4}, InputShape{2, DYN, 4}, InputShape{2, DYN, 4})));
TEST_F(TransformationTestsF, BroadcastElementwiseFusionWithShapeOf) {
{
auto input = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, Shape{1, 3});
auto shape_of = std::make_shared<ngraph::opset5::ShapeOf>(input);
auto broadcast = std::make_shared<ngraph::opset5::Broadcast>(input, shape_of);
auto elementwise = std::make_shared<ngraph::opset5::Multiply>(input, broadcast);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{elementwise}, ngraph::ParameterVector{input});
manager.register_pass<pass::BroadcastElementwiseFusion>();
}
{
auto input = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, Shape{1, 3});
auto elementwise = std::make_shared<ngraph::opset5::Multiply>(input, input);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{elementwise}, ngraph::ParameterVector{input});
}
}