Add support for input dynamic shape in ONNX LpNorm operator (#4613)
This commit is contained in:
parent
fc589572a1
commit
c99b6feea2
@ -42,8 +42,6 @@ namespace ngraph
|
||||
const auto data_shape = data.get_partial_shape();
|
||||
const auto data_rank = data_shape.rank();
|
||||
|
||||
CHECK_VALID_NODE(
|
||||
node, data_shape.is_static(), "Data shape must be static for lp_norm op");
|
||||
const auto data_rank_value = data_rank.get_length();
|
||||
const std::int64_t p_norm{node.get_attribute_value<std::int64_t>("p", 2)};
|
||||
|
||||
@ -62,8 +60,7 @@ namespace ngraph
|
||||
std::shared_ptr<ngraph::Node> norm = ngraph::builder::opset1::lp_norm(
|
||||
data, normalize_axis_const, static_cast<std::size_t>(p_norm));
|
||||
|
||||
const auto target_shape = default_opset::Constant::create(
|
||||
element::i64, Shape{size_t(data_rank_value)}, data_shape.to_shape());
|
||||
const auto target_shape = std::make_shared<default_opset::ShapeOf>(data);
|
||||
|
||||
// Create a default axes order matching the data tensor rank and erase the
|
||||
// element at the 'normalize_axis' position. The erased element indicates the
|
||||
|
50
ngraph/test/models/onnx/lp_norm_default_dynamic.prototxt
Normal file
50
ngraph/test/models/onnx/lp_norm_default_dynamic.prototxt
Normal file
@ -0,0 +1,50 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
input: "x"
|
||||
output: "y"
|
||||
op_type: "LpNormalization"
|
||||
}
|
||||
name: "lp_norm_graph"
|
||||
input {
|
||||
name: "x"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 4
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 1
|
||||
}
|
@ -2886,6 +2886,26 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_lp_norm_default)
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_lp_norm_default_dynamic)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/lp_norm_default_dynamic.prototxt"));
|
||||
|
||||
Shape data_shape{2, 3, 4};
|
||||
std::vector<float> data(shape_size(data_shape));
|
||||
std::iota(std::begin(data), std::end(data), 1);
|
||||
|
||||
auto test_case = test::TestCase<TestEngine, test::TestCaseType::DYNAMIC>(function);
|
||||
test_case.add_input<float>(data_shape, data);
|
||||
test_case.add_expected_output<float>(
|
||||
data_shape, {0.18257418f, 0.36514837f, 0.5477225f, 0.73029673f, 0.37904903f, 0.45485884f,
|
||||
0.5306686f, 0.60647845f, 0.42616236f, 0.47351375f, 0.5208651f, 0.5682165f,
|
||||
0.4469492f, 0.48132992f, 0.51571065f, 0.5500913f, 0.45862272f, 0.48560053f,
|
||||
0.5125783f, 0.53955615f, 0.46609157f, 0.4882864f, 0.51048124f, 0.5326761f});
|
||||
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_instance_normalization)
|
||||
{
|
||||
const auto function = onnx_import::import_onnx_model(
|
||||
|
@ -202,6 +202,7 @@ onnx_model_range_positive_step
|
||||
onnx_model_range_negative_step
|
||||
onnx_dyn_shapes_slice_1_3d_input_21_axes_ends_max
|
||||
onnx_model_max_pool_dyn_rank_without_default_attrs
|
||||
onnx_model_lp_norm_default_dynamic
|
||||
|
||||
# (Constant W, R inputs are required) Ticket: 49207
|
||||
# Function references undeclared parameters
|
||||
|
Loading…
Reference in New Issue
Block a user