[GPU] Small fix for gather_nonzero (#16858)

This commit is contained in:
Roman Lyamin
2023-04-12 09:15:49 +04:00
committed by GitHub
parent 0312d8cf1b
commit f8aacf3b19
2 changed files with 37 additions and 2 deletions

View File

@@ -41,16 +41,17 @@ JitConstants GatherNonzeroKernelRef::GetJitConstants(const gather_nonzero_params
JitConstants jit = MakeBaseParamsJitConstants(params);
const auto& input = params.inputs[0];
jit.AddConstant(MakeJitConstant("OV_INPUT_RANK", params.ov_input_rank));
auto max_local_mem_size = params.engineInfo.maxLocalMemSize / (params.outputs[0].ElementSize());
jit.AddConstant(MakeJitConstant("MAX_LOCAL_MEM_SIZE", max_local_mem_size));
if (input.is_dynamic()) {
DimensionAccessHelper dims(input, 0);
const std::string total_data_size = toVectorMulString({dims.x, dims.y, dims.z, dims.w, dims.f, dims.b});
jit.AddConstant(MakeJitConstant("TOTAL_DATA_SIZE", total_data_size));
jit.AddConstant(MakeJitConstant("MAX_LOCAL_MEM_SIZE", max_local_mem_size));
} else {
jit.AddConstant(MakeJitConstant("TOTAL_DATA_SIZE", params.inputs[0].LogicalSize()));
if (params.inputs[0].LogicalSize() * params.ov_input_rank < max_local_mem_size) {
jit.AddConstant(MakeJitConstant("MAX_LOCAL_MEM_SIZE", max_local_mem_size));
jit.AddConstant(MakeJitConstant("USE_LOCAL_MEM", 1));
}
}

View File

@@ -448,3 +448,37 @@ TEST(test_non_zero, 6d_fp16_2_2_2_1_5_1) {
};
test_non_zero<int32_t>(layout{ov::PartialShape{2, 2, 2, 1, 5, 1}, data_types::i32, format::bfwzyx}, in_data);
}
TEST(test_gather_non_zero, not_use_local_mem) {
auto& engine = get_test_engine();
auto max_local_mem_size = engine.get_device_info().max_local_mem_size;
auto in_layout = layout{ov::PartialShape{ov::Dimension(max_local_mem_size)}, data_types::f32, format::bfyx};
auto input_mem = engine.allocate_memory(in_layout);
auto in_data = std::vector<float>(max_local_mem_size, 1.f);
set_values(input_mem, in_data);
auto output_shape_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx};
auto output_shape_mem = engine.allocate_memory(output_shape_layout);
set_values(output_shape_mem, {static_cast<int32_t>(max_local_mem_size)});
topology topology;
topology.add(input_layout("input", in_layout));
topology.add(data("output_shape", output_shape_mem));
topology.add(gather_nonzero("gather_nonzero", input_info("input"), input_info("output_shape")));
network network(engine, topology, get_test_default_config(engine));
network.set_input_data("input", input_mem);
auto outputs = network.execute();
auto output = outputs.at("gather_nonzero").get_memory();
cldnn::mem_lock<int32_t> output_ptr(output, get_test_stream());
std::vector<int32_t> expected_results(max_local_mem_size);
ngraph::runtime::reference::non_zero<float, int32_t>(in_data.data(), expected_results.data(), in_layout.get_shape());
for (size_t i = 0; i < expected_results.size(); ++i) {
ASSERT_EQ(expected_results[i], output_ptr[i]);
}
}