Update Transformations For Dynamic cases (#9803)
* Dynamic support * Update transformations
This commit is contained in:
parent
68ef22a53d
commit
61762fbaf0
@ -5,6 +5,7 @@
|
||||
#include "itt.hpp"
|
||||
#include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
@ -13,13 +14,26 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::BroadcastElementwiseFusion, "BroadcastEleme
|
||||
namespace {
|
||||
|
||||
bool can_eliminate_broadcast(const ngraph::Output<ngraph::Node>& eltwise,
|
||||
const ngraph::PartialShape & input_shape,
|
||||
const ngraph::PartialShape & broadcast_shape) {
|
||||
const ngraph::Output<ngraph::Node>& eltwise_input,
|
||||
const ngraph::Output<ngraph::Node>& broadcast) {
|
||||
auto b = std::dynamic_pointer_cast<ngraph::op::util::BinaryElementwiseArithmetic>(eltwise.get_node_shared_ptr());
|
||||
if (!b || b->get_autob() == ngraph::op::AutoBroadcastType::NONE) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that eltwise_input is the same input which comes to ShapeOf which comes
|
||||
// to Broadcast operation as a output shape target. In this case we can eliminate
|
||||
// Broadcast since eltwise_input will broadcast another eltwise input automatically.
|
||||
auto broadcast_input = broadcast.get_node()->get_input_node_shared_ptr(1);
|
||||
if ((ov::is_type<ngraph::opset5::ShapeOf>(broadcast_input) ||
|
||||
ov::is_type<ngraph::opset1::ShapeOf>(broadcast_input)) &&
|
||||
broadcast_input->input_value(0) == eltwise_input) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const auto & input_shape = eltwise_input.get_partial_shape();
|
||||
const auto & broadcast_shape = broadcast.get_partial_shape();
|
||||
|
||||
if (input_shape.rank().is_dynamic() || broadcast_shape.rank().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
@ -71,8 +85,7 @@ ngraph::pass::BroadcastElementwiseFusion::BroadcastElementwiseFusion() {
|
||||
const auto & m_broadcast_input = pattern_value.at(broadcast_input);
|
||||
auto & m_broadcast = pattern_value.at(broadcast);
|
||||
|
||||
if (!can_eliminate_broadcast(m_eltwise, m_eltwise_input.get_partial_shape(),
|
||||
m_broadcast.get_partial_shape())) {
|
||||
if (!can_eliminate_broadcast(m_eltwise, m_eltwise_input, m_broadcast)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -20,8 +20,8 @@ static std::shared_ptr<Node> fuse_const_to_weights(const std::shared_ptr<Node>&
|
||||
const op::AutoBroadcastSpec& autob) {
|
||||
auto const_shape = mul_const->get_shape();
|
||||
auto const_rank = static_cast<int64_t>(const_shape.size());
|
||||
const auto& weights_shape = weights.get_shape();
|
||||
int64_t weights_rank = static_cast<int64_t>(weights_shape.size());
|
||||
const auto& weights_shape = weights.get_partial_shape();
|
||||
int64_t weights_rank = static_cast<int64_t>(weights_shape.rank().get_length());
|
||||
|
||||
// Fuse if const is a scalar
|
||||
if (ngraph::is_scalar(const_shape)) {
|
||||
@ -61,10 +61,12 @@ static std::shared_ptr<Node> fuse_const_to_weights(const std::shared_ptr<Node>&
|
||||
if (const_shape.back() > 1) {
|
||||
// Check if const's last dimension matches last weights dimension
|
||||
if (matmul_casted->get_transpose_b()) {
|
||||
if (weights_rank > 1 && const_shape.back() != weights_shape[weights_rank - 2]) {
|
||||
if (weights_shape[weights_rank - 2].is_dynamic() ||
|
||||
(weights_rank > 1 && const_shape.back() != static_cast<size_t>(weights_shape[weights_rank - 2].get_length()))) {
|
||||
return nullptr;
|
||||
}
|
||||
} else if (const_shape.back() != weights_shape.back()) {
|
||||
} else if (weights_shape[weights_rank - 1].is_dynamic() ||
|
||||
const_shape.back() != static_cast<size_t>(weights_shape[weights_rank - 1].get_length())) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
@ -139,7 +141,7 @@ static std::shared_ptr<Node> fuse_const_to_weights(const std::shared_ptr<Node>&
|
||||
pass::MatMulMultiplyFusion::MatMulMultiplyFusion() {
|
||||
MATCHER_SCOPE(MatMulMultiplyFusion);
|
||||
auto input_pattern = pattern::any_input();
|
||||
auto weights_pattern = pattern::any_input(pattern::has_static_shape());
|
||||
auto weights_pattern = pattern::any_input(pattern::has_static_rank());
|
||||
auto mul_const_pattern = pattern::wrap_type<opset8::Constant>();
|
||||
auto matmul_pattern = pattern::wrap_type<opset8::MatMul>({input_pattern, weights_pattern});
|
||||
auto mul_pattern = pattern::wrap_type<opset8::Multiply>({matmul_pattern, mul_const_pattern});
|
||||
|
@ -58,11 +58,10 @@ bool isConvertableToPowerStatic(const std::shared_ptr<BaseOp> &node) {
|
||||
template <>
|
||||
bool isConvertableToPowerStatic(const std::shared_ptr<ngraph::opset1::Power> &node) {
|
||||
auto input_rank = node->get_input_partial_shape(0).rank();
|
||||
auto const_shape = node->get_input_shape(1);
|
||||
if (input_rank.is_dynamic())
|
||||
return false;
|
||||
return std::dynamic_pointer_cast<ngraph::opset1::Constant>(node->get_input_node_shared_ptr(1)) != nullptr &&
|
||||
input_rank.get_length() >= const_shape.size() && ngraph::shape_size(const_shape) == 1;
|
||||
auto const_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(node->get_input_node_shared_ptr(1));
|
||||
return const_node && input_rank.get_length() >= const_node->get_shape().size() && ngraph::shape_size(const_node->get_shape()) == 1;
|
||||
}
|
||||
|
||||
template <class BaseOp>
|
||||
|
@ -13,7 +13,9 @@
|
||||
NGRAPH_RTTI_DEFINITION(MKLDNNPlugin::ReshapeFullyConnectedFusion, "ReshapeFullyConnectedFusion", 0);
|
||||
|
||||
MKLDNNPlugin::ReshapeFullyConnectedFusion::ReshapeFullyConnectedFusion() {
|
||||
auto m_reshape = ngraph::pattern::wrap_type<ngraph::opset1::Reshape>(ngraph::pattern::has_static_shape());
|
||||
auto m_reshape = ngraph::pattern::wrap_type<ngraph::opset1::Reshape>({ngraph::pattern::any_input(ov::pass::pattern::has_static_shape()),
|
||||
ngraph::pattern::any_input()},
|
||||
ngraph::pattern::has_static_shape());
|
||||
ngraph::OutputVector twoInputs = {m_reshape, ngraph::pattern::any_input()};
|
||||
ngraph::OutputVector threeInputs = {m_reshape, ngraph::pattern::any_input(), ngraph::pattern::any_input()};
|
||||
auto fcTwoInputs = ngraph::pattern::wrap_type<MKLDNNPlugin::FullyConnectedNode>(twoInputs, ngraph::pattern::has_static_shape());
|
||||
|
@ -267,3 +267,22 @@ INSTANTIATE_TEST_SUITE_P(EliminateDynamicBroadcast, EliminateDynamicBroadcastTes
|
||||
INSTANTIATE_TEST_SUITE_P(NoEliminateDynamicBroadcast, NoEliminateDynamicBroadcastTest,
|
||||
testing::Values(std::make_tuple(InputShape{2, 1, 4}, InputShape{2, DYN, 4}, InputShape{2, DYN, 4}),
|
||||
std::make_tuple(InputShape{2, DYN, 4}, InputShape{2, DYN, 4}, InputShape{2, DYN, 4})));
|
||||
|
||||
|
||||
TEST_F(TransformationTestsF, BroadcastElementwiseFusionWithShapeOf) {
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, Shape{1, 3});
|
||||
auto shape_of = std::make_shared<ngraph::opset5::ShapeOf>(input);
|
||||
auto broadcast = std::make_shared<ngraph::opset5::Broadcast>(input, shape_of);
|
||||
auto elementwise = std::make_shared<ngraph::opset5::Multiply>(input, broadcast);
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{elementwise}, ngraph::ParameterVector{input});
|
||||
|
||||
manager.register_pass<pass::BroadcastElementwiseFusion>();
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset5::Parameter>(ngraph::element::f32, Shape{1, 3});
|
||||
auto elementwise = std::make_shared<ngraph::opset5::Multiply>(input, input);
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{elementwise}, ngraph::ParameterVector{input});
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user