add input0(cond) broadcast in select (#14267)

This commit is contained in:
Wilson Seok
2022-11-29 17:42:23 +09:00
committed by GitHub
parent 03583d93cb
commit dd0e76b384
3 changed files with 63 additions and 1 deletions

View File

@@ -25,6 +25,9 @@ layout select_inst::calc_output_layout(select_node const& node, kernel_impl_para
auto input1_size = impl_param.get_input_layout(1).get_tensor();
auto input2_size = impl_param.get_input_layout(2).get_tensor();
output_size = tensor::max(input1_size, input2_size);
// Cond input0 also can be broadcasted.
auto input0_size = impl_param.get_input_layout(0).get_tensor();
output_size = tensor::max(input0_size, output_size);
}
return layout(in_layout.data_type, in_layout.format, output_size);
@@ -139,6 +142,10 @@ select_inst::typed_primitive_inst(network& network, select_node const& node) : p
auto dep1_size = deps[1]->get_output_layout().get_tensor();
auto dep2_size = deps[2]->get_output_layout().get_tensor();
cldnn::tensor output_tensor = tensor::max(dep1_size, dep2_size);
// Cond input0 also can be broadcasted.
auto dep0_size = deps[0]->get_output_layout().get_tensor();
output_tensor = tensor::max(dep0_size, output_tensor);
auto max_dim_count = output_tensor.raw.size();
for (size_t i = 0; i < deps.size(); i++) {

View File

@@ -2240,3 +2240,57 @@ TEST(select_gpu_u8, select_basic_mask_i8_1x1x2x2) {
EXPECT_EQ(answers[i], output_ptr[i]);
}
}
TEST(select_gpu_fp32, select_numpy_broadcast_mask_u8_1x1x3) {
auto& engine = get_test_engine();
auto input = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 3, 1, 1 } });
auto input2 = engine.allocate_memory({ data_types::f32, format::bfyx, { 3, 1, 1, 1 } });
auto mask = engine.allocate_memory({ data_types::u8, format::bfyx, { 1, 1, 3, 1 } });
topology topology;
topology.add(input_layout("input", input->get_layout()));
topology.add(input_layout("input2", input2->get_layout()));
topology.add(input_layout("mask", mask->get_layout()));
topology.add(cldnn::select("select", "mask", "input", "input2"));
set_values(input, {
1.f, 0.f, 2.f
});
set_values(input2, {
0.5f, 2.5f, 5.f
});
set_values<unsigned char>(mask, {
1, 0, 1
});
network network(engine, topology);
network.set_input_data("input", input);
network.set_input_data("input2", input2);
network.set_input_data("mask", mask);
auto outputs = network.execute();
auto output = outputs.at("select").get_memory();
float answers[27] = {
1.f, 0.5f, 1.f,
0.f, 0.5f, 0.f,
2.f, 0.5f, 2.f,
1.f, 2.5f, 1.f,
0.f, 2.5f, 0.f,
2.f, 2.5f, 2.f,
1.f, 5.f, 1.f,
0.f, 5.f, 0.f,
2.f, 5.f, 2.f
};
cldnn::mem_lock<float> output_ptr(output, get_test_stream());
for (int i = 0; i < 27; i++)
{
EXPECT_EQ(answers[i], output_ptr[i]);
}
}

View File

@@ -53,7 +53,8 @@ const std::vector<std::vector<std::vector<size_t>>> numpyShapes = {
{{1, 3, 1}, {8, 2, 3, 1}, {3, 9}},
{{5, 1, 8}, {2, 1, 9, 8}, {2, 5, 9, 8}},
{{6, 1, 1, 8}, {6, 7, 1, 8}, {2, 1}},
{{5, 1, 1, 1}, {5, 7, 8, 6}, {1, 8, 6}}
{{5, 1, 1, 1}, {5, 7, 8, 6}, {1, 8, 6}},
{{1, 1, 3}, {1, 3, 1}, {3, 1, 1}}
};
const auto numpyCases = ::testing::Combine(