[GPU] Squeeze/Unsqueeze shape infer support (#12724)
This commit is contained in:
parent
85549cb404
commit
93d82cde4a
@ -22,6 +22,12 @@ namespace cldnn {
|
||||
struct reshape : public primitive_base<reshape> {
|
||||
CLDNN_DECLARE_PRIMITIVE(reshape)
|
||||
|
||||
enum reshape_mode : uint32_t {
|
||||
base,
|
||||
squeeze,
|
||||
unsqueeze
|
||||
};
|
||||
|
||||
/// @brief Constructs reshape primitive.
|
||||
/// @param id This primitive id.
|
||||
/// @param input Input primitive id.
|
||||
@ -32,11 +38,13 @@ struct reshape : public primitive_base<reshape> {
|
||||
reshape(const primitive_id& id,
|
||||
const primitive_id& input,
|
||||
const tensor& output_shape,
|
||||
reshape_mode mode = reshape_mode::base,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {input}, output_padding)
|
||||
, output_shape(output_shape)
|
||||
, output_pattern({})
|
||||
, output_partial_shape({}) {}
|
||||
, output_partial_shape({})
|
||||
, mode(mode) {}
|
||||
|
||||
/// @brief reshape with dynamic pattern
|
||||
reshape(const primitive_id& id,
|
||||
@ -44,12 +52,14 @@ struct reshape : public primitive_base<reshape> {
|
||||
const primitive_id& pattern_id,
|
||||
bool special_zero,
|
||||
const ov::PartialShape& output_partial_shape,
|
||||
reshape_mode mode = reshape_mode::base,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {input, pattern_id}, output_padding)
|
||||
, output_shape(tensor())
|
||||
, special_zero(special_zero)
|
||||
, output_pattern({})
|
||||
, output_partial_shape(output_partial_shape) {}
|
||||
, output_partial_shape(output_partial_shape)
|
||||
, mode(mode) {}
|
||||
|
||||
/// @brief reshape with static pattern
|
||||
reshape(const primitive_id& id,
|
||||
@ -57,12 +67,14 @@ struct reshape : public primitive_base<reshape> {
|
||||
bool special_zero,
|
||||
const std::vector<int64_t>& output_pattern,
|
||||
const ov::PartialShape& output_partial_shape,
|
||||
reshape_mode mode = reshape_mode::base,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {input}, output_padding)
|
||||
, output_shape(tensor())
|
||||
, special_zero(special_zero)
|
||||
, output_pattern(output_pattern)
|
||||
, output_partial_shape(output_partial_shape) {}
|
||||
, output_partial_shape(output_partial_shape)
|
||||
, mode(mode) {}
|
||||
|
||||
/// @brief Requested memory shape.
|
||||
tensor output_shape;
|
||||
@ -72,6 +84,8 @@ struct reshape : public primitive_base<reshape> {
|
||||
std::vector<int64_t> output_pattern;
|
||||
|
||||
ov::PartialShape output_partial_shape;
|
||||
|
||||
reshape_mode mode;
|
||||
};
|
||||
|
||||
/// @}
|
||||
|
@ -33,6 +33,8 @@ public:
|
||||
return false;
|
||||
return (!this->get_output_layout().data_padding && !input().get_output_layout(false).data_padding);
|
||||
}
|
||||
|
||||
std::vector<size_t> get_shape_infer_dependencies() const override { return {1}; }
|
||||
};
|
||||
|
||||
using reshape_node = typed_program_node<reshape>;
|
||||
|
@ -63,9 +63,6 @@ std::vector<layout> reshape_inst::calc_output_layouts(reshape_node const& /*node
|
||||
return { layout{prim->output_partial_shape, input_layout.data_type, format::adjust_to_rank(input_layout.format, prim->output_partial_shape.size())} };
|
||||
}
|
||||
|
||||
ov::op::v1::Reshape op;
|
||||
op.set_special_zero(prim->special_zero);
|
||||
|
||||
ShapeType pattern_shape = impl_param.input_layouts.size() == 2 ? impl_param.get_input_layout(1).get<ShapeType>()
|
||||
: ShapeType(ov::Shape{ prim->output_pattern.size() });
|
||||
std::vector<ShapeType> output_shapes = {ShapeType()};
|
||||
@ -74,6 +71,31 @@ std::vector<layout> reshape_inst::calc_output_layouts(reshape_node const& /*node
|
||||
pattern_shape,
|
||||
};
|
||||
|
||||
std::map<size_t, ngraph::HostTensorPtr> const_data;
|
||||
|
||||
auto run_shape_infer = [&](reshape::reshape_mode mode) {
|
||||
switch (mode) {
|
||||
case reshape::reshape_mode::base: {
|
||||
ov::op::v1::Reshape op;
|
||||
op.set_special_zero(prim->special_zero);
|
||||
shape_infer(&op, input_shapes, output_shapes, const_data);
|
||||
break;
|
||||
}
|
||||
case reshape::reshape_mode::squeeze: {
|
||||
ov::op::v0::Squeeze op;
|
||||
shape_infer(&op, input_shapes, output_shapes, const_data);
|
||||
break;
|
||||
}
|
||||
case reshape::reshape_mode::unsqueeze: {
|
||||
ov::op::v0::Unsqueeze op;
|
||||
shape_infer(&op, input_shapes, output_shapes, const_data);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
OPENVINO_ASSERT("Unsupported reshape mode");
|
||||
}
|
||||
};
|
||||
|
||||
if (!memory_deps.empty()) {
|
||||
auto pattern_mem = memory_deps.at(1);
|
||||
|
||||
@ -82,19 +104,14 @@ std::vector<layout> reshape_inst::calc_output_layouts(reshape_node const& /*node
|
||||
auto pattern_ptr = pattern_lock.data();
|
||||
auto pattern_tensor = make_host_tensor(pattern_mem->get_layout(), pattern_ptr);
|
||||
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> const_data = {
|
||||
{1, pattern_tensor},
|
||||
};
|
||||
|
||||
shape_infer(&op, input_shapes, output_shapes, const_data);
|
||||
const_data.emplace(1, pattern_tensor);
|
||||
run_shape_infer(prim->mode);
|
||||
} else {
|
||||
auto pattern_data = prim->output_pattern;
|
||||
auto pattern_tensor = make_host_tensor({pattern_shape, data_types::i64, format::bfyx}, static_cast<void*>(pattern_data.data()));
|
||||
std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>> const_data = {
|
||||
{1, pattern_tensor},
|
||||
};
|
||||
|
||||
shape_infer(&op, input_shapes, output_shapes, const_data);
|
||||
const_data.emplace(1, pattern_tensor);
|
||||
run_shape_infer(prim->mode);
|
||||
}
|
||||
|
||||
return { layout{output_shapes[0], input_layout.data_type, format::adjust_to_rank(input_layout.format, output_shapes[0].size())} };
|
||||
|
@ -15,7 +15,7 @@
|
||||
namespace ov {
|
||||
namespace intel_gpu {
|
||||
|
||||
static void CreateCommonReshapeOp(Program& p, const std::shared_ptr<ngraph::Node>& op) {
|
||||
static void CreateCommonReshapeOp(Program& p, const std::shared_ptr<ngraph::Node>& op, cldnn::reshape::reshape_mode mode) {
|
||||
validate_inputs_count(op, {1, 2});
|
||||
auto inputPrimitives = p.GetInputPrimitiveIDs(op);
|
||||
std::string layerName = layer_type_name_ID(op);
|
||||
@ -50,21 +50,22 @@ static void CreateCommonReshapeOp(Program& p, const std::shared_ptr<ngraph::Node
|
||||
|
||||
auto reshapePrim = cldnn::reshape(layerName,
|
||||
reshapeInputId,
|
||||
outTensor);
|
||||
outTensor,
|
||||
mode);
|
||||
|
||||
p.add_primitive(*op, reshapePrim);
|
||||
}
|
||||
|
||||
static void CreateReshapeOp(Program& p, const std::shared_ptr<ngraph::op::v1::Reshape>& op) {
|
||||
CreateCommonReshapeOp(p, op);
|
||||
CreateCommonReshapeOp(p, op, cldnn::reshape::reshape_mode::base);
|
||||
}
|
||||
|
||||
static void CreateSqueezeOp(Program& p, const std::shared_ptr<ngraph::op::v0::Squeeze>& op) {
|
||||
CreateCommonReshapeOp(p, op);
|
||||
CreateCommonReshapeOp(p, op, cldnn::reshape::reshape_mode::squeeze);
|
||||
}
|
||||
|
||||
static void CreateUnsqueezeOp(Program& p, const std::shared_ptr<ngraph::op::v0::Unsqueeze>& op) {
|
||||
CreateCommonReshapeOp(p, op);
|
||||
CreateCommonReshapeOp(p, op, cldnn::reshape::reshape_mode::unsqueeze);
|
||||
}
|
||||
|
||||
REGISTER_FACTORY_IMPL(v1, Reshape);
|
||||
|
@ -29,7 +29,7 @@ struct reshape_test_params {
|
||||
layout expected_layout;
|
||||
};
|
||||
|
||||
class reshape_test_two_inputs : public testing::TestWithParam<reshape_test_params> { };
|
||||
class reshape_test_two_inputs : public testing::TestWithParam<reshape_test_params> {};
|
||||
TEST_P(reshape_test_two_inputs, shape_infer) {
|
||||
auto p = GetParam();
|
||||
|
||||
@ -83,7 +83,7 @@ INSTANTIATE_TEST_SUITE_P(smoke, reshape_test_two_inputs,
|
||||
},
|
||||
}));
|
||||
|
||||
class reshape_test_single_input : public testing::TestWithParam<reshape_test_params> { };
|
||||
class reshape_test_single_input : public testing::TestWithParam<reshape_test_params> {};
|
||||
TEST_P(reshape_test_single_input, shape_infer) {
|
||||
auto p = GetParam();
|
||||
|
||||
@ -127,4 +127,112 @@ INSTANTIATE_TEST_SUITE_P(smoke, reshape_test_single_input,
|
||||
},
|
||||
}));
|
||||
|
||||
struct squeeze_unsqueeze_test_params {
|
||||
layout in_layout;
|
||||
layout indices_layout;
|
||||
std::vector<int64_t> indices_data;
|
||||
ov::PartialShape output_partial_shape;
|
||||
layout expected_layout;
|
||||
};
|
||||
|
||||
class squeeze_test : public testing::TestWithParam<squeeze_unsqueeze_test_params> {};
|
||||
TEST_P(squeeze_test, shape_infer) {
|
||||
auto p = GetParam();
|
||||
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto input_prim = std::make_shared<input_layout>("input", p.in_layout);
|
||||
auto indices_prim = std::make_shared<input_layout>("pattern", p.indices_layout);
|
||||
auto squeeze_prim = std::make_shared<reshape>("output", "input", "pattern",
|
||||
false, p.output_partial_shape,
|
||||
reshape::reshape_mode::squeeze);
|
||||
cldnn::program prog(engine);
|
||||
|
||||
auto indices_mem = engine.allocate_memory(p.indices_layout);
|
||||
set_values(indices_mem, p.indices_data);
|
||||
|
||||
auto& input_node = prog.get_or_create(input_prim);
|
||||
auto& indices_node = prog.get_or_create(indices_prim);
|
||||
auto& squeeze_node = prog.get_or_create(squeeze_prim);
|
||||
program_wrapper::add_connection(prog, input_node, squeeze_node);
|
||||
program_wrapper::add_connection(prog, indices_node, squeeze_node);
|
||||
auto params = squeeze_node.get_kernel_impl_params();
|
||||
|
||||
auto res_wo_data = reshape_inst::calc_output_layouts<ov::PartialShape>(squeeze_node, *params);
|
||||
|
||||
params->memory_deps = {{1, indices_mem}};
|
||||
auto res_w_data = reshape_inst::calc_output_layouts<ov::PartialShape>(squeeze_node, *params);
|
||||
|
||||
layout expected_layout_wo_data{p.output_partial_shape, p.expected_layout.data_type, p.expected_layout.format};
|
||||
ASSERT_EQ(res_wo_data.size(), 1);
|
||||
ASSERT_EQ(res_wo_data[0], expected_layout_wo_data);
|
||||
|
||||
ASSERT_EQ(res_w_data.size(), 1);
|
||||
ASSERT_EQ(res_w_data[0], p.expected_layout);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke, squeeze_test,
|
||||
testing::ValuesIn(std::vector<squeeze_unsqueeze_test_params>{
|
||||
{
|
||||
layout{ov::PartialShape{1, 3, 1, 2}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{2}, data_types::i64, format::bfyx}, {0, 2}, ov::PartialShape::dynamic(2),
|
||||
layout{ov::PartialShape{3, 2}, data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape{1}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{1}, data_types::i64, format::bfyx}, {0}, ov::PartialShape::dynamic(0),
|
||||
layout{ov::PartialShape{}, data_types::f32, format::bfyx}
|
||||
}
|
||||
}));
|
||||
|
||||
class unsqueeze_test : public testing::TestWithParam<squeeze_unsqueeze_test_params> { };
|
||||
TEST_P(unsqueeze_test, shape_infer) {
|
||||
auto p = GetParam();
|
||||
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto input_prim = std::make_shared<input_layout>("input", p.in_layout);
|
||||
auto indices_prim = std::make_shared<input_layout>("pattern", p.indices_layout);
|
||||
auto unsqueeze_prim = std::make_shared<reshape>("output", "input", "pattern",
|
||||
false, p.output_partial_shape,
|
||||
reshape::reshape_mode::unsqueeze);
|
||||
cldnn::program prog(engine);
|
||||
|
||||
auto indices_mem = engine.allocate_memory(p.indices_layout);
|
||||
set_values(indices_mem, p.indices_data);
|
||||
|
||||
auto& input_node = prog.get_or_create(input_prim);
|
||||
auto& indices_node = prog.get_or_create(indices_prim);
|
||||
auto& unsqueeze_node = prog.get_or_create(unsqueeze_prim);
|
||||
program_wrapper::add_connection(prog, input_node, unsqueeze_node);
|
||||
program_wrapper::add_connection(prog, indices_node, unsqueeze_node);
|
||||
auto params = unsqueeze_node.get_kernel_impl_params();
|
||||
|
||||
auto res_wo_data = reshape_inst::calc_output_layouts<ov::PartialShape>(unsqueeze_node, *params);
|
||||
|
||||
params->memory_deps = {{1, indices_mem}};
|
||||
auto res_w_data = reshape_inst::calc_output_layouts<ov::PartialShape>(unsqueeze_node, *params);
|
||||
|
||||
layout expected_layout_wo_data{p.output_partial_shape, p.expected_layout.data_type, p.expected_layout.format};
|
||||
ASSERT_EQ(res_wo_data.size(), 1);
|
||||
ASSERT_EQ(res_wo_data[0], expected_layout_wo_data);
|
||||
|
||||
ASSERT_EQ(res_w_data.size(), 1);
|
||||
ASSERT_EQ(res_w_data[0], p.expected_layout);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke, unsqueeze_test,
|
||||
testing::ValuesIn(std::vector<squeeze_unsqueeze_test_params>{
|
||||
{
|
||||
layout{ov::PartialShape{2, 3}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{2}, data_types::i64, format::bfyx}, {0, 3}, ov::PartialShape::dynamic(4),
|
||||
layout{ov::PartialShape{1, 2, 3, 1}, data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape{}, data_types::f32, format::bfyx},
|
||||
layout{ov::PartialShape{1}, data_types::i64, format::bfyx}, {0}, ov::PartialShape::dynamic(1),
|
||||
layout{ov::PartialShape{1}, data_types::f32, format::bfyx}
|
||||
}
|
||||
}));
|
||||
|
||||
} // shape_infer_tests
|
||||
|
@ -62,7 +62,7 @@ void generic_reshape_test(format fmt, tensor const& input_size, tensor const& re
|
||||
tpl.add(reorder("reorder", "input", padded_input_layout));
|
||||
reshape_input = "reorder";
|
||||
}
|
||||
tpl.add(reshape("reshape", reshape_input, reshape_size, output_padd));
|
||||
tpl.add(reshape("reshape", reshape_input, reshape_size, cldnn::reshape::reshape_mode::base, output_padd));
|
||||
|
||||
build_options bo;
|
||||
bo.set_option(build_option::outputs({reshape_input, "reshape"}));
|
||||
@ -525,7 +525,7 @@ TEST(reshape_gpu_f32, basic_bfwzyx) {
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("input", input->get_layout()));
|
||||
topology.add(reshape("reshape", "input", tensor(batch(1), feature(1), spatial(2, 2, 3, 3)), padding({0, 0, 0, 0, 0, 1}, 0.f)));
|
||||
topology.add(reshape("reshape", "input", tensor(batch(1), feature(1), spatial(2, 2, 3, 3)), cldnn::reshape::reshape_mode::base, padding({0, 0, 0, 0, 0, 1}, 0.f)));
|
||||
|
||||
// clang-format off
|
||||
std::vector<float> input_data = {
|
||||
|
Loading…
Reference in New Issue
Block a user