[GPU] Adjust dynamic eltw fusion to keep correct shape infer (#14198)

* [GPU] Adjust dynamic eltw fusion to keep correct shape infer

* [GPU] Fixed eltwise shape infer
This commit is contained in:
Vladimir Paramuzov
2022-11-28 09:08:19 +04:00
committed by GitHub
parent 03b677b10b
commit 7933cc7e0b
5 changed files with 103 additions and 67 deletions

View File

@@ -11,6 +11,9 @@
#include <vector>
#include <algorithm>
#include "openvino/op/add.hpp"
#include "utils.hpp"
namespace cldnn {
GPU_DEFINE_PRIMITIVE_TYPE_ID(eltwise)
@@ -111,70 +114,37 @@ std::vector<layout> eltwise_inst::calc_output_layouts(eltwise_node const& /*node
auto out_data_type = desc->output_data_type.value_or(input_layout.data_type);
auto get_output_layout = [&]() {
const auto& autob = desc->broadcast_spec;
auto out_pshape = input_layout.get<ShapeType>();
cldnn::format out_format = input_layout.format;
// We create dummy Add op as shape infer is exactly the same for any eltwise op type, so there is no need to have correct op type
ov::op::v1::Add op;
op.set_autob(desc->broadcast_spec);
std::vector<ShapeType> output_shapes = {ShapeType()};
std::vector<ShapeType> input_shapes;
for (size_t i = 0; i < desc->input_size(); i++) {
input_shapes.push_back(impl_param.get_input_layout(i).get<ShapeType>());
}
eltwise_shape_infer(&op, input_shapes, output_shapes);
if (input_layout.format == format::b_fs_zyx_fsv16) // use optimized 5D
out_format = format::b_fs_zyx_fsv16;
else if (input_layout.format == format::bs_fs_zyx_bsv16_fsv16)
out_format = format::bs_fs_zyx_bsv16_fsv16;
for (size_t i = 0; i < impl_param.input_layouts.size(); i++) {
for (size_t i = 0; i < desc->input_size(); i++) {
if (impl_param.primary_input_idx == i)
continue;
auto l = impl_param.get_non_padded_input_layout(i);
auto in_pshape = l.get<ShapeType>();
if (autob.m_type == ov::op::AutoBroadcastType::NONE) {
OPENVINO_ASSERT(ShapeType::merge_into(out_pshape, in_pshape), desc->id + ": Argument shapes are inconsistent.\n");
} else if (autob.m_type == ov::op::AutoBroadcastType::NUMPY || autob.m_type == ov::op::AutoBroadcastType::PDPD) {
auto origin_out_pshape = out_pshape;
// For out_pshape{2,3,15,1} and int_pshae{1,3},
// expected output shape for NUMPY should be out_pshape{2,3,15,1} but the actual output will be {2,3,15,3}
// So, fill the rank with default dim(1) for shape which has smaller rank.
if (autob.m_type == ov::op::AutoBroadcastType::NUMPY
&& out_pshape.rank().is_static() && in_pshape.rank().is_static()
&& out_pshape.rank() != in_pshape.rank()) {
ov::Dimension default_dim(1);
const auto in_pshape_rank = in_pshape.rank().get_length();
const auto out_pshape_rank = out_pshape.rank().get_length();
auto new_rank = std::max(in_pshape_rank, out_pshape_rank);
for (auto i = in_pshape_rank; i < new_rank; i++) {
in_pshape.push_back(default_dim);
}
for (auto i = out_pshape_rank; i < new_rank; i++) {
out_pshape.push_back(default_dim);
}
}
if (!ShapeType::broadcast_merge_into(out_pshape, in_pshape, autob)) {
// Temporarily add codes which get output shape using max value from each dimension to pass some legacy functional tests.
// IE_THROW() << desc->id << ": incorrect input shapes (" << out_pshape << " & " << in_pshape << ")\n" << str_endline;
out_pshape = origin_out_pshape;
if (out_pshape.is_static() && in_pshape.is_static()) {
auto in_shape = in_pshape.to_shape();
auto out_shape = out_pshape.to_shape();
for (size_t i = 0; i < in_shape.size(); i++) {
out_shape[i] = std::max(out_shape[i], in_shape[i]);
}
out_pshape = ShapeType(out_shape);
} else {
if (in_pshape.rank().is_static()) {
out_pshape = ShapeType::dynamic(in_pshape.rank());
}
}
}
} else {
OPENVINO_ASSERT(false, desc->id + ": Unsupported auto broadcast specification\n");
}
if (l.format == format::b_fs_zyx_fsv16) // use optimized 5D
out_format = format::b_fs_zyx_fsv16;
else if (l.format == format::bs_fs_zyx_bsv16_fsv16)
out_format = format::bs_fs_zyx_bsv16_fsv16;
}
return layout(out_pshape, out_data_type, out_format);
return layout(output_shapes[0], out_data_type, out_format);
};
auto output_layout = get_output_layout();

View File

@@ -998,6 +998,36 @@ void prepare_primitive_fusing::fuse_simple_primitives(program &p) {
can_fuse_parents[1] = false;
}
}
} else {
// In case of dynamic shapes we check that parent & peer shapes are compatible to allow merge
// This is required to avoid an issue when shape is partially defined and incorrectly propagated to further nodes
// which may ruin shape inference
// E.g. parent1 [?,?,768], parent2 [?,?,1]
// expected eltw out shape: [?,?,768]
// but w/o this check we can fuse eltwise to parent2 and return [?,?,1] as output shape which is unexpected
auto parent1_pshape = parent1->get_output_layout().get_partial_shape();
auto parent2_pshape = parent2->get_output_layout().get_partial_shape();
auto out_pshape = node.get_output_layout().get_partial_shape();
auto are_compatible = [](const ov::PartialShape& out_shape, const ov::PartialShape& in_shape) -> bool {
if (out_shape.rank().get_length() != in_shape.rank().get_length())
return false;
bool compatible = true;
for (size_t i = 0; i < out_shape.size(); i++) {
auto& od = out_shape[i];
auto& id = in_shape[i];
if (od.is_static() && id.is_static()) {
compatible &= od.get_length() == id.get_length();
} else if (id.is_static()) {
compatible &= id.get_length() != 1;
}
}
return compatible;
};
can_fuse_parents[0] = can_fuse_parents[0] && are_compatible(out_pshape, parent1_pshape);
can_fuse_parents[1] = can_fuse_parents[1] && are_compatible(out_pshape, parent2_pshape);
}
// We should have at least one node to fuse

View File

@@ -6,10 +6,11 @@
#include "intel_gpu/runtime/engine.hpp"
#include "intel_gpu/graph/network.hpp"
#include "intel_gpu/graph/program.hpp"
#include "data_inst.h"
#include "eltwise_inst.h"
#include "intel_gpu/graph/network.hpp"
#include "reduce_inst.h"
#include "pass_manager.h"
#include "to_string_utils.h"
@@ -44,6 +45,32 @@ TEST(prepare_primitive_fusing, fuse_activation_to_fc_dyn) {
ASSERT_FALSE(has_node_with_type<activation>(*prog));
}
TEST(prepare_primitive_fusing, dont_fuse_incompatible_eltwise) {
auto& engine = get_test_engine();
auto in_layout = layout{ ov::PartialShape{-1, -1, 10}, data_types::f32, format::bfyx };
auto const_layout = layout{ ov::PartialShape{1, 1, 1}, data_types::f32, format::bfyx };
auto const_mem = engine.allocate_memory(const_layout);
topology topology;
topology.add(input_layout("input", in_layout));
topology.add(data("const", const_mem));
topology.add(eltwise("eltw_pre", {"input", "const"}, eltwise_mode::sum));
topology.add(reduce("reduce", "eltw_pre", reduce_mode::max, {2}, true));
topology.add(eltwise("eltw", {"input", "reduce"}, eltwise_mode::sum));
topology.add(reorder("reorder", "eltw", format::bfyx, data_types::f32));
build_options build_opts;
build_opts.set_option(build_option::allow_new_shape_infer(true));
auto prog = program::build_program(engine, topology, build_opts, false, true);
layout_optimizer lo(true);
program_wrapper::apply_opt_pass<prepare_primitive_fusing>(*prog, lo);
ASSERT_NE(prog, nullptr);
ASSERT_TRUE(has_node(*prog, "eltw"));
}
TEST(prepare_primitive_fusing, fuse_eltwise_to_fc_dyn_legal) {
auto& engine = get_test_engine();
auto weights = engine.allocate_memory({ ov::PartialShape{ 16, 20 }, data_types::u8, format::bfyx });
@@ -86,11 +113,11 @@ TEST(prepare_primitive_fusing, fuse_eltwise_to_fc_dyn_legal) {
TEST(prepare_primitive_fusing, fuse_eltwise_to_fc_dyn_illegal) {
auto& engine = get_test_engine();
auto weights = engine.allocate_memory({ ov::PartialShape{ 1, 10 }, data_types::u8, format::bfyx });
auto weights = engine.allocate_memory({ ov::PartialShape{ 2, 10 }, data_types::u8, format::bfyx });
auto in_layout = layout{ ov::PartialShape::dynamic(2), data_types::u8, format::bfyx };
auto in_eltw_layout = layout{ ov::PartialShape::dynamic(2), data_types::f32, format::bfyx };
set_values<uint8_t>(weights, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
set_values<uint8_t>(weights, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
topology topology;
topology.add(data("weights", weights));
@@ -140,11 +167,11 @@ TEST(prepare_primitive_fusing, fuse_eltwise_to_fc_dyn_illegal) {
TEST(prepare_primitive_fusing, fuse_eltwise_to_fc_dyn_illegal_const) {
auto& engine = get_test_engine();
auto weights = engine.allocate_memory({ ov::PartialShape{ 1, 10 }, data_types::u8, format::bfyx });
auto weights = engine.allocate_memory({ ov::PartialShape{ 2, 10 }, data_types::u8, format::bfyx });
auto in_layout = layout{ ov::PartialShape::dynamic(2), data_types::u8, format::bfyx };
auto in_eltw_layout = layout{ ov::PartialShape{2, 2}, data_types::f32, format::bfyx };
set_values<uint8_t>(weights, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
set_values<uint8_t>(weights, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
auto extra_input_memory = engine.allocate_memory(in_eltw_layout);
set_values<float>(extra_input_memory, {10, 20, 30, 40});
@@ -245,11 +272,11 @@ TEST(prepare_primitive_fusing, fuse_eltwise_to_fc_dyn_legal_scalar_const_broadca
TEST(prepare_primitive_fusing, fuse_eltwise_to_fc_dyn_illegal_1) {
auto& engine = get_test_engine();
auto weights = engine.allocate_memory({ ov::PartialShape{ 1, 10 }, data_types::u8, format::bfyx });
auto weights = engine.allocate_memory({ ov::PartialShape{ 2, 10 }, data_types::u8, format::bfyx });
auto in_layout = layout{ ov::PartialShape::dynamic(2), data_types::u8, format::bfyx };
auto in_eltw_layout = layout{ ov::PartialShape::dynamic(2), data_types::f32, format::bfyx };
set_values<uint8_t>(weights, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
set_values<uint8_t>(weights, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
// The topology below is intended to check the following tricky things:
// 1. Cases where original eltw input is also optimized (act_e2 is fused into act_e1)
@@ -308,12 +335,12 @@ TEST(prepare_primitive_fusing, fuse_eltwise_to_fc_dyn_illegal_1) {
TEST(prepare_primitive_fusing, fuse_eltwise_to_fc_dyn_illegal_2) {
auto& engine = get_test_engine();
auto weights0 = engine.allocate_memory({ ov::PartialShape{ 2, 10 }, data_types::i8, format::bfyx });
auto weights1 = engine.allocate_memory({ ov::PartialShape{ 1, 2 }, data_types::i8, format::bfyx });
auto weights1 = engine.allocate_memory({ ov::PartialShape{ 4, 2 }, data_types::i8, format::bfyx });
auto in_layout = layout{ ov::PartialShape::dynamic(2), data_types::i8, format::bfyx };
auto in_eltw_layout = layout{ ov::PartialShape::dynamic(2), data_types::f32, format::bfyx };
set_values<uint8_t>(weights0, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
set_values<uint8_t>(weights1, {1, 1});
set_values<uint8_t>(weights1, {1, 1, 1, 1, 1, 1, 1, 1});
// The topology below is intended to check the following tricky things:

View File

@@ -83,31 +83,31 @@ INSTANTIATE_TEST_SUITE_P(smoke, eltwise_si_test,
testing::ValuesIn(std::vector<eltwise_test_params>{
{{{2, 1, 5}, data_types::f32, format::bfyx}, {{2, 1, 5}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NONE}, {{2, 1, 5}, data_types::f32, format::bfyx}, {}},
{{{2, 1, 5}, data_types::f32, format::bfyx}, {{1, 4, 1}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {{2, 4, 5}, data_types::f32, format::bfyx}, {}},
{{{1, 1, 5}, data_types::f32, format::bfyx}, {{5, 2, 1, 3}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {{5, 2, 5, 3}, data_types::f32, format::bfyx}, {}},
{{{1, 5, 1}, data_types::f32, format::bfyx}, {{5, 2, 1, 3}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {{5, 2, 5, 3}, data_types::f32, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::f32, format::bfyx}, {{4, 5}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::PDPD, -1}, {{2, 3, 4, 5}, data_types::f32, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::f32, format::bfyx}, {{1, 3}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::PDPD}, {{2, 3, 4, 5}, data_types::f32, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::f32, format::bfyx}, {{2, 3}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::PDPD}, {{2, 3, 4, 5}, data_types::f32, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::f32, format::bfyx}, {{3}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::PDPD, 1}, {{2, 3, 4, 5}, data_types::f32, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::f32, format::bfyx}, {{3}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {{3, 3, 4, 5}, data_types::f32, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::f32, format::bfyx}, {{5}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {{2, 3, 4, 5}, data_types::f32, format::bfyx}, {}},
// test for dynamic shape
{{{1, 1, 5}, data_types::f32, format::bfyx}, {{5, 2, 1, 3}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {{5, 2, 5, 3}, data_types::f32, format::bfyx}, {}},
{{PartialShape::dynamic(3), data_types::f32, format::bfyx}, {{2, 3, 4, 5}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::PDPD}, {PartialShape::dynamic(4), data_types::f32, format::bfyx}, {}},
{{{1, 5, 1}, data_types::f32, format::bfyx}, {{5, 2, 1, 3}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {{5, 2, 5, 3}, data_types::f32, format::bfyx}, {}},
{{{2, -1, 5}, data_types::f32, format::bfyx}, {{1, 4, 1}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {{2, 4, 5}, data_types::f32, format::bfyx}, {}},
{{PartialShape::dynamic(3), data_types::f32, format::bfyx}, {{1, 4, 1}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {{-1, 4, -1}, data_types::f32, format::bfyx}, {}},
{{PartialShape::dynamic(3), data_types::f32, format::bfyx}, {{2, 1, 5}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {{2, -1, 5}, data_types::f32, format::bfyx}, {}},
{{PartialShape::dynamic(3), data_types::f32, format::bfyx}, {{1, 4, 1}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::PDPD}, {PartialShape::dynamic(3), data_types::f32, format::bfyx}, {}},
{{{-1, -1, 1024, 512}, data_types::f32, format::bfyx}, {{1, 1, 512, 1}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {ov::PartialShape::dynamic(4), data_types::f32, format::bfyx}, {}},
{{{-1, -1, 1024, 512}, data_types::f32, format::bfyx}, {{1, 1, 512}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {{-1,-1,1024,512}, data_types::f32, format::bfyx}, {}},
{{{-1, -1, 768}, data_types::f32, format::bfyx}, {{768}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {{-1,-1,768}, data_types::f32, format::bfyx}, {}},
// test for output data type of logic and comparison operations
{{{2, 3, 4, 5}, data_types::f32, format::bfyx}, {{3}, data_types::f32, format::bfyx}, eltwise_mode::eq, {AutoBroadcastType::NUMPY}, {{3, 3, 4, 5}, data_types::i8, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::f16, format::bfyx}, {{3}, data_types::f16, format::bfyx}, eltwise_mode::ne, {AutoBroadcastType::NUMPY}, {{3, 3, 4, 5}, data_types::i8, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::f16, format::bfyx}, {{3}, data_types::f16, format::bfyx}, eltwise_mode::lt, {AutoBroadcastType::NUMPY}, {{3, 3, 4, 5}, data_types::i8, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::i32, format::bfyx}, {{3}, data_types::i32, format::bfyx}, eltwise_mode::le, {AutoBroadcastType::NUMPY}, {{3, 3, 4, 5}, data_types::i8, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::i64, format::bfyx}, {{3}, data_types::i64, format::bfyx}, eltwise_mode::gt, {AutoBroadcastType::NUMPY}, {{3, 3, 4, 5}, data_types::i8, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::f32, format::bfyx}, {{5}, data_types::f32, format::bfyx}, eltwise_mode::eq, {AutoBroadcastType::NUMPY}, {{2, 3, 4, 5}, data_types::i8, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::f16, format::bfyx}, {{5}, data_types::f16, format::bfyx}, eltwise_mode::ne, {AutoBroadcastType::NUMPY}, {{2, 3, 4, 5}, data_types::i8, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::f16, format::bfyx}, {{5}, data_types::f16, format::bfyx}, eltwise_mode::lt, {AutoBroadcastType::NUMPY}, {{2, 3, 4, 5}, data_types::i8, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::i32, format::bfyx}, {{5}, data_types::i32, format::bfyx}, eltwise_mode::le, {AutoBroadcastType::NUMPY}, {{2, 3, 4, 5}, data_types::i8, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::i64, format::bfyx}, {{5}, data_types::i64, format::bfyx}, eltwise_mode::gt, {AutoBroadcastType::NUMPY}, {{2, 3, 4, 5}, data_types::i8, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::u8, format::bfyx}, {{3}, data_types::u8, format::bfyx}, eltwise_mode::ge, {AutoBroadcastType::PDPD, 1}, {{2, 3, 4, 5}, data_types::i8, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::i8, format::bfyx}, {{3}, data_types::i8, format::bfyx}, eltwise_mode::logic_and,{AutoBroadcastType::PDPD, 1}, {{2, 3, 4, 5}, data_types::i8, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::f32, format::bfyx}, {{3}, data_types::f32, format::bfyx}, eltwise_mode::logic_or, {AutoBroadcastType::PDPD, 1}, {{2, 3, 4, 5}, data_types::i8, format::bfyx}, {}},
{{{2, 3, 4, 5}, data_types::f32, format::bfyx}, {{3}, data_types::f32, format::bfyx}, eltwise_mode::logic_xor,{AutoBroadcastType::PDPD, 1}, {{2, 3, 4, 5}, data_types::i8, format::bfyx}, {}},
// test stride
{{{5, 2, 1, 20}, data_types::f32, format::bfyx}, {{1, 1, 40}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {{5, 2, 1, 5}, data_types::f32, format::bfyx}, {{1,3,4,2}}},
{{{5, 2, 1, 20}, data_types::f32, format::bfyx}, {{1, 40, 1}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {{5, 2, 1, 5}, data_types::f32, format::bfyx}, {{1,3,4,2}}},
{{{2, 3, 40,50}, data_types::f32, format::bfyx}, {{40, 50}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::PDPD, -1}, {{2, 3, 5, 10}, data_types::f32, format::bfyx}, {{1,1,5,8}}},
{{PartialShape::dynamic(4), data_types::f32, format::bfyx}, {{2, 1, 5}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::NUMPY}, {PartialShape::dynamic(4), data_types::f32, format::bfyx}, {{1,1,5,8}}},
{{PartialShape::dynamic(4), data_types::f32, format::bfyx}, {{2, 1, 5}, data_types::f32, format::bfyx}, eltwise_mode::sum, {AutoBroadcastType::PDPD, 1}, {PartialShape::dynamic(4), data_types::f32, format::bfyx}, {{1,1,3,8}}},

View File

@@ -71,6 +71,15 @@ bool has_node_with_type(cldnn::program& prog) {
return false;
}
inline bool has_node(cldnn::program& prog, primitive_id id) {
for (auto node : prog.get_processing_order()) {
if (node->id() == id)
return true;
}
return false;
}
#define USE_RANDOM_SEED 0
#if USE_RANDOM_SEED
std::random_device rnd_device;