[GPU] Dynamic input support for Transpose (#12739)
This commit is contained in:
parent
bf2d6a72a4
commit
e7e6d7883c
@ -36,6 +36,7 @@ public:
|
||||
}
|
||||
return true;
|
||||
}
|
||||
std::vector<size_t> get_shape_infer_dependencies() const override { return {}; }
|
||||
};
|
||||
|
||||
using permute_node = typed_program_node<permute>;
|
||||
@ -45,6 +46,8 @@ class typed_primitive_inst<permute> : public typed_primitive_inst_base<permute>
|
||||
using parent = typed_primitive_inst_base<permute>;
|
||||
|
||||
public:
|
||||
template <typename ShapeType>
|
||||
static std::vector<layout> calc_output_layouts(permute_node const& /*node*/, kernel_impl_params const& impl_param);
|
||||
static layout calc_output_layout(permute_node const& node, kernel_impl_params const& impl_param);
|
||||
static std::string to_string(permute_node const& node);
|
||||
|
||||
|
@ -47,6 +47,40 @@ layout permute_inst::calc_output_layout(permute_node const& node, kernel_impl_pa
|
||||
return layout(input_layout.data_type, input_layout.format, output_size, op);
|
||||
}
|
||||
|
||||
template<typename ShapeType>
|
||||
std::vector<layout> permute_inst::calc_output_layouts(permute_node const& /*node*/, kernel_impl_params const& impl_param) {
|
||||
auto desc = impl_param.typed_desc<permute>();
|
||||
auto input_layout = impl_param.get_input_layout();
|
||||
|
||||
auto output_type = input_layout.data_type;
|
||||
if (impl_param.has_fused_primitives()) {
|
||||
output_type = impl_param.get_fused_output_layout().data_type;
|
||||
}
|
||||
|
||||
ShapeType input_shape = input_layout.get<ShapeType>();
|
||||
ShapeType output_shape;
|
||||
ov::Rank input_rank = input_shape.rank();
|
||||
|
||||
if (input_rank.is_dynamic()) {
|
||||
output_shape = ShapeType::dynamic(desc->permute_order.size());
|
||||
return { layout{output_shape, output_type, input_layout.format} };
|
||||
}
|
||||
|
||||
int64_t input_static_rank = input_rank.get_length();
|
||||
auto permute_order = desc->permute_order;
|
||||
if (permute_order.empty()) {
|
||||
for (int64_t i = 1; i <= input_static_rank; ++i) {
|
||||
permute_order.emplace_back(input_static_rank - i);
|
||||
}
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < input_static_rank; ++i) {
|
||||
output_shape.push_back(input_shape[permute_order[i]]);
|
||||
}
|
||||
|
||||
return { layout{output_shape, output_type, input_layout.format, desc->output_padding} };
|
||||
}
|
||||
|
||||
std::string permute_inst::to_string(permute_node const& node) {
|
||||
auto desc = node.get_primitive();
|
||||
auto node_info = node.desc_to_json();
|
||||
|
@ -0,0 +1,75 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_utils.h"
|
||||
|
||||
#include <intel_gpu/primitives/input_layout.hpp>
|
||||
#include <intel_gpu/primitives/permute.hpp>
|
||||
#include <intel_gpu/primitives/data.hpp>
|
||||
|
||||
#include "permute_inst.h"
|
||||
|
||||
#include "program_wrapper.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
using namespace cldnn;
|
||||
using namespace ::tests;
|
||||
|
||||
namespace shape_infer_tests {
|
||||
|
||||
struct transpose_test_params {
|
||||
layout data_layout;
|
||||
std::vector<uint16_t> permute_order_data;
|
||||
layout expected_layout;
|
||||
};
|
||||
|
||||
class transpose_test : public testing::TestWithParam<transpose_test_params> { };
|
||||
|
||||
TEST_P(transpose_test, shape_infer) {
|
||||
auto p = GetParam();
|
||||
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
auto data_layout_prim = std::make_shared<input_layout>("data", p.data_layout);
|
||||
auto permute_prim = std::make_shared<permute>("output", "data", p.permute_order_data);
|
||||
|
||||
cldnn::program prog(engine);
|
||||
|
||||
auto& data_node = prog.get_or_create(data_layout_prim);
|
||||
auto& permute_node = prog.get_or_create(permute_prim);
|
||||
program_wrapper::add_connection(prog, data_node, permute_node);
|
||||
|
||||
auto res = permute_inst::calc_output_layouts<ov::PartialShape>(permute_node, *permute_node.get_kernel_impl_params());
|
||||
|
||||
ASSERT_EQ(res.size(), 1);
|
||||
ASSERT_EQ(res[0], p.expected_layout);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke, transpose_test,
|
||||
testing::ValuesIn(std::vector<transpose_test_params>{
|
||||
{
|
||||
layout{ov::PartialShape{2, 3, 4}, data_types::f32, format::bfyx},
|
||||
{2, 0, 1},
|
||||
layout{ov::PartialShape{4, 2, 3}, data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape{2, 3, 4}, data_types::f32, format::bfyx},
|
||||
{},
|
||||
layout{ov::PartialShape{4, 3, 2}, data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape::dynamic(), data_types::f32, format::bfyx},
|
||||
{0, 1, 2},
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx}
|
||||
},
|
||||
{
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx},
|
||||
{},
|
||||
layout{ov::PartialShape::dynamic(3), data_types::f32, format::bfyx}
|
||||
}
|
||||
}));
|
||||
|
||||
} // shape_infer_tests
|
Loading…
Reference in New Issue
Block a user