[ONNX] don't hardcode shapes in Interpolate and Shape operator (#3778)
This commit is contained in:
parent
00d37aaa62
commit
7be7a8fb30
@ -148,29 +148,6 @@ namespace ngraph
|
||||
calculate_output_shape_based_on_scales(const Output<ngraph::Node>& data,
|
||||
const Output<ngraph::Node>& scales)
|
||||
{
|
||||
const auto& data_shape = data.get_partial_shape();
|
||||
const auto& scales_shape = scales.get_partial_shape();
|
||||
|
||||
if (ngraph::op::is_constant(scales.get_node()) && data_shape.is_static())
|
||||
{
|
||||
const auto scales_const =
|
||||
as_type_ptr<default_opset::Constant>(scales.get_node_shared_ptr());
|
||||
|
||||
const auto scales_vector = scales_const->cast_vector<float>();
|
||||
const auto data_static_shape = data_shape.to_shape();
|
||||
|
||||
std::vector<int64_t> output_shape;
|
||||
for (size_t i = 0; i < data_static_shape.size(); ++i)
|
||||
{
|
||||
output_shape.push_back(
|
||||
std::floor(data_static_shape.at(i) * scales_vector.at(i)));
|
||||
}
|
||||
auto output_shape_const = default_opset::Constant::create(
|
||||
element::u64, Shape({output_shape.size()}), output_shape);
|
||||
|
||||
return output_shape_const;
|
||||
}
|
||||
|
||||
const auto shape_of_data = std::make_shared<default_opset::Convert>(
|
||||
std::make_shared<default_opset::ShapeOf>(data), scales.get_element_type());
|
||||
const auto multiply =
|
||||
@ -185,33 +162,7 @@ namespace ngraph
|
||||
calculate_scales_based_on_sizes(const Output<ngraph::Node>& data,
|
||||
const Output<ngraph::Node>& sizes)
|
||||
{
|
||||
const auto& data_shape = data.get_partial_shape();
|
||||
const auto& sizes_shape = sizes.get_partial_shape();
|
||||
|
||||
const float epsilon = 1.0e-5;
|
||||
|
||||
if (ngraph::op::is_constant(sizes.get_node()) && data_shape.is_static())
|
||||
{
|
||||
const auto sizes_const =
|
||||
as_type_ptr<default_opset::Constant>(sizes.get_node_shared_ptr());
|
||||
|
||||
const auto sizes_vector = sizes_const->cast_vector<int64_t>();
|
||||
const auto data_static_shape = data_shape.to_shape();
|
||||
|
||||
std::vector<float> scales;
|
||||
for (size_t i = 0; i < data_static_shape.size(); ++i)
|
||||
{
|
||||
float scale = static_cast<float>(sizes_vector.at(i)) /
|
||||
static_cast<float>(data_static_shape.at(i)) +
|
||||
epsilon;
|
||||
scales.push_back(scale);
|
||||
}
|
||||
auto scales_const = default_opset::Constant::create(
|
||||
element::f32, Shape({scales.size()}), scales);
|
||||
|
||||
return scales_const;
|
||||
}
|
||||
|
||||
const auto shape_of_data = std::make_shared<default_opset::Convert>(
|
||||
std::make_shared<default_opset::ShapeOf>(data), ngraph::element::f32);
|
||||
const auto converted_sizes =
|
||||
|
@ -33,20 +33,7 @@ namespace ngraph
|
||||
OutputVector shape(const Node& node)
|
||||
{
|
||||
const auto data = node.get_ng_inputs().at(0);
|
||||
const auto data_shape = data.get_partial_shape();
|
||||
|
||||
if (data_shape.is_static())
|
||||
{
|
||||
const auto static_data_shape = data_shape.to_shape();
|
||||
|
||||
return {default_opset::Constant::create(ngraph::element::i64,
|
||||
Shape{static_data_shape.size()},
|
||||
static_data_shape)};
|
||||
}
|
||||
else
|
||||
{
|
||||
return {std::make_shared<default_opset::ShapeOf>(data)};
|
||||
}
|
||||
return {std::make_shared<default_opset::ShapeOf>(data)};
|
||||
}
|
||||
|
||||
} // namespace set_1
|
||||
|
@ -1,80 +0,0 @@
|
||||
ir_version: 6
|
||||
producer_name: "test_model"
|
||||
graph {
|
||||
node {
|
||||
output: "scales"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 4
|
||||
data_type: 1
|
||||
float_data: 4.0
|
||||
float_data: 3.0
|
||||
float_data: 2.0
|
||||
float_data: 1.0
|
||||
name: "scales_const"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "X"
|
||||
input: "scales"
|
||||
output: "output"
|
||||
op_type: "Resize"
|
||||
attribute {
|
||||
name: "mode"
|
||||
s: "nearest"
|
||||
type: STRING
|
||||
}
|
||||
}
|
||||
name: "test_model"
|
||||
input {
|
||||
name: "X"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "output"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
dim {
|
||||
dim_value: 6
|
||||
}
|
||||
dim {
|
||||
dim_value: 6
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 10
|
||||
}
|
@ -1122,21 +1122,6 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_reduce_sum_square)
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_resize10_import_only)
|
||||
{
|
||||
const auto resize_fn = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/resize_opset10.prototxt"));
|
||||
|
||||
// Input data shape (1, 2, 3, 4)
|
||||
// Scales input constant values {4, 3, 2, 1}
|
||||
|
||||
Shape expected_output_shape{4, 6, 6, 4};
|
||||
EXPECT_EQ(resize_fn->get_output_size(), 1);
|
||||
EXPECT_EQ(resize_fn->get_output_shape(0), expected_output_shape);
|
||||
EXPECT_EQ(count_ops_of_type<op::v0::Interpolate>(resize_fn), 1);
|
||||
EXPECT_EQ(count_ops_of_type<onnx_import::default_opset::Constant>(resize_fn), 1);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_resize11_empty_constant_as_input)
|
||||
{
|
||||
// this model contains a Constant node with an empty underlying tensor
|
||||
|
Loading…
Reference in New Issue
Block a user