concatenation_onednn migration (#7756)
- update concatenation_onednn - add unit test using onednn concat Signed-off-by: hyunback <hyunback.kim@intel.com>
This commit is contained in:
parent
bd2b346c62
commit
21c70f8eb5
@ -17,11 +17,110 @@
|
||||
namespace cldnn {
|
||||
namespace onednn {
|
||||
|
||||
struct concatenation_onednn : typed_primitive_onednn_impl<concatenation, void, dnnl::concat::primitive_desc, dnnl::concat> {
|
||||
using parent = typed_primitive_onednn_impl<concatenation, void, dnnl::concat::primitive_desc, dnnl::concat>;
|
||||
using parent::parent;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<primitive_impl> clone() const override {
|
||||
return make_unique<concatenation_onednn>(*this);
|
||||
}
|
||||
|
||||
std::unordered_map<int, dnnl::memory> get_arguments(concatenation_inst& instance) const override {
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
int input_idx = DNNL_ARG_MULTIPLE_SRC;
|
||||
for (size_t i = 0; i < instance.inputs_memory_count(); i++) {
|
||||
auto& input = instance.input_memory(i);
|
||||
args.insert({ input_idx++, input.get_onednn_memory(_pd.src_desc(static_cast<int>(i))) });
|
||||
}
|
||||
|
||||
{
|
||||
auto& output = instance.output_memory();
|
||||
args.insert({DNNL_ARG_DST, output.get_onednn_memory(_pd.dst_desc())});
|
||||
}
|
||||
|
||||
// TODO post operation
|
||||
// configure_post_ops_arguments(instance, args);
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
static std::shared_ptr<dnnl::concat::primitive_desc> get_concatenation_descriptor(const concatenation_node& arg) {
|
||||
auto prim = arg.get_primitive();
|
||||
|
||||
auto& engine = arg.get_program().get_engine();
|
||||
std::vector<dnnl::memory::desc> input_mds;
|
||||
for (auto& input : arg.get_dependencies()) {
|
||||
input_mds.push_back(onednn::layout_to_memory_desc(input->get_output_layout()));
|
||||
}
|
||||
auto output_md = onednn::layout_to_memory_desc(arg.get_output_layout());
|
||||
int axis = 0;
|
||||
switch (prim->axis) {
|
||||
case concatenation::concatenation_axis::along_b: axis = 0; break;
|
||||
case concatenation::concatenation_axis::along_f: axis = 1; break;
|
||||
case concatenation::concatenation_axis::along_y: axis = 2; break;
|
||||
case concatenation::concatenation_axis::along_x: axis = 3; break;
|
||||
default: throw std::runtime_error("unsupported concat axis");
|
||||
}
|
||||
|
||||
return std::make_shared<dnnl::concat::primitive_desc>(
|
||||
output_md,
|
||||
axis,
|
||||
input_mds,
|
||||
engine.get_onednn_engine());
|
||||
}
|
||||
|
||||
public:
|
||||
static primitive_impl* create(const concatenation_node& arg) {
|
||||
auto desc = get_concatenation_descriptor(arg);
|
||||
auto attr = get_primitive_attributes(arg);
|
||||
|
||||
std::shared_ptr<void> dummy = nullptr;
|
||||
|
||||
return new concatenation_onednn(arg, dummy, attr, *desc);
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
attach_concatenation_onednn::attach_concatenation_onednn() {
|
||||
implementation_map<concatenation>::add(impl_types::onednn, concatenation_onednn::create, {
|
||||
std::make_tuple(data_types::f32, format::bfyx),
|
||||
std::make_tuple(data_types::f16, format::bfyx),
|
||||
std::make_tuple(data_types::u8, format::bfyx),
|
||||
std::make_tuple(data_types::i8, format::bfyx),
|
||||
|
||||
std::make_tuple(data_types::f32, format::b_fs_yx_fsv16),
|
||||
std::make_tuple(data_types::f16, format::b_fs_yx_fsv16),
|
||||
std::make_tuple(data_types::u8, format::b_fs_yx_fsv16),
|
||||
std::make_tuple(data_types::i8, format::b_fs_yx_fsv16),
|
||||
|
||||
std::make_tuple(data_types::f32, format::b_fs_yx_fsv32),
|
||||
std::make_tuple(data_types::f16, format::b_fs_yx_fsv32),
|
||||
std::make_tuple(data_types::u8, format::b_fs_yx_fsv32),
|
||||
std::make_tuple(data_types::i8, format::b_fs_yx_fsv32),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv16_fsv16),
|
||||
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv16_fsv16),
|
||||
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv16_fsv16),
|
||||
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv16_fsv16),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv16),
|
||||
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv16),
|
||||
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv16),
|
||||
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv16),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv32),
|
||||
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv32),
|
||||
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv32),
|
||||
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv32),
|
||||
|
||||
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv4),
|
||||
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv4),
|
||||
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv4),
|
||||
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv4),
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
@ -1038,6 +1038,16 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
|
||||
// }
|
||||
|
||||
preferred_impl = impl_candidate;
|
||||
} else if (node.is_type<concatenation>()) {
|
||||
if (!_optimization_attributes.use_onednn_impls)
|
||||
return impl_types::ocl;
|
||||
|
||||
for (auto& dep : node.get_dependencies()) {
|
||||
if (dep->is_in_data_flow() && dep->get_preferred_impl_type() == impl_types::onednn) {
|
||||
preferred_impl = impl_types::onednn;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return preferred_impl;
|
||||
|
@ -783,3 +783,77 @@ INSTANTIATE_TEST_SUITE_P(smoke_low_precision,
|
||||
TestParamType_concat(2, { 15, 2, 16, 64 }, 1, 2)
|
||||
),
|
||||
concat_gpu::PrintToStringParamName);
|
||||
|
||||
|
||||
#ifdef ENABLE_ONEDNN_FOR_GPU
|
||||
TEST(concat_gpu_onednn, basic_input_types) {
|
||||
auto& engine = get_onednn_test_engine();
|
||||
|
||||
auto input0 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 3 } });
|
||||
auto input1 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 3 } });
|
||||
auto input2 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 3 } });
|
||||
auto input3 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 3 } });
|
||||
auto input4 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 3 } });
|
||||
|
||||
set_values<float>(input0, { 1.0f, 2.0f, 3.0f, 4.0f, 2.0f, 2.0f, 3.0f, 4.0f, 3.0f, 3.0f, 3.0f, 5.0f });
|
||||
set_values<float>(input1, { 11.0f, 12.0f, 13.0f, 14.0f, 12.0f, 12.0f, 13.0f, 14.0f, 13.0f, 13.0f, 13.0f, 15.0f });
|
||||
set_values<float>(input2, { 21.0f, 22.0f, 23.0f, 24.0f, 22.0f, 22.0f, 23.0f, 24.0f, 23.0f, 23.0f, 23.0f, 25.0f });
|
||||
set_values<float>(input3, { 31.0f, 32.0f, 33.0f, 34.0f, 32.0f, 32.0f, 33.0f, 34.0f, 33.0f, 33.0f, 33.0f, 35.0f });
|
||||
set_values<float>(input4, { 41.0f, 42.0f, 43.0f, 44.0f, 42.0f, 42.0f, 43.0f, 44.0f, 43.0f, 43.0f, 43.0f, 45.0f });
|
||||
|
||||
VF<float> output_vec = {
|
||||
1.0f, 2.0f, 3.0f, 4.0f, 2.0f, 2.0f, 3.0f, 4.0f, 3.0f, 3.0f, 3.0f, 5.0f,
|
||||
11.0f, 12.0f, 13.0f, 14.0f, 12.0f, 12.0f, 13.0f, 14.0f, 13.0f, 13.0f, 13.0f, 15.0f,
|
||||
21.0f, 22.0f, 23.0f, 24.0f, 22.0f, 22.0f, 23.0f, 24.0f, 23.0f, 23.0f, 23.0f, 25.0f,
|
||||
31.0f, 32.0f, 33.0f, 34.0f, 32.0f, 32.0f, 33.0f, 34.0f, 33.0f, 33.0f, 33.0f, 35.0f,
|
||||
41.0f, 42.0f, 43.0f, 44.0f, 42.0f, 42.0f, 43.0f, 44.0f, 43.0f, 43.0f, 43.0f, 45.0f };
|
||||
|
||||
topology topology(
|
||||
input_layout("input0", input0->get_layout()),
|
||||
input_layout("input1", input1->get_layout()),
|
||||
input_layout("input2", input2->get_layout()),
|
||||
input_layout("input3", input3->get_layout()),
|
||||
input_layout("input4", input4->get_layout()),
|
||||
concatenation("concat",
|
||||
{ "input0", "input1", "input2", "input3", "input4" },
|
||||
concatenation::concatenation_axis::along_f,
|
||||
data_types::f32,
|
||||
"",
|
||||
padding{ { 0,0,0,0 }, 0 })
|
||||
);
|
||||
|
||||
build_options options_target;
|
||||
options_target.set_option(build_option::outputs({ "concat" }));
|
||||
implementation_desc impl = { format::bfyx, std::string(""), impl_types::onednn };
|
||||
options_target.set_option(build_option::force_implementations({ {"concat", impl} }));
|
||||
|
||||
network network(engine, topology, options_target);
|
||||
network.set_input_data("input0", input0);
|
||||
network.set_input_data("input1", input1);
|
||||
network.set_input_data("input2", input2);
|
||||
network.set_input_data("input3", input3);
|
||||
network.set_input_data("input4", input4);
|
||||
|
||||
auto outputs = network.execute();
|
||||
EXPECT_EQ(outputs.size(), size_t(1));
|
||||
EXPECT_EQ(outputs.begin()->first, "concat");
|
||||
|
||||
auto output_memory = outputs.at("concat").get_memory();
|
||||
auto output_layout = output_memory->get_layout();
|
||||
cldnn::mem_lock<float> output_ptr(output_memory, get_test_stream());
|
||||
|
||||
int y_size = output_layout.size.spatial[1];
|
||||
int x_size = output_layout.size.spatial[0];
|
||||
int f_size = output_layout.size.feature[0];
|
||||
int b_size = output_layout.size.batch[0];
|
||||
EXPECT_EQ(output_layout.format, format::bfyx);
|
||||
EXPECT_EQ(y_size, 3);
|
||||
EXPECT_EQ(x_size, 4);
|
||||
EXPECT_EQ(f_size, 5);
|
||||
EXPECT_EQ(b_size, 1);
|
||||
|
||||
for (size_t x = 0; x < output_layout.count(); ++x) {
|
||||
EXPECT_EQ(output_vec[x], output_ptr[x]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user