[GPU] Add condition check for dynamic shape and onednn_impl in concat_in_place_optimization::match() (#18034)

* add dynamic shape support for dgpu in prepare_buffer_fusing

* add unit test

* add space between test cases

* update condition of impl create() for concat dynamic shape

* update unit test

* add comment and update unit test

* add impl_param.is_type() function
This commit is contained in:
Wilson Seok 2023-06-28 15:39:00 +09:00 committed by GitHub
parent 4fc0b22012
commit 1efb9eafae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 186 additions and 16 deletions

View File

@ -126,6 +126,13 @@ struct kernel_impl_params {
template <class PType>
std::shared_ptr<const PType> typed_desc() const { return std::static_pointer_cast<const PType>(desc); }
template <class PType>
bool is_type() const {
return std::static_pointer_cast<const PType>(desc)->type == PType::type_id();
}
virtual primitive_type_id type() const { return desc->type; }
void save(BinaryOutputBuffer& ob) const;
void load(BinaryInputBuffer& ib);
const program& get_program() const {

View File

@ -206,21 +206,26 @@ bool concat_in_place_optimization::match(const program_node& concat_node,
layout concat_out_l = concat_params.get_output_layout();
if (!use_usm)
return false;
if (concat_out_l.batch() > 1)
return false;
// TODO: cldnn cases should be updated. This logic is working for onednn only.
// white list for support fusing formats.
const std::vector<format> white_list = {
format::bfyx,
format::bfzyx,
format::b_fs_yx_fsv16,
format::b_fs_zyx_fsv16,
format::b_fs_yx_fsv32,
format::b_fs_zyx_fsv32,
format::b_fs_yx_fsv4,
};
if (std::find_if(white_list.begin(), white_list.end(), [&concat_out_l](format fmt){ return (fmt == concat_out_l.format); }) == std::end(white_list))
return false;
if (concat_node.is_dynamic() && !is_runtime) {
// Return true in build time, it will be checked again in runtime
return true;
} else {
if (concat_out_l.batch() > 1)
return false;
// TODO: cldnn cases should be updated. This logic is working for onednn only.
// white list for support fusing formats.
const std::vector<format> white_list = {
format::bfyx,
format::bfzyx,
format::b_fs_yx_fsv16,
format::b_fs_zyx_fsv16,
format::b_fs_yx_fsv32,
format::b_fs_zyx_fsv32,
format::b_fs_yx_fsv4,
};
if (std::find_if(white_list.begin(), white_list.end(), [&concat_out_l](format fmt){ return (fmt == concat_out_l.format); }) == std::end(white_list))
return false;
}
}
return true;
}

View File

@ -17,6 +17,7 @@
#include "kernel_selector_helper.h"
#include "register.hpp"
#include "implementation_map.hpp"
#include "concatenation_inst.h"
#include <vector>
#include <list>
@ -86,7 +87,8 @@ struct typed_primitive_impl_ocl : public typed_primitive_impl<PType> {
template<typename ImplType>
static std::unique_ptr<primitive_impl> create(const typed_program_node<PType>& arg, const kernel_impl_params& impl_param) {
if (impl_param.can_be_optimized()) {
// concat buffer fusing for dynamic shape is adaptively applied at runtime. So we need to build dynamic impl at build time.
if (impl_param.can_be_optimized() && !(impl_param.is_type<concatenation>() && impl_param.is_dynamic())) {
return make_unique<ImplType>(kernel_selector::kernel_data{});
}
auto kernel_params = ImplType::get_kernel_params(ImplType::static_canonicalize_shapes(impl_param));

View File

@ -12,6 +12,7 @@
#include "reshape_inst.h"
#include "fully_connected_inst.h"
#include "permute_inst.h"
#include "reorder_inst.h"
#include "intel_gpu/graph/network.hpp"
#include "pass_manager.h"
#include "to_string_utils.h"
@ -258,6 +259,161 @@ TEST(prepare_buffer_fusing, in_place_concat_dynamic) {
}
}
TEST(prepare_buffer_fusing, in_place_concat_dynamic_onednn_batch1) {
auto& engine = get_test_engine();
if (!engine.get_device_info().supports_immad)
return;
auto in_layout1_0 = layout{ ov::PartialShape::dynamic(4), data_types::f16, format::b_fs_yx_fsv16 };
auto in_layout2_0 = layout{ ov::PartialShape::dynamic(4), data_types::f16, format::b_fs_yx_fsv16 };
auto in_layout1 = layout{ ov::PartialShape{1, 16, 2, 1}, data_types::f16, format::b_fs_yx_fsv16 };
auto in_layout2 = layout{ ov::PartialShape{1, 16, 2, 1}, data_types::f16, format::b_fs_yx_fsv16 };
topology topology;
topology.add(input_layout("input1", in_layout1));
topology.add(input_layout("input2", in_layout2));
topology.add(reorder("reorder1", input_info("input1"), format::bfyx, data_types::f16));
topology.add(reorder("reorder2", input_info("input2"), format::bfyx, data_types::f16));
topology.add(concatenation("concat", { input_info("reorder1"), input_info("reorder2") }, 1));
topology.add(permute("output", input_info("concat"), {0, 2, 3, 1}));
ExecutionConfig config;
config.set_property(ov::intel_gpu::optimize_data(true));
config.set_property(ov::intel_gpu::allow_new_shape_infer(false));
auto prog = program::build_program(engine, topology, config, false, false);
ASSERT_NE(prog, nullptr);
auto& concat_node_p = prog->get_node("concat");
ASSERT_TRUE(concat_node_p.can_be_optimized());
cldnn::network net(prog, 0);
auto input_memory1 = engine.allocate_memory(in_layout1);
auto input_memory2 = engine.allocate_memory(in_layout2);
set_values<FLOAT16>(input_memory1,
{FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f),
FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f)});
set_values<FLOAT16>(input_memory2,
{FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f),
FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f)});
net.set_input_data("input1", input_memory1);
net.set_input_data("input2", input_memory2);
std::vector<FLOAT16> ref_output = {
FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f),
FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f),
FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f),
FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f)};
std::map<cldnn::primitive_id, cldnn::network_output> output;
EXPECT_NO_THROW(output = net.execute());
auto out_l = net.get_output_layout("output");
auto out_mem = output.at("output").get_memory();
cldnn::mem_lock<FLOAT16> output_ptr(out_mem, get_test_stream());
cldnn::mem_lock<FLOAT16> input1_ptr(input_memory1, get_test_stream());
cldnn::mem_lock<FLOAT16> input2_ptr(input_memory2, get_test_stream());
const auto& concat_inst = net.get_primitive("concat");
const auto& concat_node_n = concat_inst->get_node();
auto concat_mem = net.get_primitive("concat")->output_memory_ptr();
auto reorder1_mem = net.get_primitive("reorder1")->output_memory_ptr();
auto reorder2_mem = net.get_primitive("reorder2")->output_memory_ptr();
ASSERT_EQ(concat_mem.get(), reorder1_mem.get());
ASSERT_EQ(concat_mem.get(), reorder2_mem.get());
ASSERT_TRUE(concat_inst->can_be_optimized());
ASSERT_TRUE(concat_node_n.can_be_optimized());
for (size_t x = 0; x < out_l.count(); ++x) {
ASSERT_EQ(ref_output[x], output_ptr[x]);
}
}
TEST(prepare_buffer_fusing, in_place_concat_dynamic_onednn_batch2) {
// Check no buffer fusing when onednn concat with b=2. It is not supported.
auto& engine = get_test_engine();
if (!engine.get_device_info().supports_immad)
return;
auto in_layout1_0 = layout{ ov::PartialShape::dynamic(4), data_types::f16, format::b_fs_yx_fsv16 };
auto in_layout2_0 = layout{ ov::PartialShape::dynamic(4), data_types::f16, format::b_fs_yx_fsv16 };
auto in_layout1 = layout{ ov::PartialShape{1, 16, 2, 1}, data_types::f16, format::b_fs_yx_fsv16 };
auto in_layout2 = layout{ ov::PartialShape{1, 16, 2, 1}, data_types::f16, format::b_fs_yx_fsv16 };
topology topology;
topology.add(input_layout("input1", in_layout1_0));
topology.add(input_layout("input2", in_layout2_0));
topology.add(reorder("reorder1", input_info("input1"), format::bfyx, data_types::f16));
topology.add(reorder("reorder2", input_info("input2"), format::bfyx, data_types::f16));
topology.add(concatenation("concat", { input_info("reorder1"), input_info("reorder2") }, 0));
topology.add(permute("output", input_info("concat"), {0, 2, 3, 1}));
ExecutionConfig config;
config.set_property(ov::intel_gpu::optimize_data(true));
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
auto prog = program::build_program(engine, topology, config, false, false);
ASSERT_NE(prog, nullptr);
auto& concat_node_p = prog->get_node("concat");
ASSERT_TRUE(concat_node_p.can_be_optimized());
cldnn::network net(prog, 0);
auto input_memory1 = engine.allocate_memory(in_layout1);
auto input_memory2 = engine.allocate_memory(in_layout2);
set_values<FLOAT16>(input_memory1,
{FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f),
FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f)});
set_values<FLOAT16>(input_memory2,
{FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f),
FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f)});
net.set_input_data("input1", input_memory1);
net.set_input_data("input2", input_memory2);
std::vector<FLOAT16> ref_output = {
FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f),
FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f),
FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f),
FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f)};
std::map<cldnn::primitive_id, cldnn::network_output> output;
EXPECT_NO_THROW(output = net.execute());
auto out_l = net.get_output_layout("output");
auto out_mem = output.at("output").get_memory();
cldnn::mem_lock<FLOAT16> output_ptr(out_mem, get_test_stream());
cldnn::mem_lock<FLOAT16> input1_ptr(input_memory1, get_test_stream());
cldnn::mem_lock<FLOAT16> input2_ptr(input_memory2, get_test_stream());
const auto& concat_inst = net.get_primitive("concat");
const auto& concat_node_n = concat_inst->get_node();
auto concat_mem = net.get_primitive("concat")->output_memory_ptr();
auto reorder1_mem = net.get_primitive("reorder1")->output_memory_ptr();
auto reorder2_mem = net.get_primitive("reorder2")->output_memory_ptr();
ASSERT_NE(concat_mem.get(), reorder1_mem.get());
ASSERT_NE(concat_mem.get(), reorder2_mem.get());
ASSERT_FALSE(concat_inst->can_be_optimized());
ASSERT_TRUE(concat_node_n.can_be_optimized());
for (size_t x = 0; x < out_l.count(); ++x) {
ASSERT_EQ(ref_output[x], output_ptr[x]);
}
}
TEST(prepare_buffer_fusing, in_place_concat_dynamic__static_dim_dyn_pad) {
auto& engine = get_test_engine();
auto in_layout1_0 = layout{ ov::PartialShape{-1, 2, -1, -1}, data_types::f32, format::bfyx }; // => {-1, -1, -1, 2}