[GPU] Added support mixed input formats for Select (#16009)
This commit is contained in:
@@ -108,22 +108,6 @@ select_inst::typed_primitive_inst(network& network, select_node const& node) : p
|
||||
3,
|
||||
"");
|
||||
|
||||
if (deps[1].first->get_output_layout().get_tensor() != cldnn::tensor(1))
|
||||
CLDNN_ERROR_NOT_EQUAL(node.id(),
|
||||
"Mask format",
|
||||
deps[0].first->get_output_layout().format,
|
||||
"Positive input format",
|
||||
deps[1].first->get_output_layout().format,
|
||||
"");
|
||||
|
||||
if (deps[2].first->get_output_layout().get_tensor() != cldnn::tensor(1))
|
||||
CLDNN_ERROR_NOT_EQUAL(node.id(),
|
||||
"Mask format",
|
||||
deps[0].first->get_output_layout().format,
|
||||
"Positive input format",
|
||||
deps[2].first->get_output_layout().format,
|
||||
"");
|
||||
|
||||
if (node.get_primitive()->broadcast_spec.m_type == ov::op::AutoBroadcastType::NONE) {
|
||||
CLDNN_ERROR_LAYOUT_MISMATCH(node.id(),
|
||||
"Positive input layout",
|
||||
@@ -139,14 +123,6 @@ select_inst::typed_primitive_inst(network& network, select_node const& node) : p
|
||||
deps[1].first->get_output_layout().get_tensor(),
|
||||
"");
|
||||
} else if (node.get_primitive()->broadcast_spec.m_type == ov::op::AutoBroadcastType::NUMPY) {
|
||||
if (deps[1].first->get_output_layout().get_tensor() != cldnn::tensor(1) && deps[2].first->get_output_layout().get_tensor() != cldnn::tensor(1))
|
||||
CLDNN_ERROR_NOT_EQUAL(node.id(),
|
||||
"Positive input format",
|
||||
deps[1].first->get_output_layout().format,
|
||||
"Negative input format",
|
||||
deps[2].first->get_output_layout().format,
|
||||
"");
|
||||
|
||||
CLDNN_ERROR_DATA_TYPES_MISMATCH(node.id(),
|
||||
"Positive input data type",
|
||||
deps[1].first->get_output_layout().data_type,
|
||||
|
||||
@@ -4,46 +4,25 @@
|
||||
|
||||
#include "include/batch_headers/fetch_data.cl"
|
||||
|
||||
#ifdef IS_DYNAMIC
|
||||
#define GET_INDEX(prefix) GET_DATA_INDEX_SAFE(prefix, d4, d3, d2, d1)
|
||||
#else
|
||||
#define GET_INDEX(prefix) \
|
||||
CAT(prefix, _OFFSET) + \
|
||||
(d1 % CAT(prefix, _SIZES)[0])*CAT(prefix, _PITCHES)[0] + \
|
||||
(d2 % CAT(prefix, _SIZES)[1])*CAT(prefix, _PITCHES)[1] + \
|
||||
(d3 % CAT(prefix, _SIZES)[2])*CAT(prefix, _PITCHES)[2] + \
|
||||
(d4 % CAT(prefix, _SIZES)[3])*CAT(prefix, _PITCHES)[3]
|
||||
#endif
|
||||
|
||||
#define INPUT_0 input0[GET_INDEX(INPUT0)]
|
||||
#define INPUT_1 input1[GET_INDEX(INPUT1)]
|
||||
#define INPUT_2 input2[GET_INDEX(INPUT2)]
|
||||
#define INPUT_0 input0[INPUT0_GET_INDEX_SAFE(b, f, y, x)]
|
||||
#define INPUT_1 input1[INPUT1_GET_INDEX_SAFE(b, f, y, x)]
|
||||
#define INPUT_2 input2[INPUT2_GET_INDEX_SAFE(b, f, y, x)]
|
||||
|
||||
KERNEL(select)(
|
||||
OPTIONAL_SHAPE_INFO_ARG
|
||||
INPUTS_DECLS
|
||||
__global OUTPUT_TYPE* output)
|
||||
{
|
||||
const uint x = (uint)get_global_id(0);
|
||||
const uint y = (uint)get_global_id(1);
|
||||
const uint bf = (uint)get_global_id(2);
|
||||
|
||||
const uint d1 = (uint) get_global_id(0);
|
||||
const uint d2 = (uint) get_global_id(1);
|
||||
const uint d34 = (uint) get_global_id(2);
|
||||
const uint b = bf % OUTPUT_BATCH_NUM;
|
||||
const uint f = bf / OUTPUT_BATCH_NUM;
|
||||
|
||||
#ifdef IS_DYNAMIC
|
||||
const uint d3 = d34 % OUTPUT_FEATURE_NUM;
|
||||
const uint d4 = d34 / OUTPUT_FEATURE_NUM;
|
||||
#else
|
||||
const uint d3 = d34 % OUTPUT_SIZES[2];
|
||||
const uint d4 = d34 / OUTPUT_SIZES[2];
|
||||
#endif
|
||||
uint output_offset = OUTPUT_GET_INDEX(b, f, y, x);
|
||||
|
||||
#ifdef IS_DYNAMIC
|
||||
uint output_offset = OUTPUT_GET_INDEX(d4, d3, d2, d1);
|
||||
#else
|
||||
uint output_offset = GET_DATA_INDEX_RAW(OUTPUT, d1, d2, d3, d4);
|
||||
#endif
|
||||
const OUTPUT_TYPE res = select(INPUT_2, INPUT_1, MASK);
|
||||
|
||||
const OUTPUT_TYPE res = select(INPUT_2, INPUT_1, MASK);
|
||||
|
||||
output[output_offset] = res;
|
||||
output[output_offset] = res;
|
||||
}
|
||||
|
||||
@@ -95,22 +95,16 @@ JitConstants SelectKernelBase::GetJitConstants(const select_params& params) cons
|
||||
|
||||
SelectKernelBase::DispatchData SelectKernelBase::SetDefault(const select_params& params) const {
|
||||
DispatchData dispatchData;
|
||||
|
||||
const auto& out = params.outputs[0];
|
||||
const auto& in = params.inputs[0];
|
||||
|
||||
std::vector<size_t> gws;
|
||||
for (const auto& o : out.GetDims()) {
|
||||
gws.push_back(o.v);
|
||||
}
|
||||
dispatchData.gws = { out.X().v, out.Y().v, out.Feature().v * out.Batch().v };
|
||||
|
||||
for (size_t i = gws.size(); i < 4; i++) {
|
||||
gws.push_back(1U);
|
||||
}
|
||||
std::vector<std::vector<Tensor::DataChannelName>> dims_by_gws = {{ Tensor::DataChannelName::X },
|
||||
{ Tensor::DataChannelName::Y },
|
||||
{ Tensor::DataChannelName::FEATURE, Tensor::DataChannelName::BATCH }};
|
||||
|
||||
dispatchData.gws[0] = gws[0];
|
||||
dispatchData.gws[1] = gws[1];
|
||||
dispatchData.gws[2] = gws[2] * gws[3];
|
||||
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo);
|
||||
dispatchData.lws = GetOptimalLocalWorkGroupSizes(dispatchData.gws, params.engineInfo, in.GetLayout(), out.GetLayout(), dims_by_gws);
|
||||
|
||||
return dispatchData;
|
||||
}
|
||||
|
||||
@@ -1264,22 +1264,6 @@ TEST(select_gpu_f32, select_basic_error_input_types) {
|
||||
EXPECT_ANY_THROW(network(engine, topology));
|
||||
}
|
||||
|
||||
TEST(select_gpu_f32, select_basic_error_input_formats) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto input = engine.allocate_memory({ data_types::f32, format::yxfb,{ 2, 2, 2, 2 } });
|
||||
auto input2 = engine.allocate_memory({ data_types::f32, format::yxfb,{ 2, 2, 2, 2 } });
|
||||
auto mask = engine.allocate_memory({ data_types::f32, format::bfyx,{ 2, 2, 2, 2 } });
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("input", input->get_layout()));
|
||||
topology.add(input_layout("input2", input2->get_layout()));
|
||||
topology.add(input_layout("mask", mask->get_layout()));
|
||||
topology.add(cldnn::select("select", input_info("mask"), input_info("input"), input_info("input2")));
|
||||
|
||||
EXPECT_ANY_THROW(network(engine, topology));
|
||||
}
|
||||
|
||||
TEST(select_gpu_f32, select_basic_byxf) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
@@ -2314,6 +2298,64 @@ TEST(select_gpu_fp32, select_numpy_broadcast_mask_u8_1x1x3) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(select_gpu_f32, select_different_formats) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto input1 = engine.allocate_memory({ data_types::f32, format::bfyx, { 2, 1, 2, 2 } });
|
||||
auto input2 = engine.allocate_memory({ data_types::f32, format::byxf, { 2, 1, 2, 2 } });
|
||||
auto mask = engine.allocate_memory({ data_types::f32, format::yxfb, { 1, 1, 2, 2 } });
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("input1", input1->get_layout()));
|
||||
topology.add(input_layout("input2", input2->get_layout()));
|
||||
topology.add(input_layout("mask", mask->get_layout()));
|
||||
topology.add(cldnn::select("select", input_info("mask"), input_info("input1"), input_info("input2")));
|
||||
|
||||
set_values(input1, {
|
||||
1.f, 2.f,
|
||||
3.f, 4.f,
|
||||
|
||||
5.f, 6.f,
|
||||
7.f, 8.f
|
||||
});
|
||||
|
||||
set_values(input2, {
|
||||
9.f, 10.f,
|
||||
11.f, 12.f,
|
||||
|
||||
13.f, 14.f,
|
||||
15.f, 16.f
|
||||
});
|
||||
|
||||
set_values(mask, {
|
||||
0.f, 0.f,
|
||||
1.f, 1.f
|
||||
});
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
network.set_input_data("input1", input1);
|
||||
network.set_input_data("input2", input2);
|
||||
network.set_input_data("mask", mask);
|
||||
auto outputs = network.execute();
|
||||
|
||||
auto output = outputs.at("select").get_memory();
|
||||
|
||||
std::vector<float> answers {
|
||||
9.f, 10.f,
|
||||
3.f, 4.f,
|
||||
|
||||
13.f, 14.f,
|
||||
7.f, 8.f
|
||||
};
|
||||
|
||||
cldnn::mem_lock<float> output_ptr(output, get_test_stream());
|
||||
|
||||
for (size_t i = 0; i < answers.size(); ++i) {
|
||||
ASSERT_EQ(answers[i], output_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(select_gpu_f32, dynamic) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user