Enables SharedOpOptimization for v1 and v3 Broadcasts (#21635)
This commit is contained in:
parent
0649865372
commit
e5a6077877
@ -6,6 +6,7 @@
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
@ -163,6 +164,24 @@ bool converts_are_equal(const Node* lhs, const Node* rhs) {
|
||||
inputs_from_same_source_or_equal_constants(lhs, rhs);
|
||||
}
|
||||
|
||||
bool broadcasts_1_are_equal(const Node* lhs, const Node* rhs) {
|
||||
const auto l_broadcast = as_type<const v1::Broadcast>(lhs);
|
||||
const auto r_broadcast = as_type<const v1::Broadcast>(rhs);
|
||||
if (!l_broadcast || !r_broadcast)
|
||||
return false;
|
||||
return l_broadcast->get_broadcast_spec() == r_broadcast->get_broadcast_spec() &&
|
||||
inputs_from_same_source_or_equal_constants(lhs, rhs);
|
||||
}
|
||||
|
||||
bool broadcasts_3_are_equal(const Node* lhs, const Node* rhs) {
|
||||
const auto l_broadcast = as_type<const v3::Broadcast>(lhs);
|
||||
const auto r_broadcast = as_type<const v3::Broadcast>(rhs);
|
||||
if (!l_broadcast || !r_broadcast)
|
||||
return false;
|
||||
return l_broadcast->get_broadcast_spec() == r_broadcast->get_broadcast_spec() &&
|
||||
inputs_from_same_source_or_equal_constants(lhs, rhs);
|
||||
}
|
||||
|
||||
bool shape_of_upgrade(const shared_ptr<Model>& model) {
|
||||
bool rewritten = false;
|
||||
for (const auto& op : model->get_ordered_ops()) {
|
||||
@ -199,6 +218,8 @@ bool pass::SharedOpOptimization::run_on_model(const shared_ptr<Model>& model) {
|
||||
RECORD_NO_ATTRIBUTES(v3::ScatterUpdate),
|
||||
|
||||
// with attributes
|
||||
RECORD(v1::Broadcast, broadcasts_1_are_equal),
|
||||
RECORD(v3::Broadcast, broadcasts_3_are_equal),
|
||||
RECORD(v0::Concat, concats_are_equal),
|
||||
RECORD(v0::Convert, converts_are_equal),
|
||||
RECORD(v1::Gather, gathers_are_equal),
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/ov_test_utils.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
@ -339,6 +340,48 @@ TEST_F(SharedTransformationTestsF, SharedShapeOfTestI64Only) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SharedTransformationTestsF, Sharedv1Broadcasts) {
|
||||
{
|
||||
auto input = std::make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto target_shape = std::make_shared<v0::Parameter>(element::i64, PartialShape::dynamic());
|
||||
auto broadcast_v1_0 = std::make_shared<v1::Broadcast>(input, target_shape);
|
||||
auto broadcast_v1_1 = std::make_shared<v1::Broadcast>(input, target_shape, AutoBroadcastType::PDPD);
|
||||
auto broadcast_v1_2 = std::make_shared<v1::Broadcast>(input, target_shape);
|
||||
auto concat = std::make_shared<v0::Concat>(OutputVector{broadcast_v1_0, broadcast_v1_1, broadcast_v1_2}, 0);
|
||||
model = std::make_shared<Model>(NodeVector{concat}, ParameterVector{input, target_shape});
|
||||
manager.register_pass<pass::SharedOpOptimization>();
|
||||
}
|
||||
{
|
||||
auto input = std::make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto target_shape = std::make_shared<v0::Parameter>(element::i64, PartialShape::dynamic());
|
||||
auto broadcast_v1_0 = std::make_shared<v1::Broadcast>(input, target_shape);
|
||||
auto broadcast_v1_1 = std::make_shared<v1::Broadcast>(input, target_shape, AutoBroadcastType::PDPD);
|
||||
auto concat = std::make_shared<v0::Concat>(OutputVector{broadcast_v1_0, broadcast_v1_1, broadcast_v1_0}, 0);
|
||||
model_ref = std::make_shared<Model>(NodeVector{concat}, ParameterVector{input, target_shape});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SharedTransformationTestsF, Sharedv3Broadcasts) {
|
||||
{
|
||||
auto input = std::make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto target_shape = std::make_shared<v0::Parameter>(element::i64, PartialShape::dynamic());
|
||||
auto broadcast_v1_0 = std::make_shared<v3::Broadcast>(input, target_shape);
|
||||
auto broadcast_v1_1 = std::make_shared<v3::Broadcast>(input, target_shape, BroadcastType::BIDIRECTIONAL);
|
||||
auto broadcast_v1_2 = std::make_shared<v3::Broadcast>(input, target_shape);
|
||||
auto concat = std::make_shared<v0::Concat>(OutputVector{broadcast_v1_0, broadcast_v1_1, broadcast_v1_2}, 0);
|
||||
model = std::make_shared<Model>(NodeVector{concat}, ParameterVector{input, target_shape});
|
||||
manager.register_pass<pass::SharedOpOptimization>();
|
||||
}
|
||||
{
|
||||
auto input = std::make_shared<v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
auto target_shape = std::make_shared<v0::Parameter>(element::i64, PartialShape::dynamic());
|
||||
auto broadcast_v1_0 = std::make_shared<v3::Broadcast>(input, target_shape);
|
||||
auto broadcast_v1_1 = std::make_shared<v3::Broadcast>(input, target_shape, BroadcastType::BIDIRECTIONAL);
|
||||
auto concat = std::make_shared<v0::Concat>(OutputVector{broadcast_v1_0, broadcast_v1_1, broadcast_v1_0}, 0);
|
||||
model_ref = std::make_shared<Model>(NodeVector{concat}, ParameterVector{input, target_shape});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SharedTransformationTestsF, SharedShapeOfTestI32Only) {
|
||||
Shape input_shape{120, 4};
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user