[GPU] Small fix for gather_nonzero (#16858)
This commit is contained in:
@@ -41,16 +41,17 @@ JitConstants GatherNonzeroKernelRef::GetJitConstants(const gather_nonzero_params
|
||||
JitConstants jit = MakeBaseParamsJitConstants(params);
|
||||
const auto& input = params.inputs[0];
|
||||
jit.AddConstant(MakeJitConstant("OV_INPUT_RANK", params.ov_input_rank));
|
||||
|
||||
auto max_local_mem_size = params.engineInfo.maxLocalMemSize / (params.outputs[0].ElementSize());
|
||||
jit.AddConstant(MakeJitConstant("MAX_LOCAL_MEM_SIZE", max_local_mem_size));
|
||||
|
||||
if (input.is_dynamic()) {
|
||||
DimensionAccessHelper dims(input, 0);
|
||||
const std::string total_data_size = toVectorMulString({dims.x, dims.y, dims.z, dims.w, dims.f, dims.b});
|
||||
jit.AddConstant(MakeJitConstant("TOTAL_DATA_SIZE", total_data_size));
|
||||
jit.AddConstant(MakeJitConstant("MAX_LOCAL_MEM_SIZE", max_local_mem_size));
|
||||
} else {
|
||||
jit.AddConstant(MakeJitConstant("TOTAL_DATA_SIZE", params.inputs[0].LogicalSize()));
|
||||
if (params.inputs[0].LogicalSize() * params.ov_input_rank < max_local_mem_size) {
|
||||
jit.AddConstant(MakeJitConstant("MAX_LOCAL_MEM_SIZE", max_local_mem_size));
|
||||
jit.AddConstant(MakeJitConstant("USE_LOCAL_MEM", 1));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -448,3 +448,37 @@ TEST(test_non_zero, 6d_fp16_2_2_2_1_5_1) {
|
||||
};
|
||||
test_non_zero<int32_t>(layout{ov::PartialShape{2, 2, 2, 1, 5, 1}, data_types::i32, format::bfwzyx}, in_data);
|
||||
}
|
||||
|
||||
TEST(test_gather_non_zero, not_use_local_mem) {
|
||||
auto& engine = get_test_engine();
|
||||
auto max_local_mem_size = engine.get_device_info().max_local_mem_size;
|
||||
|
||||
auto in_layout = layout{ov::PartialShape{ov::Dimension(max_local_mem_size)}, data_types::f32, format::bfyx};
|
||||
auto input_mem = engine.allocate_memory(in_layout);
|
||||
auto in_data = std::vector<float>(max_local_mem_size, 1.f);
|
||||
set_values(input_mem, in_data);
|
||||
|
||||
auto output_shape_layout = layout{ov::PartialShape{1}, data_types::i32, format::bfyx};
|
||||
auto output_shape_mem = engine.allocate_memory(output_shape_layout);
|
||||
set_values(output_shape_mem, {static_cast<int32_t>(max_local_mem_size)});
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("input", in_layout));
|
||||
topology.add(data("output_shape", output_shape_mem));
|
||||
topology.add(gather_nonzero("gather_nonzero", input_info("input"), input_info("output_shape")));
|
||||
|
||||
network network(engine, topology, get_test_default_config(engine));
|
||||
|
||||
network.set_input_data("input", input_mem);
|
||||
|
||||
auto outputs = network.execute();
|
||||
auto output = outputs.at("gather_nonzero").get_memory();
|
||||
cldnn::mem_lock<int32_t> output_ptr(output, get_test_stream());
|
||||
|
||||
std::vector<int32_t> expected_results(max_local_mem_size);
|
||||
ngraph::runtime::reference::non_zero<float, int32_t>(in_data.data(), expected_results.data(), in_layout.get_shape());
|
||||
|
||||
for (size_t i = 0; i < expected_results.size(); ++i) {
|
||||
ASSERT_EQ(expected_results[i], output_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user