Enable tensor offset to GemmKernelRef for input padding support (#12133)
Signed-off-by: Andrew Park <andrew.park@intel.com>
This commit is contained in:
committed by
GitHub
parent
99dbb35835
commit
dc374ca1bf
@@ -26,6 +26,7 @@ ParamsKey GemmKernelRef::GetSupportedKey() const {
|
||||
k.EnableBatching();
|
||||
k.EnableDifferentTypes();
|
||||
k.EnableTensorPitches();
|
||||
k.EnableTensorOffset();
|
||||
k.EnableQuantization(QuantizationType::SYMMETRIC);
|
||||
|
||||
return k;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include <intel_gpu/primitives/input_layout.hpp>
|
||||
#include <intel_gpu/primitives/gemm.hpp>
|
||||
#include <intel_gpu/primitives/crop.hpp>
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
@@ -111,6 +112,61 @@ TEST(gemm_gpu, basic_bfyx_t2) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(gemm_gpu, basic_bfyx_t2_inplace_crop_with_pad) {
|
||||
auto& engine = get_test_engine();
|
||||
auto input = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 2, 4, 3 } });
|
||||
auto input2 = engine.allocate_memory({ data_types::f32, format::bfyx, { 1, 1, 4, 1 } });
|
||||
|
||||
std::vector<float> input_data = {
|
||||
1.f, -2.f, 3.f, -4.f,
|
||||
5.f, 6.f, 1.f, 2.f,
|
||||
3.f, 3.f, 2.f, -1.f,
|
||||
|
||||
1.f, -2.f, 3.f, -4.f,
|
||||
5.f, 6.f, 1.f, 2.f,
|
||||
3.f, 3.f, 2.f, -1.f,
|
||||
};
|
||||
|
||||
std::vector<float> input_data2 = {
|
||||
2.f, 5.f, -4.f, -7.f,
|
||||
};
|
||||
set_values(input, input_data);
|
||||
set_values(input2, input_data2);
|
||||
|
||||
std::vector<float> out_data = {
|
||||
8.f, 22.f, 20.f
|
||||
};
|
||||
|
||||
topology topology;
|
||||
topology.add(
|
||||
input_layout("input", input->get_layout())
|
||||
);
|
||||
topology.add(
|
||||
input_layout("input2", input2->get_layout())
|
||||
);
|
||||
topology.add(
|
||||
crop("crop.1", "input", { 1, 1, 4, 3 }, { 0, 1, 0, 0 })
|
||||
);
|
||||
topology.add(
|
||||
gemm("output", { "crop.1", "input2" }, data_types::f32, false, true)
|
||||
);
|
||||
|
||||
build_options options;
|
||||
options.set_option(build_option::optimize_data(true));
|
||||
network network(engine, topology, options);
|
||||
network.set_input_data("input", input);
|
||||
network.set_input_data("input2", input2);
|
||||
auto outputs = network.execute();
|
||||
|
||||
auto output = outputs.at("output").get_memory();
|
||||
cldnn::mem_lock<float> output_ptr(output, get_test_stream());
|
||||
|
||||
EXPECT_EQ(output_ptr.size(), (uint32_t)3);
|
||||
for (uint32_t i = 0; i < out_data.size(); ++i) {
|
||||
EXPECT_FLOAT_EQ(output_ptr[i], out_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(gemm_gpu, basic_bfyx_t1t2) {
|
||||
auto& engine = get_test_engine();
|
||||
auto input = engine.allocate_memory({ data_types::f32, format::bfyx, { 2, 1, 3, 4 } });
|
||||
|
||||
Reference in New Issue
Block a user