concatenation_onednn migration (#7756)

- update concatenation_onednn
- add unit test using onednn concat

Signed-off-by: hyunback <hyunback.kim@intel.com>
This commit is contained in:
hyunback kim 2021-10-05 09:16:32 +09:00 committed by GitHub
parent bd2b346c62
commit 21c70f8eb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 183 additions and 0 deletions

View File

@ -17,11 +17,110 @@
namespace cldnn {
namespace onednn {
struct concatenation_onednn : typed_primitive_onednn_impl<concatenation, void, dnnl::concat::primitive_desc, dnnl::concat> {
using parent = typed_primitive_onednn_impl<concatenation, void, dnnl::concat::primitive_desc, dnnl::concat>;
using parent::parent;
protected:
std::unique_ptr<primitive_impl> clone() const override {
return make_unique<concatenation_onednn>(*this);
}
std::unordered_map<int, dnnl::memory> get_arguments(concatenation_inst& instance) const override {
std::unordered_map<int, dnnl::memory> args;
int input_idx = DNNL_ARG_MULTIPLE_SRC;
for (size_t i = 0; i < instance.inputs_memory_count(); i++) {
auto& input = instance.input_memory(i);
args.insert({ input_idx++, input.get_onednn_memory(_pd.src_desc(static_cast<int>(i))) });
}
{
auto& output = instance.output_memory();
args.insert({DNNL_ARG_DST, output.get_onednn_memory(_pd.dst_desc())});
}
// TODO post operation
// configure_post_ops_arguments(instance, args);
return args;
}
static std::shared_ptr<dnnl::concat::primitive_desc> get_concatenation_descriptor(const concatenation_node& arg) {
auto prim = arg.get_primitive();
auto& engine = arg.get_program().get_engine();
std::vector<dnnl::memory::desc> input_mds;
for (auto& input : arg.get_dependencies()) {
input_mds.push_back(onednn::layout_to_memory_desc(input->get_output_layout()));
}
auto output_md = onednn::layout_to_memory_desc(arg.get_output_layout());
int axis = 0;
switch (prim->axis) {
case concatenation::concatenation_axis::along_b: axis = 0; break;
case concatenation::concatenation_axis::along_f: axis = 1; break;
case concatenation::concatenation_axis::along_y: axis = 2; break;
case concatenation::concatenation_axis::along_x: axis = 3; break;
default: throw std::runtime_error("unsupported concat axis");
}
return std::make_shared<dnnl::concat::primitive_desc>(
output_md,
axis,
input_mds,
engine.get_onednn_engine());
}
public:
static primitive_impl* create(const concatenation_node& arg) {
auto desc = get_concatenation_descriptor(arg);
auto attr = get_primitive_attributes(arg);
std::shared_ptr<void> dummy = nullptr;
return new concatenation_onednn(arg, dummy, attr, *desc);
}
};
namespace detail {
attach_concatenation_onednn::attach_concatenation_onednn() {
implementation_map<concatenation>::add(impl_types::onednn, concatenation_onednn::create, {
std::make_tuple(data_types::f32, format::bfyx),
std::make_tuple(data_types::f16, format::bfyx),
std::make_tuple(data_types::u8, format::bfyx),
std::make_tuple(data_types::i8, format::bfyx),
std::make_tuple(data_types::f32, format::b_fs_yx_fsv16),
std::make_tuple(data_types::f16, format::b_fs_yx_fsv16),
std::make_tuple(data_types::u8, format::b_fs_yx_fsv16),
std::make_tuple(data_types::i8, format::b_fs_yx_fsv16),
std::make_tuple(data_types::f32, format::b_fs_yx_fsv32),
std::make_tuple(data_types::f16, format::b_fs_yx_fsv32),
std::make_tuple(data_types::u8, format::b_fs_yx_fsv32),
std::make_tuple(data_types::i8, format::b_fs_yx_fsv32),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv16_fsv16),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv16),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv32_fsv32),
std::make_tuple(data_types::f32, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::f16, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::u8, format::bs_fs_yx_bsv4_fsv4),
std::make_tuple(data_types::i8, format::bs_fs_yx_bsv4_fsv4),
});
}
} // namespace detail

View File

@ -1038,6 +1038,16 @@ impl_types layout_optimizer::get_preferred_impl_type(program_node& node, format
// }
preferred_impl = impl_candidate;
} else if (node.is_type<concatenation>()) {
if (!_optimization_attributes.use_onednn_impls)
return impl_types::ocl;
for (auto& dep : node.get_dependencies()) {
if (dep->is_in_data_flow() && dep->get_preferred_impl_type() == impl_types::onednn) {
preferred_impl = impl_types::onednn;
break;
}
}
}
return preferred_impl;

View File

@ -783,3 +783,77 @@ INSTANTIATE_TEST_SUITE_P(smoke_low_precision,
TestParamType_concat(2, { 15, 2, 16, 64 }, 1, 2)
),
concat_gpu::PrintToStringParamName);
#ifdef ENABLE_ONEDNN_FOR_GPU
TEST(concat_gpu_onednn, basic_input_types) {
auto& engine = get_onednn_test_engine();
auto input0 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 3 } });
auto input1 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 3 } });
auto input2 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 3 } });
auto input3 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 3 } });
auto input4 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 3 } });
set_values<float>(input0, { 1.0f, 2.0f, 3.0f, 4.0f, 2.0f, 2.0f, 3.0f, 4.0f, 3.0f, 3.0f, 3.0f, 5.0f });
set_values<float>(input1, { 11.0f, 12.0f, 13.0f, 14.0f, 12.0f, 12.0f, 13.0f, 14.0f, 13.0f, 13.0f, 13.0f, 15.0f });
set_values<float>(input2, { 21.0f, 22.0f, 23.0f, 24.0f, 22.0f, 22.0f, 23.0f, 24.0f, 23.0f, 23.0f, 23.0f, 25.0f });
set_values<float>(input3, { 31.0f, 32.0f, 33.0f, 34.0f, 32.0f, 32.0f, 33.0f, 34.0f, 33.0f, 33.0f, 33.0f, 35.0f });
set_values<float>(input4, { 41.0f, 42.0f, 43.0f, 44.0f, 42.0f, 42.0f, 43.0f, 44.0f, 43.0f, 43.0f, 43.0f, 45.0f });
VF<float> output_vec = {
1.0f, 2.0f, 3.0f, 4.0f, 2.0f, 2.0f, 3.0f, 4.0f, 3.0f, 3.0f, 3.0f, 5.0f,
11.0f, 12.0f, 13.0f, 14.0f, 12.0f, 12.0f, 13.0f, 14.0f, 13.0f, 13.0f, 13.0f, 15.0f,
21.0f, 22.0f, 23.0f, 24.0f, 22.0f, 22.0f, 23.0f, 24.0f, 23.0f, 23.0f, 23.0f, 25.0f,
31.0f, 32.0f, 33.0f, 34.0f, 32.0f, 32.0f, 33.0f, 34.0f, 33.0f, 33.0f, 33.0f, 35.0f,
41.0f, 42.0f, 43.0f, 44.0f, 42.0f, 42.0f, 43.0f, 44.0f, 43.0f, 43.0f, 43.0f, 45.0f };
topology topology(
input_layout("input0", input0->get_layout()),
input_layout("input1", input1->get_layout()),
input_layout("input2", input2->get_layout()),
input_layout("input3", input3->get_layout()),
input_layout("input4", input4->get_layout()),
concatenation("concat",
{ "input0", "input1", "input2", "input3", "input4" },
concatenation::concatenation_axis::along_f,
data_types::f32,
"",
padding{ { 0,0,0,0 }, 0 })
);
build_options options_target;
options_target.set_option(build_option::outputs({ "concat" }));
implementation_desc impl = { format::bfyx, std::string(""), impl_types::onednn };
options_target.set_option(build_option::force_implementations({ {"concat", impl} }));
network network(engine, topology, options_target);
network.set_input_data("input0", input0);
network.set_input_data("input1", input1);
network.set_input_data("input2", input2);
network.set_input_data("input3", input3);
network.set_input_data("input4", input4);
auto outputs = network.execute();
EXPECT_EQ(outputs.size(), size_t(1));
EXPECT_EQ(outputs.begin()->first, "concat");
auto output_memory = outputs.at("concat").get_memory();
auto output_layout = output_memory->get_layout();
cldnn::mem_lock<float> output_ptr(output_memory, get_test_stream());
int y_size = output_layout.size.spatial[1];
int x_size = output_layout.size.spatial[0];
int f_size = output_layout.size.feature[0];
int b_size = output_layout.size.batch[0];
EXPECT_EQ(output_layout.format, format::bfyx);
EXPECT_EQ(y_size, 3);
EXPECT_EQ(x_size, 4);
EXPECT_EQ(f_size, 5);
EXPECT_EQ(b_size, 1);
for (size_t x = 0; x < output_layout.count(); ++x) {
EXPECT_EQ(output_vec[x], output_ptr[x]);
}
}
#endif