[IE CLDNN] Fix unsupported dims number error (#2453)

This commit is contained in:
Sergey Shlyapnikov 2020-10-02 16:51:09 +03:00 committed by GitHub
parent 20c20ad87a
commit 6b456e58a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 36 additions and 12 deletions

View File

@ -130,6 +130,7 @@ inline cldnn::format ImageFormatFromLayout(InferenceEngine::Layout l) {
inline cldnn::format defaultFormatForDims(size_t dimensions) {
switch (dimensions) {
case 0:
case 1:
case 2:
case 3:

View File

@ -85,53 +85,66 @@ attach_scale_gpu::attach_scale_gpu() {
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::yxfb), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::yxfb), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::yxfb), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::byxf), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::byxf), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::byxf), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfyx), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfyx), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfyx), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::bfyx), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfyx), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfzyx), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfzyx), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfzyx), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::bfzyx), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfzyx), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bfwzyx), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bfwzyx), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bfwzyx), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::bfwzyx), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::bfwzyx), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::b_fs_yx_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::b_fs_yx_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::b_fs_yx_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_yx_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_yx_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::b_fs_zyx_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::b_fs_zyx_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::b_fs_zyx_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_zyx_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_zyx_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bs_fs_zyx_bsv16_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bs_fs_zyx_bsv16_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bs_fs_zyx_bsv16_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::fs_b_yx_fsv32), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::fs_b_yx_fsv32), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::bs_fs_yx_bsv16_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::bs_fs_yx_bsv16_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::bs_fs_yx_bsv16_fsv16), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_yx_fsv4), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_yx_fsv4), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::b_fs_yx_fsv4), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::b_fs_yx_fsv4), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::b_fs_yx_fsv4), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_yx_fsv32), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_yx_fsv32), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::b_fs_yx_fsv32), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::b_fs_yx_fsv32), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::b_fs_yx_fsv32), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::u8, format::b_fs_zyx_fsv32), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i8, format::b_fs_zyx_fsv32), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f16, format::b_fs_zyx_fsv32), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::f32, format::b_fs_zyx_fsv32), val_fw);
implementation_map<scale>::add(std::make_tuple(engine_types::ocl, data_types::i32, format::b_fs_zyx_fsv32), val_fw);
}
} // namespace detail

View File

@ -67,12 +67,21 @@ select_inst::typed_primitive_inst(network_impl& network, select_node const& node
3,
"");
CLDNN_ERROR_NOT_EQUAL(node.id(),
"Mask format",
deps[0]->get_output_layout().format,
"Positive input format",
deps[1]->get_output_layout().format,
"");
if (deps[1]->get_output_layout().size != cldnn::tensor(1))
CLDNN_ERROR_NOT_EQUAL(node.id(),
"Mask format",
deps[0]->get_output_layout().format,
"Positive input format",
deps[1]->get_output_layout().format,
"");
if (deps[2]->get_output_layout().size != cldnn::tensor(1))
CLDNN_ERROR_NOT_EQUAL(node.id(),
"Mask format",
deps[0]->get_output_layout().format,
"Positive input format",
deps[2]->get_output_layout().format,
"");
if (node.get_primitive()->broadcast_type == "none") {
CLDNN_ERROR_LAYOUT_MISMATCH(node.id(),
@ -89,12 +98,13 @@ select_inst::typed_primitive_inst(network_impl& network, select_node const& node
deps[1]->get_output_layout().size,
"");
} else if (node.get_primitive()->broadcast_type == "numpy") {
CLDNN_ERROR_NOT_EQUAL(node.id(),
"Positive input format",
deps[1]->get_output_layout().format,
"Negative input format",
deps[2]->get_output_layout().format,
"");
if (deps[1]->get_output_layout().size != cldnn::tensor(1) && deps[2]->get_output_layout().size != cldnn::tensor(1))
CLDNN_ERROR_NOT_EQUAL(node.id(),
"Positive input format",
deps[1]->get_output_layout().format,
"Negative input format",
deps[2]->get_output_layout().format,
"");
CLDNN_ERROR_DATA_TYPES_MISMATCH(node.id(),
"Positive input data type",