From 21c70f8eb55ca4b51e5331b77787577ff5250155 Mon Sep 17 00:00:00 2001 From: hyunback kim Date: Tue, 5 Oct 2021 09:16:32 +0900 Subject: [PATCH] concatenation_onednn migration (#7756) - update concatenation_onednn - add unit test using onednn concat Signed-off-by: hyunback --- .../src/impls/onednn/concatenation_onednn.cpp | 99 +++++++++++++++++++ .../thirdparty/clDNN/src/layout_optimizer.cpp | 10 ++ .../test_cases/concatenation_gpu_test.cpp | 74 ++++++++++++++ 3 files changed, 183 insertions(+) diff --git a/inference-engine/thirdparty/clDNN/src/impls/onednn/concatenation_onednn.cpp b/inference-engine/thirdparty/clDNN/src/impls/onednn/concatenation_onednn.cpp index fffc039aefb..5ba38256a00 100644 --- a/inference-engine/thirdparty/clDNN/src/impls/onednn/concatenation_onednn.cpp +++ b/inference-engine/thirdparty/clDNN/src/impls/onednn/concatenation_onednn.cpp @@ -17,11 +17,110 @@ namespace cldnn { namespace onednn { +struct concatenation_onednn : typed_primitive_onednn_impl { + using parent = typed_primitive_onednn_impl; + using parent::parent; +protected: + std::unique_ptr clone() const override { + return make_unique(*this); + } + + std::unordered_map get_arguments(concatenation_inst& instance) const override { + std::unordered_map args; + + int input_idx = DNNL_ARG_MULTIPLE_SRC; + for (size_t i = 0; i < instance.inputs_memory_count(); i++) { + auto& input = instance.input_memory(i); + args.insert({ input_idx++, input.get_onednn_memory(_pd.src_desc(static_cast(i))) }); + } + + { + auto& output = instance.output_memory(); + args.insert({DNNL_ARG_DST, output.get_onednn_memory(_pd.dst_desc())}); + } + + // TODO post operation + // configure_post_ops_arguments(instance, args); + + return args; + } + + static std::shared_ptr get_concatenation_descriptor(const concatenation_node& arg) { + auto prim = arg.get_primitive(); + + auto& engine = arg.get_program().get_engine(); + std::vector input_mds; + for (auto& input : arg.get_dependencies()) { + input_mds.push_back(onednn::layout_to_memory_desc(input->get_output_layout())); + } + auto output_md = onednn::layout_to_memory_desc(arg.get_output_layout()); + int axis = 0; + switch (prim->axis) { + case concatenation::concatenation_axis::along_b: axis = 0; break; + case concatenation::concatenation_axis::along_f: axis = 1; break; + case concatenation::concatenation_axis::along_y: axis = 2; break; + case concatenation::concatenation_axis::along_x: axis = 3; break; + default: throw std::runtime_error("unsupported concat axis"); + } + + return std::make_shared( + output_md, + axis, + input_mds, + engine.get_onednn_engine()); + } + +public: + static primitive_impl* create(const concatenation_node& arg) { + auto desc = get_concatenation_descriptor(arg); + auto attr = get_primitive_attributes(arg); + + std::shared_ptr dummy = nullptr; + + return new concatenation_onednn(arg, dummy, attr, *desc); + } +}; namespace detail { attach_concatenation_onednn::attach_concatenation_onednn() { + implementation_map::add(impl_types::onednn, concatenation_onednn::create, { + std::make_tuple(data_types::f32, format::bfyx), + std::make_tuple(data_types::f16, format::bfyx), + std::make_tuple(data_types::u8, format::bfyx), + std::make_tuple(data_types::i8, format::bfyx), + + std::make_tuple(data_types::f32, format::b_fs_yx_fsv16), + std::make_tuple(data_types::f16, format::b_fs_yx_fsv16), + std::make_tuple(data_types::u8, format::b_fs_yx_fsv16), + std::make_tuple(data_types::i8, format::b_fs_yx_fsv16), + + std::make_tuple(data_types::f32, format::b_fs_yx_fsv32), + std::make_tuple(data_types::f16, format::b_fs_yx_fsv32), + std::make_tuple(data_types::u8, format::b_fs_yx_fsv32), + std::make_tuple(data_types::i8, format::b_fs_yx_fsv32), + + std::make_tuple(data_types::f32, format::bs_fs_yx_bsv16_fsv16), + std::make_tuple(data_types::f16, format::bs_fs_yx_bsv16_fsv16), + std::make_tuple(data_types::u8, format::bs_fs_yx_bsv16_fsv16), + std::make_tuple(data_types::i8, format::bs_fs_yx_bsv16_fsv16), + + std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv16), + std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv16), + std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv16), + std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv16), + + std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv32), + std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv32), + std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv32), + std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv32), + + std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv4), + std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv4), + std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv4), + std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv4), + }); } } // namespace detail diff --git a/inference-engine/thirdparty/clDNN/src/layout_optimizer.cpp b/inference-engine/thirdparty/clDNN/src/layout_optimizer.cpp index 03413c27144..a806303a278 100644 --- a/inference-engine/thirdparty/clDNN/src/layout_optimizer.cpp +++ b/inference-engine/thirdparty/clDNN/src/layout_optimizer.cpp @@ -1038,6 +1038,16 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format // } preferred_impl = impl_candidate; + } else if (node.is_type()) { + if (!_optimization_attributes.use_onednn_impls) + return impl_types::ocl; + + for (auto& dep : node.get_dependencies()) { + if (dep->is_in_data_flow() && dep->get_preferred_impl_type() == impl_types::onednn) { + preferred_impl = impl_types::onednn; + break; + } + } } return preferred_impl; diff --git a/inference-engine/thirdparty/clDNN/tests/test_cases/concatenation_gpu_test.cpp b/inference-engine/thirdparty/clDNN/tests/test_cases/concatenation_gpu_test.cpp index 444e1cde280..b649dc0ef51 100644 --- a/inference-engine/thirdparty/clDNN/tests/test_cases/concatenation_gpu_test.cpp +++ b/inference-engine/thirdparty/clDNN/tests/test_cases/concatenation_gpu_test.cpp @@ -783,3 +783,77 @@ INSTANTIATE_TEST_SUITE_P(smoke_low_precision, TestParamType_concat(2, { 15, 2, 16, 64 }, 1, 2) ), concat_gpu::PrintToStringParamName); + + +#ifdef ENABLE_ONEDNN_FOR_GPU +TEST(concat_gpu_onednn, basic_input_types) { + auto& engine = get_onednn_test_engine(); + + auto input0 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 3 } }); + auto input1 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 3 } }); + auto input2 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 3 } }); + auto input3 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 3 } }); + auto input4 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 3 } }); + + set_values(input0, { 1.0f, 2.0f, 3.0f, 4.0f, 2.0f, 2.0f, 3.0f, 4.0f, 3.0f, 3.0f, 3.0f, 5.0f }); + set_values(input1, { 11.0f, 12.0f, 13.0f, 14.0f, 12.0f, 12.0f, 13.0f, 14.0f, 13.0f, 13.0f, 13.0f, 15.0f }); + set_values(input2, { 21.0f, 22.0f, 23.0f, 24.0f, 22.0f, 22.0f, 23.0f, 24.0f, 23.0f, 23.0f, 23.0f, 25.0f }); + set_values(input3, { 31.0f, 32.0f, 33.0f, 34.0f, 32.0f, 32.0f, 33.0f, 34.0f, 33.0f, 33.0f, 33.0f, 35.0f }); + set_values(input4, { 41.0f, 42.0f, 43.0f, 44.0f, 42.0f, 42.0f, 43.0f, 44.0f, 43.0f, 43.0f, 43.0f, 45.0f }); + + VF output_vec = { + 1.0f, 2.0f, 3.0f, 4.0f, 2.0f, 2.0f, 3.0f, 4.0f, 3.0f, 3.0f, 3.0f, 5.0f, + 11.0f, 12.0f, 13.0f, 14.0f, 12.0f, 12.0f, 13.0f, 14.0f, 13.0f, 13.0f, 13.0f, 15.0f, + 21.0f, 22.0f, 23.0f, 24.0f, 22.0f, 22.0f, 23.0f, 24.0f, 23.0f, 23.0f, 23.0f, 25.0f, + 31.0f, 32.0f, 33.0f, 34.0f, 32.0f, 32.0f, 33.0f, 34.0f, 33.0f, 33.0f, 33.0f, 35.0f, + 41.0f, 42.0f, 43.0f, 44.0f, 42.0f, 42.0f, 43.0f, 44.0f, 43.0f, 43.0f, 43.0f, 45.0f }; + + topology topology( + input_layout("input0", input0->get_layout()), + input_layout("input1", input1->get_layout()), + input_layout("input2", input2->get_layout()), + input_layout("input3", input3->get_layout()), + input_layout("input4", input4->get_layout()), + concatenation("concat", + { "input0", "input1", "input2", "input3", "input4" }, + concatenation::concatenation_axis::along_f, + data_types::f32, + "", + padding{ { 0,0,0,0 }, 0 }) + ); + + build_options options_target; + options_target.set_option(build_option::outputs({ "concat" })); + implementation_desc impl = { format::bfyx, std::string(""), impl_types::onednn }; + options_target.set_option(build_option::force_implementations({ {"concat", impl} })); + + network network(engine, topology, options_target); + network.set_input_data("input0", input0); + network.set_input_data("input1", input1); + network.set_input_data("input2", input2); + network.set_input_data("input3", input3); + network.set_input_data("input4", input4); + + auto outputs = network.execute(); + EXPECT_EQ(outputs.size(), size_t(1)); + EXPECT_EQ(outputs.begin()->first, "concat"); + + auto output_memory = outputs.at("concat").get_memory(); + auto output_layout = output_memory->get_layout(); + cldnn::mem_lock output_ptr(output_memory, get_test_stream()); + + int y_size = output_layout.size.spatial[1]; + int x_size = output_layout.size.spatial[0]; + int f_size = output_layout.size.feature[0]; + int b_size = output_layout.size.batch[0]; + EXPECT_EQ(output_layout.format, format::bfyx); + EXPECT_EQ(y_size, 3); + EXPECT_EQ(x_size, 4); + EXPECT_EQ(f_size, 5); + EXPECT_EQ(b_size, 1); + + for (size_t x = 0; x < output_layout.count(); ++x) { + EXPECT_EQ(output_vec[x], output_ptr[x]); + } +} +#endif