[GPU] Add condition check for dynamic shape and onednn_impl in concat_in_place_optimization::match() (#18034)
* add dynamic shape support for dgpu in prepare_buffer_fusing * add unit test * add space between test cases * update condition of impl create() for concat dynamic shape * update unit test * add comment and update unit test * add impl_param.is_type() function
This commit is contained in:
parent
4fc0b22012
commit
1efb9eafae
@ -126,6 +126,13 @@ struct kernel_impl_params {
|
||||
template <class PType>
|
||||
std::shared_ptr<const PType> typed_desc() const { return std::static_pointer_cast<const PType>(desc); }
|
||||
|
||||
template <class PType>
|
||||
bool is_type() const {
|
||||
return std::static_pointer_cast<const PType>(desc)->type == PType::type_id();
|
||||
}
|
||||
|
||||
virtual primitive_type_id type() const { return desc->type; }
|
||||
|
||||
void save(BinaryOutputBuffer& ob) const;
|
||||
void load(BinaryInputBuffer& ib);
|
||||
const program& get_program() const {
|
||||
|
@ -206,21 +206,26 @@ bool concat_in_place_optimization::match(const program_node& concat_node,
|
||||
layout concat_out_l = concat_params.get_output_layout();
|
||||
if (!use_usm)
|
||||
return false;
|
||||
if (concat_out_l.batch() > 1)
|
||||
return false;
|
||||
// TODO: cldnn cases should be updated. This logic is working for onednn only.
|
||||
// white list for support fusing formats.
|
||||
const std::vector<format> white_list = {
|
||||
format::bfyx,
|
||||
format::bfzyx,
|
||||
format::b_fs_yx_fsv16,
|
||||
format::b_fs_zyx_fsv16,
|
||||
format::b_fs_yx_fsv32,
|
||||
format::b_fs_zyx_fsv32,
|
||||
format::b_fs_yx_fsv4,
|
||||
};
|
||||
if (std::find_if(white_list.begin(), white_list.end(), [&concat_out_l](format fmt){ return (fmt == concat_out_l.format); }) == std::end(white_list))
|
||||
return false;
|
||||
if (concat_node.is_dynamic() && !is_runtime) {
|
||||
// Return true in build time, it will be checked again in runtime
|
||||
return true;
|
||||
} else {
|
||||
if (concat_out_l.batch() > 1)
|
||||
return false;
|
||||
// TODO: cldnn cases should be updated. This logic is working for onednn only.
|
||||
// white list for support fusing formats.
|
||||
const std::vector<format> white_list = {
|
||||
format::bfyx,
|
||||
format::bfzyx,
|
||||
format::b_fs_yx_fsv16,
|
||||
format::b_fs_zyx_fsv16,
|
||||
format::b_fs_yx_fsv32,
|
||||
format::b_fs_zyx_fsv32,
|
||||
format::b_fs_yx_fsv4,
|
||||
};
|
||||
if (std::find_if(white_list.begin(), white_list.end(), [&concat_out_l](format fmt){ return (fmt == concat_out_l.format); }) == std::end(white_list))
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "kernel_selector_helper.h"
|
||||
#include "register.hpp"
|
||||
#include "implementation_map.hpp"
|
||||
#include "concatenation_inst.h"
|
||||
|
||||
#include <vector>
|
||||
#include <list>
|
||||
@ -86,7 +87,8 @@ struct typed_primitive_impl_ocl : public typed_primitive_impl<PType> {
|
||||
|
||||
template<typename ImplType>
|
||||
static std::unique_ptr<primitive_impl> create(const typed_program_node<PType>& arg, const kernel_impl_params& impl_param) {
|
||||
if (impl_param.can_be_optimized()) {
|
||||
// concat buffer fusing for dynamic shape is adaptively applied at runtime. So we need to build dynamic impl at build time.
|
||||
if (impl_param.can_be_optimized() && !(impl_param.is_type<concatenation>() && impl_param.is_dynamic())) {
|
||||
return make_unique<ImplType>(kernel_selector::kernel_data{});
|
||||
}
|
||||
auto kernel_params = ImplType::get_kernel_params(ImplType::static_canonicalize_shapes(impl_param));
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include "reshape_inst.h"
|
||||
#include "fully_connected_inst.h"
|
||||
#include "permute_inst.h"
|
||||
#include "reorder_inst.h"
|
||||
#include "intel_gpu/graph/network.hpp"
|
||||
#include "pass_manager.h"
|
||||
#include "to_string_utils.h"
|
||||
@ -258,6 +259,161 @@ TEST(prepare_buffer_fusing, in_place_concat_dynamic) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(prepare_buffer_fusing, in_place_concat_dynamic_onednn_batch1) {
|
||||
auto& engine = get_test_engine();
|
||||
if (!engine.get_device_info().supports_immad)
|
||||
return;
|
||||
auto in_layout1_0 = layout{ ov::PartialShape::dynamic(4), data_types::f16, format::b_fs_yx_fsv16 };
|
||||
auto in_layout2_0 = layout{ ov::PartialShape::dynamic(4), data_types::f16, format::b_fs_yx_fsv16 };
|
||||
auto in_layout1 = layout{ ov::PartialShape{1, 16, 2, 1}, data_types::f16, format::b_fs_yx_fsv16 };
|
||||
auto in_layout2 = layout{ ov::PartialShape{1, 16, 2, 1}, data_types::f16, format::b_fs_yx_fsv16 };
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("input1", in_layout1));
|
||||
topology.add(input_layout("input2", in_layout2));
|
||||
topology.add(reorder("reorder1", input_info("input1"), format::bfyx, data_types::f16));
|
||||
topology.add(reorder("reorder2", input_info("input2"), format::bfyx, data_types::f16));
|
||||
|
||||
topology.add(concatenation("concat", { input_info("reorder1"), input_info("reorder2") }, 1));
|
||||
topology.add(permute("output", input_info("concat"), {0, 2, 3, 1}));
|
||||
|
||||
ExecutionConfig config;
|
||||
config.set_property(ov::intel_gpu::optimize_data(true));
|
||||
config.set_property(ov::intel_gpu::allow_new_shape_infer(false));
|
||||
auto prog = program::build_program(engine, topology, config, false, false);
|
||||
ASSERT_NE(prog, nullptr);
|
||||
auto& concat_node_p = prog->get_node("concat");
|
||||
ASSERT_TRUE(concat_node_p.can_be_optimized());
|
||||
cldnn::network net(prog, 0);
|
||||
|
||||
auto input_memory1 = engine.allocate_memory(in_layout1);
|
||||
auto input_memory2 = engine.allocate_memory(in_layout2);
|
||||
set_values<FLOAT16>(input_memory1,
|
||||
{FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
|
||||
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f),
|
||||
FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
|
||||
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f)});
|
||||
set_values<FLOAT16>(input_memory2,
|
||||
{FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
|
||||
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f),
|
||||
FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
|
||||
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f)});
|
||||
net.set_input_data("input1", input_memory1);
|
||||
net.set_input_data("input2", input_memory2);
|
||||
|
||||
std::vector<FLOAT16> ref_output = {
|
||||
FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
|
||||
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f),
|
||||
FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
|
||||
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f),
|
||||
FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
|
||||
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f),
|
||||
FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
|
||||
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f)};
|
||||
|
||||
std::map<cldnn::primitive_id, cldnn::network_output> output;
|
||||
EXPECT_NO_THROW(output = net.execute());
|
||||
auto out_l = net.get_output_layout("output");
|
||||
auto out_mem = output.at("output").get_memory();
|
||||
cldnn::mem_lock<FLOAT16> output_ptr(out_mem, get_test_stream());
|
||||
|
||||
cldnn::mem_lock<FLOAT16> input1_ptr(input_memory1, get_test_stream());
|
||||
cldnn::mem_lock<FLOAT16> input2_ptr(input_memory2, get_test_stream());
|
||||
|
||||
const auto& concat_inst = net.get_primitive("concat");
|
||||
const auto& concat_node_n = concat_inst->get_node();
|
||||
auto concat_mem = net.get_primitive("concat")->output_memory_ptr();
|
||||
auto reorder1_mem = net.get_primitive("reorder1")->output_memory_ptr();
|
||||
auto reorder2_mem = net.get_primitive("reorder2")->output_memory_ptr();
|
||||
|
||||
ASSERT_EQ(concat_mem.get(), reorder1_mem.get());
|
||||
ASSERT_EQ(concat_mem.get(), reorder2_mem.get());
|
||||
ASSERT_TRUE(concat_inst->can_be_optimized());
|
||||
ASSERT_TRUE(concat_node_n.can_be_optimized());
|
||||
|
||||
for (size_t x = 0; x < out_l.count(); ++x) {
|
||||
ASSERT_EQ(ref_output[x], output_ptr[x]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(prepare_buffer_fusing, in_place_concat_dynamic_onednn_batch2) {
|
||||
// Check no buffer fusing when onednn concat with b=2. It is not supported.
|
||||
auto& engine = get_test_engine();
|
||||
if (!engine.get_device_info().supports_immad)
|
||||
return;
|
||||
auto in_layout1_0 = layout{ ov::PartialShape::dynamic(4), data_types::f16, format::b_fs_yx_fsv16 };
|
||||
auto in_layout2_0 = layout{ ov::PartialShape::dynamic(4), data_types::f16, format::b_fs_yx_fsv16 };
|
||||
auto in_layout1 = layout{ ov::PartialShape{1, 16, 2, 1}, data_types::f16, format::b_fs_yx_fsv16 };
|
||||
auto in_layout2 = layout{ ov::PartialShape{1, 16, 2, 1}, data_types::f16, format::b_fs_yx_fsv16 };
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("input1", in_layout1_0));
|
||||
topology.add(input_layout("input2", in_layout2_0));
|
||||
topology.add(reorder("reorder1", input_info("input1"), format::bfyx, data_types::f16));
|
||||
topology.add(reorder("reorder2", input_info("input2"), format::bfyx, data_types::f16));
|
||||
|
||||
topology.add(concatenation("concat", { input_info("reorder1"), input_info("reorder2") }, 0));
|
||||
topology.add(permute("output", input_info("concat"), {0, 2, 3, 1}));
|
||||
|
||||
ExecutionConfig config;
|
||||
config.set_property(ov::intel_gpu::optimize_data(true));
|
||||
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
|
||||
auto prog = program::build_program(engine, topology, config, false, false);
|
||||
ASSERT_NE(prog, nullptr);
|
||||
auto& concat_node_p = prog->get_node("concat");
|
||||
ASSERT_TRUE(concat_node_p.can_be_optimized());
|
||||
cldnn::network net(prog, 0);
|
||||
|
||||
auto input_memory1 = engine.allocate_memory(in_layout1);
|
||||
auto input_memory2 = engine.allocate_memory(in_layout2);
|
||||
set_values<FLOAT16>(input_memory1,
|
||||
{FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
|
||||
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f),
|
||||
FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
|
||||
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f)});
|
||||
set_values<FLOAT16>(input_memory2,
|
||||
{FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
|
||||
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f),
|
||||
FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
|
||||
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f)});
|
||||
net.set_input_data("input1", input_memory1);
|
||||
net.set_input_data("input2", input_memory2);
|
||||
|
||||
std::vector<FLOAT16> ref_output = {
|
||||
FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
|
||||
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f),
|
||||
FLOAT16(1.0f), FLOAT16(2.0f), FLOAT16(3.0f), FLOAT16(4.0f), FLOAT16(5.0f), FLOAT16(6.0f), FLOAT16(7.0f), FLOAT16(8.0f),
|
||||
FLOAT16(11.0f), FLOAT16(22.0f), FLOAT16(33.0f), FLOAT16(44.0f), FLOAT16(55.0f), FLOAT16(66.0f), FLOAT16(77.0f), FLOAT16(88.0f),
|
||||
FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
|
||||
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f),
|
||||
FLOAT16(111.0f), FLOAT16(222.0f), FLOAT16(333.0f), FLOAT16(444.0f), FLOAT16(555.0f), FLOAT16(666.0f), FLOAT16(777.0f), FLOAT16(888.0f),
|
||||
FLOAT16(1111.0f), FLOAT16(2222.0f), FLOAT16(3333.0f), FLOAT16(4444.0f), FLOAT16(5555.0f), FLOAT16(6666.0f), FLOAT16(7777.0f), FLOAT16(8888.0f)};
|
||||
|
||||
std::map<cldnn::primitive_id, cldnn::network_output> output;
|
||||
EXPECT_NO_THROW(output = net.execute());
|
||||
auto out_l = net.get_output_layout("output");
|
||||
auto out_mem = output.at("output").get_memory();
|
||||
cldnn::mem_lock<FLOAT16> output_ptr(out_mem, get_test_stream());
|
||||
|
||||
cldnn::mem_lock<FLOAT16> input1_ptr(input_memory1, get_test_stream());
|
||||
cldnn::mem_lock<FLOAT16> input2_ptr(input_memory2, get_test_stream());
|
||||
|
||||
const auto& concat_inst = net.get_primitive("concat");
|
||||
const auto& concat_node_n = concat_inst->get_node();
|
||||
auto concat_mem = net.get_primitive("concat")->output_memory_ptr();
|
||||
auto reorder1_mem = net.get_primitive("reorder1")->output_memory_ptr();
|
||||
auto reorder2_mem = net.get_primitive("reorder2")->output_memory_ptr();
|
||||
|
||||
ASSERT_NE(concat_mem.get(), reorder1_mem.get());
|
||||
ASSERT_NE(concat_mem.get(), reorder2_mem.get());
|
||||
ASSERT_FALSE(concat_inst->can_be_optimized());
|
||||
ASSERT_TRUE(concat_node_n.can_be_optimized());
|
||||
|
||||
for (size_t x = 0; x < out_l.count(); ++x) {
|
||||
ASSERT_EQ(ref_output[x], output_ptr[x]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(prepare_buffer_fusing, in_place_concat_dynamic__static_dim_dyn_pad) {
|
||||
auto& engine = get_test_engine();
|
||||
auto in_layout1_0 = layout{ ov::PartialShape{-1, 2, -1, -1}, data_types::f32, format::bfyx }; // => {-1, -1, -1, 2}
|
||||
|
Loading…
Reference in New Issue
Block a user