OpShapeInfer (#8310)
* [ShapeInfer]impl assign/read_value * [ShapeInfer]impl tile * [ShapeInfer]impl LSTM cell * [ShapeInfer]Impl PriorGrid * [ShapeInfer]remove useless friend function * [ShapeInfer]apply review comments * [ShapeInfer]revise code * [ShapeInfer]impl copy_shape * [ShapeInfer]fix tile ci * fix onnx ci test * remove test_compatibility fail_issue_39658 * [ShapeInfer]fix reviews * [ShapeInfer]restore rnn_cell_base * [ShapeInfer]fix win build * [ShapeInfer]fix win type conversion * [ShapeInfer]fix merging * [ShapeInfer]move shape_infer to src/core * [ShapeInfer]apply review comments * [ShapeInfer]use shape_infer in tile evaluate * [ShapeInfer]fix tile ci * [ShapeInfer]enable shape_infer in mkldnn * [ShapeInfer]use shape_inference in tests * [ShapeInfer]remove useless in tile evaluate
This commit is contained in:
parent
82415f00d8
commit
ee4643d97e
@ -2,23 +2,30 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <openvino/core/node.hpp>
|
#include "shape_inference.hpp"
|
||||||
|
|
||||||
#include <ngraph/runtime/host_tensor.hpp>
|
#include <ngraph/runtime/host_tensor.hpp>
|
||||||
|
#include <openvino/core/node.hpp>
|
||||||
#include <openvino/opsets/opset1.hpp>
|
#include <openvino/opsets/opset1.hpp>
|
||||||
#include <openvino/opsets/opset2.hpp>
|
#include <openvino/opsets/opset2.hpp>
|
||||||
#include <openvino/opsets/opset4.hpp>
|
#include <openvino/opsets/opset4.hpp>
|
||||||
#include <openvino/opsets/opset5.hpp>
|
#include <openvino/opsets/opset5.hpp>
|
||||||
#include <openvino/opsets/opset6.hpp>
|
#include <openvino/opsets/opset6.hpp>
|
||||||
#include <openvino/opsets/opset8.hpp>
|
#include <openvino/opsets/opset8.hpp>
|
||||||
#include "static_shape.hpp"
|
|
||||||
#include "utils.hpp"
|
|
||||||
#include "shape_inference.hpp"
|
|
||||||
#include "convolution_shape_inference.hpp"
|
|
||||||
#include "reduce_shape_inference.hpp"
|
|
||||||
#include "shape_nodes.hpp"
|
|
||||||
#include "fake_quantize.hpp"
|
|
||||||
#include "experimental_detectron_detection_output_shape_inference.hpp"
|
|
||||||
|
|
||||||
|
#include "assign_shape_inference.hpp"
|
||||||
|
#include "convolution_shape_inference.hpp"
|
||||||
|
#include "experimental_detectron_detection_output_shape_inference.hpp"
|
||||||
|
#include "experimental_detectron_prior_grid_generator_shape_inference.hpp"
|
||||||
|
#include "fake_quantize.hpp"
|
||||||
|
#include "lstm_cell_shape_inference.hpp"
|
||||||
|
#include "read_value_shape_inference.hpp"
|
||||||
|
#include "reduce_shape_inference.hpp"
|
||||||
|
#include "shape_inference.hpp"
|
||||||
|
#include "shape_nodes.hpp"
|
||||||
|
#include "static_shape.hpp"
|
||||||
|
#include "tile_shape_inference.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
void shape_inference(ov::Node* op,
|
void shape_inference(ov::Node* op,
|
||||||
const std::vector<ov::StaticShape>& input_shapes,
|
const std::vector<ov::StaticShape>& input_shapes,
|
||||||
@ -27,44 +34,53 @@ void shape_inference(ov::Node* op,
|
|||||||
if (auto node = ov::as_type<ov::opset8::Convolution>(op)) {
|
if (auto node = ov::as_type<ov::opset8::Convolution>(op)) {
|
||||||
ov::CoordinateDiff pads_begin, pads_end;
|
ov::CoordinateDiff pads_begin, pads_end;
|
||||||
bool status = resolve_auto_pad_for_shape(node, pads_begin, pads_end, input_shapes, 2, 2);
|
bool status = resolve_auto_pad_for_shape(node, pads_begin, pads_end, input_shapes, 2, 2);
|
||||||
OPENVINO_ASSERT(status, "Convolution shape inference doesn't have enough information to calculate static shapes");
|
OPENVINO_ASSERT(status,
|
||||||
|
"Convolution shape inference doesn't have enough information to calculate static shapes");
|
||||||
shape_infer(node, pads_begin, pads_end, input_shapes, output_shapes);
|
shape_infer(node, pads_begin, pads_end, input_shapes, output_shapes);
|
||||||
} else if (auto node = ov::as_type<ov::opset8::GroupConvolution>(op)) {
|
} else if (auto node = ov::as_type<ov::opset8::GroupConvolution>(op)) {
|
||||||
ov::CoordinateDiff pads_begin, pads_end;
|
ov::CoordinateDiff pads_begin, pads_end;
|
||||||
bool status = resolve_auto_pad_for_shape(node, pads_begin, pads_end, input_shapes, 2, 3);
|
bool status = resolve_auto_pad_for_shape(node, pads_begin, pads_end, input_shapes, 2, 3);
|
||||||
OPENVINO_ASSERT(status, "GroupConvolution shape inference doesn't have enough information to calculate static shapes");
|
OPENVINO_ASSERT(status,
|
||||||
|
"GroupConvolution shape inference doesn't have enough information to calculate static shapes");
|
||||||
shape_infer(node, pads_begin, pads_end, input_shapes, output_shapes);
|
shape_infer(node, pads_begin, pads_end, input_shapes, output_shapes);
|
||||||
} else if (auto node = ov::as_type<ov::opset8::ConvolutionBackpropData>(op)) {
|
} else if (auto node = ov::as_type<ov::opset8::ConvolutionBackpropData>(op)) {
|
||||||
ov::CoordinateDiff pads_begin, pads_end;
|
ov::CoordinateDiff pads_begin, pads_end;
|
||||||
ov::StaticShape output_shape_input;
|
ov::StaticShape output_shape_input;
|
||||||
if (node->get_input_size() == 3)
|
if (node->get_input_size() == 3)
|
||||||
get_data_as_shape<ov::StaticShape>(2, op, output_shape_input, constant_data);
|
get_data_as_shape<ov::StaticShape>(2, op, output_shape_input, constant_data);
|
||||||
bool status = resolve_auto_pad_for_shape_back_prop(node, pads_begin, pads_end, input_shapes, output_shape_input, 2, 2);
|
bool status =
|
||||||
OPENVINO_ASSERT(status, "ConvolutionBackpropData shape inference doesn't have enough information to calculate static shapes");
|
resolve_auto_pad_for_shape_back_prop(node, pads_begin, pads_end, input_shapes, output_shape_input, 2, 2);
|
||||||
|
OPENVINO_ASSERT(
|
||||||
|
status,
|
||||||
|
"ConvolutionBackpropData shape inference doesn't have enough information to calculate static shapes");
|
||||||
shape_infer(node, pads_begin, pads_end, output_shape_input, input_shapes, output_shapes);
|
shape_infer(node, pads_begin, pads_end, output_shape_input, input_shapes, output_shapes);
|
||||||
} else if (auto node = ov::as_type<ov::opset8::GroupConvolutionBackpropData>(op)) {
|
} else if (auto node = ov::as_type<ov::opset8::GroupConvolutionBackpropData>(op)) {
|
||||||
ov::CoordinateDiff pads_begin, pads_end;
|
ov::CoordinateDiff pads_begin, pads_end;
|
||||||
ov::StaticShape output_shape_input;
|
ov::StaticShape output_shape_input;
|
||||||
if (node->get_input_size() == 3)
|
if (node->get_input_size() == 3)
|
||||||
get_data_as_shape<ov::StaticShape>(2, op, output_shape_input, constant_data);
|
get_data_as_shape<ov::StaticShape>(2, op, output_shape_input, constant_data);
|
||||||
bool status = resolve_auto_pad_for_shape_back_prop(node, pads_begin, pads_end, input_shapes, output_shape_input, 2, 3);
|
bool status =
|
||||||
OPENVINO_ASSERT(status, "GroupConvolutionBackpropData shape inference doesn't have enough information to calculate static shapes");
|
resolve_auto_pad_for_shape_back_prop(node, pads_begin, pads_end, input_shapes, output_shape_input, 2, 3);
|
||||||
|
OPENVINO_ASSERT(
|
||||||
|
status,
|
||||||
|
"GroupConvolutionBackpropData shape inference doesn't have enough information to calculate static shapes");
|
||||||
shape_infer(node, pads_begin, pads_end, output_shape_input, input_shapes, output_shapes);
|
shape_infer(node, pads_begin, pads_end, output_shape_input, input_shapes, output_shapes);
|
||||||
} else if (auto node = ov::as_type<ov::op::util::ArithmeticReductionKeepDims>(op)) {
|
} else if (auto node = ov::as_type<ov::op::util::ArithmeticReductionKeepDims>(op)) {
|
||||||
shape_infer(node, input_shapes, output_shapes, constant_data);
|
shape_infer(node, input_shapes, output_shapes, constant_data);
|
||||||
} else if (auto node = ov::as_type<ov::op::util::LogicalReductionKeepDims>(op)) {
|
} else if (auto node = ov::as_type<ov::op::util::LogicalReductionKeepDims>(op)) {
|
||||||
shape_infer(node, input_shapes, output_shapes, constant_data);
|
shape_infer(node, input_shapes, output_shapes, constant_data);
|
||||||
} else if (ov::is_type<ov::op::util::UnaryElementwiseArithmetic>(op) ||
|
} else if (ov::is_type<ov::op::util::UnaryElementwiseArithmetic>(op) || ov::is_type<ov::opset1::Convert>(op) ||
|
||||||
ov::is_type<ov::opset1::Convert>(op) || ov::is_type<ov::opset1::Clamp>(op) ||
|
ov::is_type<ov::opset1::Clamp>(op) || ov::is_type<ov::opset1::GRN>(op) ||
|
||||||
ov::is_type<ov::opset1::GRN>(op) || ov::is_type<ov::opset1::LRN>(op) ||
|
ov::is_type<ov::opset1::LRN>(op) || ov::is_type<ov::opset1::LogicalNot>(op) ||
|
||||||
ov::is_type<ov::opset1::LogicalNot>(op) || ov::is_type<ov::opset4::Mish>(op) ||
|
ov::is_type<ov::opset4::Mish>(op) || ov::is_type<ov::opset2::MVN>(op) ||
|
||||||
ov::is_type<ov::opset2::MVN>(op) || ov::is_type<ov::opset6::MVN>(op) ||
|
ov::is_type<ov::opset6::MVN>(op) || ov::is_type<ov::opset1::PRelu>(op) ||
|
||||||
ov::is_type<ov::opset1::PRelu>(op) || ov::is_type<ov::opset1::Relu>(op) ||
|
ov::is_type<ov::opset1::Relu>(op) || ov::is_type<ov::opset4::Swish>(op) ||
|
||||||
ov::is_type<ov::opset4::Swish>(op) || ov::is_type<ov::opset1::Softmax>(op) ||
|
ov::is_type<ov::opset1::Softmax>(op) || ov::is_type<ov::opset1::Elu>(op) ||
|
||||||
ov::is_type<ov::opset1::Elu>(op) || ov::is_type<ov::opset5::Round>(op)) {
|
ov::is_type<ov::opset5::Round>(op)) {
|
||||||
copy_shape_infer(node, input_shapes, output_shapes);
|
copy_shape_infer(node, input_shapes, output_shapes);
|
||||||
} else if (ov::is_type<ov::op::util::BinaryElementwiseArithmetic>(op) ||
|
} else if (ov::is_type<ov::op::util::BinaryElementwiseArithmetic>(op) ||
|
||||||
ov::is_type<ov::op::util::BinaryElementwiseComparison>(op) || ov::is_type<ov::op::util::BinaryElementwiseLogical>(op)) {
|
ov::is_type<ov::op::util::BinaryElementwiseComparison>(op) ||
|
||||||
|
ov::is_type<ov::op::util::BinaryElementwiseLogical>(op)) {
|
||||||
eltwise_shape_infer(op, input_shapes, output_shapes);
|
eltwise_shape_infer(op, input_shapes, output_shapes);
|
||||||
} else if (auto node = ov::as_type<ov::opset1::FakeQuantize>(op)) {
|
} else if (auto node = ov::as_type<ov::opset1::FakeQuantize>(op)) {
|
||||||
shape_infer(node, input_shapes, output_shapes);
|
shape_infer(node, input_shapes, output_shapes);
|
||||||
@ -80,15 +96,30 @@ void shape_inference(ov::Node* op,
|
|||||||
shape_infer(node, input_shapes, output_shapes);
|
shape_infer(node, input_shapes, output_shapes);
|
||||||
} else if (auto node = ov::as_type<ov::opset6::ExperimentalDetectronDetectionOutput>(op)) {
|
} else if (auto node = ov::as_type<ov::opset6::ExperimentalDetectronDetectionOutput>(op)) {
|
||||||
shape_infer(node, input_shapes, output_shapes);
|
shape_infer(node, input_shapes, output_shapes);
|
||||||
|
} else if (auto node = ov::as_type<ov::opset3::Assign>(op)) {
|
||||||
|
shape_infer(node, input_shapes, output_shapes);
|
||||||
|
} else if (auto node = ov::as_type<ov::opset6::Assign>(op)) {
|
||||||
|
shape_infer(node, input_shapes, output_shapes);
|
||||||
|
} else if (auto node = ov::as_type<ov::opset6::ExperimentalDetectronPriorGridGenerator>(op)) {
|
||||||
|
shape_infer(node, input_shapes, output_shapes);
|
||||||
|
} else if (auto node = ov::as_type<ov::opset1::LSTMCell>(op)) {
|
||||||
|
shape_infer(node, input_shapes, output_shapes);
|
||||||
|
} else if (auto node = ov::as_type<ov::opset6::LSTMCell>(op)) {
|
||||||
|
shape_infer(node, input_shapes, output_shapes);
|
||||||
|
} else if (auto node = ov::as_type<ov::opset3::ReadValue>(op)) {
|
||||||
|
shape_infer(node, input_shapes, output_shapes);
|
||||||
|
} else if (auto node = ov::as_type<ov::opset6::ReadValue>(op)) {
|
||||||
|
shape_infer(node, input_shapes, output_shapes);
|
||||||
|
} else if (auto node = ov::as_type<ov::opset6::Tile>(op)) {
|
||||||
|
shape_infer(node, input_shapes, output_shapes, constant_data);
|
||||||
} else {
|
} else {
|
||||||
ngraph::OutputVector new_inputs;
|
ngraph::OutputVector new_inputs;
|
||||||
for (size_t i = 0; i < op->get_input_size(); ++i) {
|
for (size_t i = 0; i < op->get_input_size(); ++i) {
|
||||||
if (constant_data.count(i)) {
|
if (constant_data.count(i)) {
|
||||||
new_inputs.push_back(std::make_shared<ov::opset1::Constant>(constant_data.at(i)));
|
new_inputs.push_back(std::make_shared<ov::opset1::Constant>(constant_data.at(i)));
|
||||||
} else {
|
} else {
|
||||||
new_inputs.push_back(
|
new_inputs.push_back(std::make_shared<ov::opset1::Parameter>(op->get_input_element_type(i),
|
||||||
std::make_shared<ov::opset1::Parameter>(
|
input_shapes[i].to_partial_shape()));
|
||||||
op->get_input_element_type(i), input_shapes[i].to_partial_shape()));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const auto local_op = op->clone_with_new_inputs(new_inputs);
|
const auto local_op = op->clone_with_new_inputs(new_inputs);
|
||||||
@ -96,8 +127,10 @@ void shape_inference(ov::Node* op,
|
|||||||
|
|
||||||
output_shapes.resize(op->get_output_size());
|
output_shapes.resize(op->get_output_size());
|
||||||
for (size_t i = 0; i < output_shapes.size(); ++i) {
|
for (size_t i = 0; i < output_shapes.size(); ++i) {
|
||||||
const auto &partial_shape = local_op->get_output_partial_shape(i);
|
const auto& partial_shape = local_op->get_output_partial_shape(i);
|
||||||
OPENVINO_ASSERT(partial_shape.is_static(), "On device shape infer shouldn't support default shape infer for nodes with internal dynamism");
|
OPENVINO_ASSERT(
|
||||||
|
partial_shape.is_static(),
|
||||||
|
"On device shape infer shouldn't support default shape infer for nodes with internal dynamism");
|
||||||
output_shapes[i] = ov::StaticShape(partial_shape.to_shape());
|
output_shapes[i] = ov::StaticShape(partial_shape.to_shape());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,47 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <openvino/op/ops.hpp>
|
||||||
|
#include <openvino/op/parameter.hpp>
|
||||||
|
#include <utils/shape_inference/shape_inference.hpp>
|
||||||
|
#include <utils/shape_inference/static_shape.hpp>
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
template <class T>
|
||||||
|
std::shared_ptr<T> constructGraph();
|
||||||
|
|
||||||
|
template <>
|
||||||
|
std::shared_ptr<op::v3::Assign> constructGraph() {
|
||||||
|
auto input = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||||
|
auto read_value = std::make_shared<op::v3::ReadValue>(input, "variable_id");
|
||||||
|
return std::make_shared<op::v3::Assign>(read_value, "variable_id");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
std::shared_ptr<op::v6::Assign> constructGraph() {
|
||||||
|
auto input = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||||
|
auto variable = std::make_shared<ov::op::util::Variable>(
|
||||||
|
ov::op::util::VariableInfo{PartialShape::dynamic(), element::dynamic, "ID"});
|
||||||
|
auto read_value = std::make_shared<op::v6::Assign>(input, variable);
|
||||||
|
return std::make_shared<op::v6::Assign>(read_value, variable);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void assignTest() {
|
||||||
|
auto assign = constructGraph<T>();
|
||||||
|
|
||||||
|
// Test StaticShape
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 2, 64, 64}}, static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(assign.get(), static_input_shapes, static_output_shapes);
|
||||||
|
ASSERT_EQ(static_input_shapes[0], (StaticShape{1, 2, 64, 64}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, AssignTest) {
|
||||||
|
// Test v3 Assign
|
||||||
|
assignTest<op::v3::Assign>();
|
||||||
|
// Test v6 Assign
|
||||||
|
assignTest<op::v6::Assign>();
|
||||||
|
}
|
@ -0,0 +1,37 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <openvino/op/experimental_detectron_prior_grid_generator.hpp>
|
||||||
|
#include <openvino/op/ops.hpp>
|
||||||
|
#include <openvino/op/parameter.hpp>
|
||||||
|
#include <utils/shape_inference/shape_inference.hpp>
|
||||||
|
#include <utils/shape_inference/static_shape.hpp>
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, PriorGridGenerator) {
|
||||||
|
op::v6::ExperimentalDetectronPriorGridGenerator::Attributes attrs;
|
||||||
|
attrs.flatten = false;
|
||||||
|
attrs.h = 0;
|
||||||
|
attrs.w = 0;
|
||||||
|
attrs.stride_x = 4.0f;
|
||||||
|
attrs.stride_y = 4.0f;
|
||||||
|
|
||||||
|
auto priors = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||||
|
auto feature_map = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||||
|
auto im_data = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||||
|
|
||||||
|
auto grid_gen =
|
||||||
|
std::make_shared<ov::op::v6::ExperimentalDetectronPriorGridGenerator>(priors, feature_map, im_data, attrs);
|
||||||
|
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 4},
|
||||||
|
StaticShape{1, 256, 200, 336},
|
||||||
|
StaticShape{1, 3, 800, 1344}},
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(grid_gen.get(), static_input_shapes, static_output_shapes);
|
||||||
|
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({200, 336, 3, 4}));
|
||||||
|
}
|
@ -0,0 +1,38 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <openvino/op/ops.hpp>
|
||||||
|
#include <openvino/op/parameter.hpp>
|
||||||
|
#include <utils/shape_inference/shape_inference.hpp>
|
||||||
|
#include <utils/shape_inference/static_shape.hpp>
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, LstmCellTest) {
|
||||||
|
const size_t batch_size = 2;
|
||||||
|
const size_t input_size = 3;
|
||||||
|
const size_t hidden_size = 3;
|
||||||
|
const size_t gates_count = 4;
|
||||||
|
|
||||||
|
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||||
|
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||||
|
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||||
|
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||||
|
const auto C_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||||
|
const auto Bias = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1});
|
||||||
|
const auto lstm_cell = std::make_shared<op::v4::LSTMCell>(X, H_t, C_t, W, R, Bias, hidden_size);
|
||||||
|
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{batch_size, input_size},
|
||||||
|
StaticShape{batch_size, hidden_size},
|
||||||
|
StaticShape{batch_size, hidden_size},
|
||||||
|
StaticShape{gates_count * hidden_size, input_size},
|
||||||
|
StaticShape{gates_count * hidden_size, hidden_size},
|
||||||
|
StaticShape{gates_count * hidden_size}},
|
||||||
|
static_output_shapes = {StaticShape{}, StaticShape{}};
|
||||||
|
shape_inference(lstm_cell.get(), static_input_shapes, static_output_shapes);
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size}));
|
||||||
|
ASSERT_EQ(static_output_shapes[1], StaticShape({batch_size, hidden_size}));
|
||||||
|
}
|
@ -0,0 +1,45 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <openvino/op/ops.hpp>
|
||||||
|
#include <openvino/op/parameter.hpp>
|
||||||
|
#include <utils/shape_inference/shape_inference.hpp>
|
||||||
|
#include <utils/shape_inference/static_shape.hpp>
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
std::shared_ptr<T> constructGraph();
|
||||||
|
|
||||||
|
template <>
|
||||||
|
std::shared_ptr<op::v3::ReadValue> constructGraph() {
|
||||||
|
auto input = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||||
|
return std::make_shared<op::v3::ReadValue>(input, "variable_id");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
std::shared_ptr<op::v6::ReadValue> constructGraph() {
|
||||||
|
auto input = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||||
|
auto variable = std::make_shared<ov::op::util::Variable>(
|
||||||
|
ov::op::util::VariableInfo{PartialShape::dynamic(), element::dynamic, "ID"});
|
||||||
|
return std::make_shared<op::v6::ReadValue>(input, variable);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void readValueTest() {
|
||||||
|
auto readValue = constructGraph<T>();
|
||||||
|
|
||||||
|
// Test StaticShape
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 2, 64, 64}}, static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(readValue.get(), static_input_shapes, static_output_shapes);
|
||||||
|
ASSERT_EQ(static_output_shapes[0], (StaticShape{1, 2, 64, 64}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, ReadValueTest) {
|
||||||
|
// Test v3 ReadValue
|
||||||
|
readValueTest<op::v3::ReadValue>();
|
||||||
|
// Test v6 ReadValue
|
||||||
|
readValueTest<op::v6::ReadValue>();
|
||||||
|
}
|
@ -0,0 +1,50 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <openvino/op/ops.hpp>
|
||||||
|
#include <openvino/op/parameter.hpp>
|
||||||
|
#include <utils/shape_inference/shape_inference.hpp>
|
||||||
|
#include <utils/shape_inference/static_shape.hpp>
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, TileTest) {
|
||||||
|
auto param0 = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||||
|
auto param1 = std::make_shared<ov::op::v0::Constant>(element::i64, ov::Shape{3}, std::vector<int>{3, 4, 1});
|
||||||
|
auto tile = std::make_shared<op::v0::Tile>(param0, param1);
|
||||||
|
// Test Static Shape
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{6, 8, 10}, StaticShape{3}},
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(tile.get(), static_input_shapes, static_output_shapes);
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({18, 32, 10}));
|
||||||
|
// Test Wrong Static Shape
|
||||||
|
std::vector<StaticShape> wrong_static_input_shapes = {StaticShape{6, 8, 10}, StaticShape{}},
|
||||||
|
wrong_static_output_shapes = {StaticShape{}};
|
||||||
|
|
||||||
|
ASSERT_THROW(shape_inference(tile.get(), wrong_static_input_shapes, wrong_static_output_shapes), ov::AssertFailure);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, TileFewRepeatsTest) {
|
||||||
|
auto param0 = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||||
|
auto param1 = ov::op::v0::Constant::create(element::i64, Shape{2}, {4, 1});
|
||||||
|
auto tile = std::make_shared<op::v0::Tile>(param0, param1);
|
||||||
|
// Test Static Shape
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{6, 8, 10}, StaticShape{2}},
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(tile.get(), static_input_shapes, static_output_shapes);
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({6, 32, 10}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(StaticShapeInferenceTest, TileSmallDataRankTest) {
|
||||||
|
auto param0 = std::make_shared<ov::op::v0::Parameter>(element::f32, PartialShape{-1, -1});
|
||||||
|
auto param1 = ov::op::v0::Constant::create(element::i64, Shape{3}, {3, 4, 1});
|
||||||
|
auto tile = std::make_shared<op::v0::Tile>(param0, param1);
|
||||||
|
// Test Static Shape
|
||||||
|
std::vector<StaticShape> static_input_shapes = {StaticShape{8, 10}, StaticShape{3}},
|
||||||
|
static_output_shapes = {StaticShape{}};
|
||||||
|
shape_inference(tile.get(), static_input_shapes, static_output_shapes);
|
||||||
|
ASSERT_EQ(static_output_shapes[0], StaticShape({3, 32, 10}));
|
||||||
|
}
|
@ -34,6 +34,8 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
std::string m_variable_id;
|
std::string m_variable_id;
|
||||||
|
template <class T>
|
||||||
|
friend void shape_infer(const Assign* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes);
|
||||||
};
|
};
|
||||||
} // namespace v3
|
} // namespace v3
|
||||||
|
|
||||||
@ -70,6 +72,10 @@ public:
|
|||||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override;
|
bool constant_fold(OutputVector& output_values, const OutputVector& inputs_values) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <class T>
|
||||||
|
friend void shape_infer(const Assign* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes);
|
||||||
};
|
};
|
||||||
} // namespace v6
|
} // namespace v6
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -60,8 +60,10 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
Attributes m_attrs;
|
Attributes m_attrs;
|
||||||
|
template <class T>
|
||||||
void validate();
|
friend void shape_infer(const ExperimentalDetectronPriorGridGenerator* op,
|
||||||
|
const std::vector<T>& input_shapes,
|
||||||
|
std::vector<T>& output_shapes);
|
||||||
};
|
};
|
||||||
} // namespace v6
|
} // namespace v6
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -241,6 +241,8 @@ private:
|
|||||||
|
|
||||||
static constexpr std::size_t s_gates_count{4};
|
static constexpr std::size_t s_gates_count{4};
|
||||||
static constexpr std::size_t s_peepholes_count{3};
|
static constexpr std::size_t s_peepholes_count{3};
|
||||||
|
template <class T>
|
||||||
|
friend void shape_infer(const LSTMCell* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes);
|
||||||
};
|
};
|
||||||
} // namespace v0
|
} // namespace v0
|
||||||
|
|
||||||
@ -378,6 +380,8 @@ private:
|
|||||||
util::ActivationFunction m_activation_h;
|
util::ActivationFunction m_activation_h;
|
||||||
|
|
||||||
static constexpr std::size_t s_gates_count{4};
|
static constexpr std::size_t s_gates_count{4};
|
||||||
|
template <class T>
|
||||||
|
friend void shape_infer(const LSTMCell* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes);
|
||||||
};
|
};
|
||||||
} // namespace v4
|
} // namespace v4
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
41
src/core/shape_inference/include/assign_shape_inference.hpp
Normal file
41
src/core/shape_inference/include/assign_shape_inference.hpp
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
#pragma once
|
||||||
|
#include <openvino/core/graph_util.hpp>
|
||||||
|
#include <openvino/op/assign.hpp>
|
||||||
|
|
||||||
|
#include "utils.hpp"
|
||||||
|
namespace ov {
|
||||||
|
namespace op {
|
||||||
|
namespace v3 {
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void shape_infer(const Assign* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
|
||||||
|
NODE_VALIDATION_CHECK(op, input_shapes.size() == 1 && output_shapes.size() == 1);
|
||||||
|
const auto& input_shape = input_shapes[0];
|
||||||
|
const auto& variable_info = op->m_variable->get_info();
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
op->m_variable_id == variable_info.variable_id,
|
||||||
|
"Variables identifiers are inconsistent.");
|
||||||
|
const auto& arg_t = op->get_input_element_type(0);
|
||||||
|
NODE_VALIDATION_CHECK(op, arg_t == variable_info.data_type, "Variables types are inconsistent.");
|
||||||
|
|
||||||
|
if (input_shape.is_static() && variable_info.data_shape.is_static()) {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
input_shape.to_shape() == variable_info.data_shape.to_shape(),
|
||||||
|
"Variables output shapes are inconsistent.");
|
||||||
|
}
|
||||||
|
copy_shape_infer(op, input_shapes, output_shapes);
|
||||||
|
}
|
||||||
|
} // namespace v3
|
||||||
|
|
||||||
|
namespace v6 {
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void shape_infer(const Assign* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
|
||||||
|
copy_shape_infer(op, input_shapes, output_shapes);
|
||||||
|
}
|
||||||
|
} // namespace v6
|
||||||
|
} // namespace op
|
||||||
|
} // namespace ov
|
@ -0,0 +1,76 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
#pragma once
|
||||||
|
#include <openvino/op/experimental_detectron_prior_grid_generator.hpp>
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace op {
|
||||||
|
namespace v6 {
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void shape_infer(const ExperimentalDetectronPriorGridGenerator* op,
|
||||||
|
const std::vector<T>& input_shapes,
|
||||||
|
std::vector<T>& output_shapes) {
|
||||||
|
NODE_VALIDATION_CHECK(op, input_shapes.size() == 3 && output_shapes.size() == 1);
|
||||||
|
const auto& priors_shape = input_shapes[0];
|
||||||
|
const auto& featmap_shape = input_shapes[1];
|
||||||
|
const auto& im_data_shape = input_shapes[2];
|
||||||
|
|
||||||
|
auto& output_shape = output_shapes[0];
|
||||||
|
size_t output_size = op->m_attrs.flatten ? 2 : 4;
|
||||||
|
|
||||||
|
output_shape.resize(output_size);
|
||||||
|
output_shape[output_size - 1] = 4;
|
||||||
|
|
||||||
|
bool prior_rank_static = priors_shape.rank().is_static();
|
||||||
|
bool featmap_rank_static = featmap_shape.rank().is_static();
|
||||||
|
bool im_data_rank_static = im_data_shape.rank().is_static();
|
||||||
|
|
||||||
|
if (prior_rank_static) {
|
||||||
|
NODE_VALIDATION_CHECK(op, priors_shape.size() == 2, "Priors rank must be equal to 2.");
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
priors_shape[1].compatible(4),
|
||||||
|
"The last dimension of the 'priors' input must be equal to 4. Got: ",
|
||||||
|
priors_shape[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (featmap_rank_static) {
|
||||||
|
NODE_VALIDATION_CHECK(op, featmap_shape.size() == 4, "Feature_map rank must be equal to 4.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (im_data_rank_static) {
|
||||||
|
NODE_VALIDATION_CHECK(op, im_data_shape.size() == 4, "Im_data rank must be equal to 4.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (featmap_rank_static && im_data_rank_static) {
|
||||||
|
const auto& num_batches_featmap = featmap_shape[0];
|
||||||
|
const auto& num_batches_im_data = im_data_shape[0];
|
||||||
|
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
num_batches_featmap.compatible(num_batches_im_data),
|
||||||
|
"The first dimension of both 'feature_map' and 'im_data' must match. "
|
||||||
|
"Feature_map: ",
|
||||||
|
num_batches_featmap,
|
||||||
|
"; Im_data: ",
|
||||||
|
num_batches_im_data);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (op->m_attrs.flatten) {
|
||||||
|
if (prior_rank_static && featmap_rank_static) {
|
||||||
|
output_shape[0] = featmap_shape[2] * featmap_shape[3] * priors_shape[0];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (featmap_rank_static) {
|
||||||
|
output_shape[0] = featmap_shape[2];
|
||||||
|
output_shape[1] = featmap_shape[3];
|
||||||
|
}
|
||||||
|
if (prior_rank_static) {
|
||||||
|
output_shape[2] = priors_shape[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace v6
|
||||||
|
} // namespace op
|
||||||
|
} // namespace ov
|
191
src/core/shape_inference/include/lstm_cell_shape_inference.hpp
Normal file
191
src/core/shape_inference/include/lstm_cell_shape_inference.hpp
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
#pragma once
|
||||||
|
#include <openvino/op/lstm_cell.hpp>
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace op {
|
||||||
|
namespace ShapeInferLSTM {
|
||||||
|
template <class OpsType, class ShapeType>
|
||||||
|
void lstm_shape_infer(const OpsType* op,
|
||||||
|
const std::vector<ShapeType>& input_shapes,
|
||||||
|
std::vector<ShapeType>& output_shapes,
|
||||||
|
std::size_t gates_count) {
|
||||||
|
using DimType = typename std::iterator_traits<typename ShapeType::iterator>::value_type;
|
||||||
|
enum { X, initial_hidden_state, initial_cell_state, W, R, B };
|
||||||
|
std::vector<bool> input_rank_static(6, false);
|
||||||
|
bool all_rank_dynamic = false;
|
||||||
|
bool all_rank_static = true;
|
||||||
|
// Prepare OutShape
|
||||||
|
auto& hidden_shape = output_shapes[0];
|
||||||
|
auto& cell_shape = output_shapes[1];
|
||||||
|
hidden_shape.resize(2);
|
||||||
|
cell_shape.resize(2);
|
||||||
|
|
||||||
|
// If rank is dynamic, then output_shape is undefined
|
||||||
|
for (size_t i = 0; i < input_shapes.size(); i++) {
|
||||||
|
input_rank_static[i] = input_shapes[i].rank().is_static();
|
||||||
|
all_rank_dynamic &= !input_rank_static[i];
|
||||||
|
all_rank_static &= input_rank_static[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (all_rank_dynamic) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const auto& x_pshape = input_shapes[0];
|
||||||
|
const auto& w_pshape = input_shapes[3];
|
||||||
|
|
||||||
|
DimType output_batch_size;
|
||||||
|
DimType output_hidden_size;
|
||||||
|
bool is_batch_init = false;
|
||||||
|
bool is_hidden_init = false;
|
||||||
|
|
||||||
|
// deduce batch/hidden_size
|
||||||
|
for (size_t i = 0; i < input_shapes.size(); i++) {
|
||||||
|
const auto& input = input_shapes[i];
|
||||||
|
if (input_rank_static[i]) {
|
||||||
|
// batch could be deduced from x, cell_state or hidden_state
|
||||||
|
if (i == X || i == initial_cell_state || i == initial_hidden_state) {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
(input.size() == 2),
|
||||||
|
"LSTMCell input rank is not correct for ",
|
||||||
|
i,
|
||||||
|
" input parameter. Current rank: ",
|
||||||
|
input.size(),
|
||||||
|
", expected: 2.");
|
||||||
|
if (!is_batch_init) {
|
||||||
|
output_batch_size = input[0];
|
||||||
|
is_batch_init = true;
|
||||||
|
} else {
|
||||||
|
NODE_VALIDATION_CHECK(
|
||||||
|
op,
|
||||||
|
DimType::merge(output_batch_size, output_batch_size, input[0]),
|
||||||
|
"Parameter batch_size not matched for X, initial_hidden_state or initial_cell_state "
|
||||||
|
"inputs.");
|
||||||
|
}
|
||||||
|
if (i == initial_cell_state || i == initial_hidden_state) {
|
||||||
|
if (!is_hidden_init) {
|
||||||
|
output_hidden_size = input[1];
|
||||||
|
is_hidden_init = true;
|
||||||
|
} else {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
DimType::merge(output_hidden_size, output_hidden_size, input[1]),
|
||||||
|
"Parameter hidden_size not matched for W, R, B, initial_hidden_state and "
|
||||||
|
"initial_cell_state "
|
||||||
|
"inputs.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (i == W || i == R || i == B) {
|
||||||
|
// check input dimension
|
||||||
|
if (i == B) {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
(input.size() == 1),
|
||||||
|
"LSTMCell input tensor dimension is not correct for ",
|
||||||
|
i,
|
||||||
|
" input parameter. Current input length: ",
|
||||||
|
input.size(),
|
||||||
|
", expected: 1.");
|
||||||
|
if (input[0].is_static()) {
|
||||||
|
if (!is_hidden_init) {
|
||||||
|
output_hidden_size = input[0].get_length() / gates_count;
|
||||||
|
is_hidden_init = true;
|
||||||
|
} else {
|
||||||
|
NODE_VALIDATION_CHECK(
|
||||||
|
op,
|
||||||
|
DimType::merge(output_hidden_size, output_hidden_size, input[0].get_length() / gates_count),
|
||||||
|
"Parameter hidden_size not matched for W, R, B, initial_hidden_state and "
|
||||||
|
"initial_cell_state "
|
||||||
|
"inputs.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
(input.size() == 2),
|
||||||
|
"LSTMCell input rank is not correct for ",
|
||||||
|
i,
|
||||||
|
" input parameter. Current rank: ",
|
||||||
|
input.size(),
|
||||||
|
", expected: 2.");
|
||||||
|
if (input[0].is_static()) {
|
||||||
|
if (!is_hidden_init) {
|
||||||
|
output_hidden_size = input[0].get_length() / gates_count;
|
||||||
|
is_hidden_init = true;
|
||||||
|
} else {
|
||||||
|
NODE_VALIDATION_CHECK(
|
||||||
|
op,
|
||||||
|
DimType::merge(output_hidden_size, output_hidden_size, input[0].get_length() / gates_count),
|
||||||
|
"Parameter hidden_size not matched for W, R, B, initial_hidden_state and "
|
||||||
|
"initial_cell_state "
|
||||||
|
"inputs.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (i == R) {
|
||||||
|
if (!is_hidden_init) {
|
||||||
|
output_hidden_size = input[1];
|
||||||
|
is_hidden_init = true;
|
||||||
|
} else {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
DimType::merge(output_hidden_size, output_hidden_size, input[1]),
|
||||||
|
"Parameter hidden_size not matched for W, R, B, initial_hidden_state "
|
||||||
|
"and initial_cell_state "
|
||||||
|
"inputs.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Check peepholes
|
||||||
|
if (input_shapes.size() == 7) {
|
||||||
|
const auto& p_pshape = input_shapes[6];
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
(p_pshape.rank().compatible(1)),
|
||||||
|
"LSTMCell input tensor P shall have dimension 1D.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// check input size
|
||||||
|
if (input_rank_static[X] && input_rank_static[W]) {
|
||||||
|
NODE_VALIDATION_CHECK(op, (x_pshape[1].compatible(w_pshape[1])), "LSTMCell mismatched input_size dimension.");
|
||||||
|
}
|
||||||
|
|
||||||
|
hidden_shape[0] = output_batch_size;
|
||||||
|
hidden_shape[1] = output_hidden_size;
|
||||||
|
cell_shape[0] = output_batch_size;
|
||||||
|
cell_shape[1] = output_hidden_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace ShapeInferLSTM
|
||||||
|
|
||||||
|
namespace v0 {
|
||||||
|
using ShapeInferLSTM::lstm_shape_infer;
|
||||||
|
template <class T>
|
||||||
|
void shape_infer(const LSTMCell* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
|
||||||
|
NODE_VALIDATION_CHECK(op, input_shapes.size() == 7 && output_shapes.size() == 2);
|
||||||
|
const auto& p_pshape = input_shapes[6];
|
||||||
|
|
||||||
|
lstm_shape_infer(op, input_shapes, output_shapes, op->s_gates_count);
|
||||||
|
const auto& hidden_size = output_shapes[0][1];
|
||||||
|
if (p_pshape[0].is_static() && hidden_size.is_static()) {
|
||||||
|
NODE_VALIDATION_CHECK(op,
|
||||||
|
p_pshape[0].compatible(hidden_size * op->s_peepholes_count),
|
||||||
|
"Parameter hidden_size mistmatched in P input. Current value is: ",
|
||||||
|
p_pshape[0].get_length(),
|
||||||
|
", expected: ",
|
||||||
|
hidden_size.get_length() * op->s_peepholes_count,
|
||||||
|
".");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace v0
|
||||||
|
|
||||||
|
namespace v4 {
|
||||||
|
using ShapeInferLSTM::lstm_shape_infer;
|
||||||
|
template <class T>
|
||||||
|
void shape_infer(const LSTMCell* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
|
||||||
|
NODE_VALIDATION_CHECK(op, input_shapes.size() == 6 && output_shapes.size() == 2);
|
||||||
|
lstm_shape_infer(op, input_shapes, output_shapes, op->s_gates_count);
|
||||||
|
}
|
||||||
|
} // namespace v4
|
||||||
|
} // namespace op
|
||||||
|
} // namespace ov
|
@ -0,0 +1,29 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
#pragma once
|
||||||
|
#include <openvino/op/read_value.hpp>
|
||||||
|
#include "utils.hpp"
|
||||||
|
namespace ov {
|
||||||
|
namespace op {
|
||||||
|
|
||||||
|
template <class OpType, class ShapeType>
|
||||||
|
void read_value_shape_infer(const OpType* op, const std::vector<ShapeType>& input_shapes, std::vector<ShapeType>& output_shapes) {
|
||||||
|
copy_shape_infer(op, input_shapes, output_shapes);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace v3 {
|
||||||
|
template <class T>
|
||||||
|
void shape_infer(const ReadValue* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
|
||||||
|
read_value_shape_infer(op, input_shapes, output_shapes);
|
||||||
|
}
|
||||||
|
} // namespace v3
|
||||||
|
|
||||||
|
namespace v6 {
|
||||||
|
template <class T>
|
||||||
|
void shape_infer(const ReadValue* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
|
||||||
|
read_value_shape_infer(op, input_shapes, output_shapes);
|
||||||
|
}
|
||||||
|
} // namespace v6
|
||||||
|
} // namespace op
|
||||||
|
} // namespace ov
|
52
src/core/shape_inference/include/tile_shape_inference.hpp
Normal file
52
src/core/shape_inference/include/tile_shape_inference.hpp
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
// Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
#pragma once
|
||||||
|
#include <openvino/op/tile.hpp>
|
||||||
|
|
||||||
|
#include "utils.hpp"
|
||||||
|
namespace ov {
|
||||||
|
namespace op {
|
||||||
|
namespace v0 {
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void shape_infer(const Tile* op,
|
||||||
|
const std::vector<T>& input_shapes,
|
||||||
|
std::vector<T>& output_shapes,
|
||||||
|
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {}) {
|
||||||
|
NODE_VALIDATION_CHECK(op, input_shapes.size() == 2 && output_shapes.size() == 1);
|
||||||
|
const auto& arg_shape = input_shapes[0];
|
||||||
|
auto& repeats_shape = input_shapes[1];
|
||||||
|
auto& output_shape = output_shapes[0];
|
||||||
|
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||||
|
std::vector<int64_t> axes_val;
|
||||||
|
NODE_VALIDATION_CHECK(op, repeats_shape.rank().compatible(1), "PartialShape of repeats must be of rank 1");
|
||||||
|
|
||||||
|
//Get repeats
|
||||||
|
bool axes_are_known = get_data_as_int64<T>(1, op, axes_val, constant_data);
|
||||||
|
const auto arg_rank = arg_shape.rank();
|
||||||
|
if (arg_rank.is_static() && (axes_are_known || repeats_shape[0].is_static())) {
|
||||||
|
//try to specify rank
|
||||||
|
int64_t data_rank = arg_shape.size();
|
||||||
|
int64_t repeats_rank = axes_are_known ? axes_val.size() : repeats_shape[0].get_length();
|
||||||
|
auto output_rank = std::max(data_rank, repeats_rank);
|
||||||
|
output_shape.resize(output_rank);
|
||||||
|
//if have constant axes, compute new axes
|
||||||
|
if (axes_are_known) {
|
||||||
|
auto remain_arg = output_rank - data_rank;
|
||||||
|
auto remain_axes = output_rank - repeats_rank;
|
||||||
|
for (size_t i = 0; i < output_rank; i++) {
|
||||||
|
auto data_tmp = i < remain_arg ? DimType(1) : arg_shape[i - (remain_arg)];
|
||||||
|
auto repeat_tmp =
|
||||||
|
i < remain_axes ? DimType(1) : axes_val[i - remain_axes];
|
||||||
|
output_shape[i] = data_tmp * repeat_tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
//can't deduce shape, set default value
|
||||||
|
output_shape = PartialShape::dynamic();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace v0
|
||||||
|
} // namespace op
|
||||||
|
} // namespace ov
|
@ -4,6 +4,8 @@
|
|||||||
|
|
||||||
#include "ngraph/op/assign.hpp"
|
#include "ngraph/op/assign.hpp"
|
||||||
|
|
||||||
|
#include <assign_shape_inference.hpp>
|
||||||
|
|
||||||
#include "itt.hpp"
|
#include "itt.hpp"
|
||||||
#include "ngraph/op/read_value.hpp"
|
#include "ngraph/op/read_value.hpp"
|
||||||
#include "ngraph/op/util/variable.hpp"
|
#include "ngraph/op/util/variable.hpp"
|
||||||
@ -26,7 +28,7 @@ void op::v3::Assign::validate_and_infer_types() {
|
|||||||
NGRAPH_OP_SCOPE(v3_Assign_validate_and_infer_types);
|
NGRAPH_OP_SCOPE(v3_Assign_validate_and_infer_types);
|
||||||
auto value = input_value(0);
|
auto value = input_value(0);
|
||||||
auto arg_t = get_input_element_type(0);
|
auto arg_t = get_input_element_type(0);
|
||||||
auto output_shape = get_input_partial_shape(0);
|
const auto& input_shape = get_input_partial_shape(0);
|
||||||
if (!m_variable) {
|
if (!m_variable) {
|
||||||
NodeVector start_nodes;
|
NodeVector start_nodes;
|
||||||
for (const auto& input : inputs()) {
|
for (const auto& input : inputs()) {
|
||||||
@ -41,20 +43,10 @@ void op::v3::Assign::validate_and_infer_types() {
|
|||||||
}
|
}
|
||||||
NODE_VALIDATION_CHECK(this, m_variable != nullptr, "Can't find variable with id = ", m_variable_id);
|
NODE_VALIDATION_CHECK(this, m_variable != nullptr, "Can't find variable with id = ", m_variable_id);
|
||||||
}
|
}
|
||||||
|
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
|
||||||
auto variable_info = m_variable->get_info();
|
std::vector<ov::PartialShape> input_shapes = {input_shape};
|
||||||
NODE_VALIDATION_CHECK(this, m_variable_id == variable_info.variable_id, "Variables identifiers are inconsistent.");
|
shape_infer(this, input_shapes, output_shapes);
|
||||||
NODE_VALIDATION_CHECK(this, arg_t == variable_info.data_type, "Variables types are inconsistent.");
|
set_output_type(0, arg_t, output_shapes[0]);
|
||||||
|
|
||||||
if (output_shape.is_static() && variable_info.data_shape.is_static()) {
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
output_shape == variable_info.data_shape,
|
|
||||||
"Variables output shapes are inconsistent.");
|
|
||||||
|
|
||||||
set_output_type(0, arg_t, output_shape);
|
|
||||||
} else {
|
|
||||||
set_output_type(0, arg_t, ov::PartialShape::dynamic());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> op::v3::Assign::clone_with_new_inputs(const OutputVector& new_args) const {
|
shared_ptr<Node> op::v3::Assign::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||||
@ -78,7 +70,10 @@ op::v6::Assign::Assign(const Output<Node>& new_value, const std::shared_ptr<Vari
|
|||||||
void op::v6::Assign::validate_and_infer_types() {
|
void op::v6::Assign::validate_and_infer_types() {
|
||||||
NGRAPH_OP_SCOPE(v6_Assign_validate_and_infer_types);
|
NGRAPH_OP_SCOPE(v6_Assign_validate_and_infer_types);
|
||||||
m_variable->update({get_input_partial_shape(0), get_input_element_type(0), m_variable->get_info().variable_id});
|
m_variable->update({get_input_partial_shape(0), get_input_element_type(0), m_variable->get_info().variable_id});
|
||||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
|
||||||
|
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0)};
|
||||||
|
shape_infer(this, input_shapes, output_shapes);
|
||||||
|
set_output_type(0, get_input_element_type(0), output_shapes[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
shared_ptr<Node> op::v6::Assign::clone_with_new_inputs(const OutputVector& new_args) const {
|
shared_ptr<Node> op::v6::Assign::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
#include "ngraph/op/experimental_detectron_prior_grid_generator.hpp"
|
#include "ngraph/op/experimental_detectron_prior_grid_generator.hpp"
|
||||||
|
|
||||||
|
#include <experimental_detectron_prior_grid_generator_shape_inference.hpp>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "itt.hpp"
|
#include "itt.hpp"
|
||||||
@ -49,71 +50,15 @@ static constexpr size_t priors_port = 0;
|
|||||||
static constexpr size_t featmap_port = 1;
|
static constexpr size_t featmap_port = 1;
|
||||||
static constexpr size_t im_data_port = 2;
|
static constexpr size_t im_data_port = 2;
|
||||||
|
|
||||||
void op::v6::ExperimentalDetectronPriorGridGenerator::validate() {
|
|
||||||
auto priors_shape = get_input_partial_shape(priors_port);
|
|
||||||
auto featmap_shape = get_input_partial_shape(featmap_port);
|
|
||||||
auto im_data_shape = get_input_partial_shape(im_data_port);
|
|
||||||
|
|
||||||
if (priors_shape.rank().is_dynamic() || featmap_shape.rank().is_dynamic()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(this, priors_shape.rank().get_length() == 2, "Priors rank must be equal to 2.");
|
|
||||||
|
|
||||||
if (priors_shape[1].is_static()) {
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
priors_shape[1].is_static() && priors_shape[1].get_length() == 4u,
|
|
||||||
"The last dimension of the 'priors' input must be equal to 4. Got: ",
|
|
||||||
priors_shape[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(this, featmap_shape.rank().get_length() == 4, "Feature_map rank must be equal to 4.");
|
|
||||||
|
|
||||||
if (im_data_shape.rank().is_dynamic()) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(this, im_data_shape.rank().get_length() == 4, "Im_data rank must be equal to 4.");
|
|
||||||
|
|
||||||
const auto num_batches_featmap = featmap_shape[0];
|
|
||||||
const auto num_batches_im_data = im_data_shape[0];
|
|
||||||
const auto batches_intersection = num_batches_featmap & num_batches_im_data;
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
!batches_intersection.get_interval().empty(),
|
|
||||||
"The first dimension of both 'feature_map' and 'im_data' must match. "
|
|
||||||
"Feature_map: ",
|
|
||||||
num_batches_featmap,
|
|
||||||
"; Im_data: ",
|
|
||||||
num_batches_im_data);
|
|
||||||
}
|
|
||||||
|
|
||||||
void op::v6::ExperimentalDetectronPriorGridGenerator::validate_and_infer_types() {
|
void op::v6::ExperimentalDetectronPriorGridGenerator::validate_and_infer_types() {
|
||||||
NGRAPH_OP_SCOPE(v6_ExperimentalDetectronPriorGridGenerator_validate_and_infer_types);
|
NGRAPH_OP_SCOPE(v6_ExperimentalDetectronPriorGridGenerator_validate_and_infer_types);
|
||||||
auto priors_shape = get_input_partial_shape(priors_port);
|
const auto& priors_shape = get_input_partial_shape(priors_port);
|
||||||
auto featmap_shape = get_input_partial_shape(featmap_port);
|
const auto& featmap_shape = get_input_partial_shape(featmap_port);
|
||||||
auto input_et = get_input_element_type(0);
|
const auto& input_et = get_input_element_type(0);
|
||||||
|
|
||||||
validate();
|
|
||||||
|
|
||||||
set_output_size(1);
|
set_output_size(1);
|
||||||
ov::PartialShape out_shape = {Dimension::dynamic(), Dimension::dynamic(), Dimension::dynamic(), 4};
|
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
|
||||||
if (m_attrs.flatten) {
|
std::vector<ov::PartialShape> input_shapes = {priors_shape, featmap_shape, get_input_partial_shape(im_data_port)};
|
||||||
out_shape = ov::PartialShape{Dimension::dynamic(), 4};
|
shape_infer(this, input_shapes, output_shapes);
|
||||||
}
|
set_output_type(0, input_et, output_shapes[0]);
|
||||||
|
|
||||||
if (priors_shape.rank().is_dynamic() || featmap_shape.rank().is_dynamic()) {
|
|
||||||
set_output_type(0, input_et, out_shape);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto num_priors = priors_shape[0];
|
|
||||||
auto featmap_height = featmap_shape[2];
|
|
||||||
auto featmap_width = featmap_shape[3];
|
|
||||||
|
|
||||||
if (m_attrs.flatten) {
|
|
||||||
out_shape = ov::PartialShape{featmap_height * featmap_width * num_priors, 4};
|
|
||||||
} else {
|
|
||||||
out_shape = ov::PartialShape{featmap_height, featmap_width, num_priors, 4};
|
|
||||||
}
|
|
||||||
set_output_type(0, input_et, out_shape);
|
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include <lstm_cell_shape_inference.hpp>
|
||||||
|
|
||||||
#include "itt.hpp"
|
#include "itt.hpp"
|
||||||
#include "ngraph/attribute_visitor.hpp"
|
#include "ngraph/attribute_visitor.hpp"
|
||||||
@ -139,30 +140,7 @@ void op::v0::LSTMCell::validate_and_infer_types() {
|
|||||||
set_argument(6, get_default_peepholes_input());
|
set_argument(6, get_default_peepholes_input());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto& input : inputs()) {
|
|
||||||
if (input.get_partial_shape().rank().is_dynamic()) {
|
|
||||||
set_output_type(0, get_input_element_type(0), ov::PartialShape::dynamic());
|
|
||||||
set_output_type(1, get_input_element_type(0), ov::PartialShape::dynamic());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<ov::PartialShape> input_param{};
|
|
||||||
|
|
||||||
auto merged_batch_size = Dimension::dynamic();
|
|
||||||
auto merged_hidden_size = Dimension::dynamic();
|
|
||||||
auto result_et = element::dynamic;
|
auto result_et = element::dynamic;
|
||||||
|
|
||||||
// Copy all inputs without peephole (7th input) and initial_cell_state (2nd input)
|
|
||||||
// information
|
|
||||||
// for further validation
|
|
||||||
for (size_t i = 0; i < get_input_size() - 1; i++) {
|
|
||||||
// exclude initial_cell_state input
|
|
||||||
if (i != 2) {
|
|
||||||
input_param.push_back(get_input_partial_shape(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get input partial shape for all inputs
|
// Get input partial shape for all inputs
|
||||||
const auto& x_pshape = get_input_partial_shape(0);
|
const auto& x_pshape = get_input_partial_shape(0);
|
||||||
const auto& ht_pshape = get_input_partial_shape(1);
|
const auto& ht_pshape = get_input_partial_shape(1);
|
||||||
@ -172,24 +150,6 @@ void op::v0::LSTMCell::validate_and_infer_types() {
|
|||||||
const auto& b_pshape = get_input_partial_shape(5);
|
const auto& b_pshape = get_input_partial_shape(5);
|
||||||
const auto& p_pshape = get_input_partial_shape(6);
|
const auto& p_pshape = get_input_partial_shape(6);
|
||||||
|
|
||||||
validate_input_rank_dimension(input_param);
|
|
||||||
|
|
||||||
// Validate rank and dimension for initial_cell_state input
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
(ct_pshape.rank().is_static()),
|
|
||||||
"LSTMCell input tensor initial_cell_state shall have static rank.");
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
(ct_pshape.rank().get_length() == 2),
|
|
||||||
"LSTMCell input tensor initial_cell_state shall have dimension 2D.");
|
|
||||||
|
|
||||||
// Validate rank and dimension for P input
|
|
||||||
NODE_VALIDATION_CHECK(this, (p_pshape.rank().is_static()), "LSTMCell input tensor P shall have static rank.");
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
(p_pshape.rank().get_length() == 1),
|
|
||||||
"LSTMCell input tensor P shall have dimension 1D.");
|
|
||||||
|
|
||||||
// Validate input element types and save result for output type
|
// Validate input element types and save result for output type
|
||||||
NODE_VALIDATION_CHECK(this,
|
NODE_VALIDATION_CHECK(this,
|
||||||
element::Type::merge(result_et, result_et, get_input_element_type(0)) &&
|
element::Type::merge(result_et, result_et, get_input_element_type(0)) &&
|
||||||
@ -201,65 +161,10 @@ void op::v0::LSTMCell::validate_and_infer_types() {
|
|||||||
"Element types for X, initial_hidden_state, initial_cell_state, W, R and B do not "
|
"Element types for X, initial_hidden_state, initial_cell_state, W, R and B do not "
|
||||||
"match.");
|
"match.");
|
||||||
|
|
||||||
// Merge batch_size dimension across all inputs to evaluate output[0] dimension
|
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}, ov::PartialShape{}};
|
||||||
NODE_VALIDATION_CHECK(this,
|
std::vector<ov::PartialShape> input_shapes =
|
||||||
Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) &&
|
{x_pshape, ht_pshape, ct_pshape, w_pshape, r_pshape, b_pshape, p_pshape};
|
||||||
Dimension::merge(merged_batch_size, merged_batch_size, ct_pshape[0]) &&
|
shape_infer(this, input_shapes, output_shapes);
|
||||||
Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]),
|
|
||||||
"Parameter batch_size not matched for X, initial_hidden_state or initial_cell_state "
|
|
||||||
"inputs.");
|
|
||||||
|
|
||||||
// Merge hidden_size dimension across all inputs to evaluate output[1] dimension
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[1]) &&
|
|
||||||
Dimension::merge(merged_hidden_size, merged_hidden_size, ct_pshape[1]) &&
|
|
||||||
Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[1]),
|
|
||||||
"Parameter hidden_size not matched for R, initial_hidden_state and initial_cell_state "
|
|
||||||
"inputs.");
|
|
||||||
|
|
||||||
// Validate hidden_size value for W, R and P inputs
|
|
||||||
if (merged_hidden_size.is_static()) {
|
|
||||||
if (w_pshape[0].is_static()) {
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
w_pshape[0].compatible(merged_hidden_size * s_gates_count),
|
|
||||||
"Parameter hidden_size mistmatched in W input. Current value is: ",
|
|
||||||
w_pshape[0].get_length(),
|
|
||||||
", expected: ",
|
|
||||||
merged_hidden_size.get_length() * s_gates_count,
|
|
||||||
".");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (r_pshape[0].is_static()) {
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
r_pshape[0].compatible(merged_hidden_size * s_gates_count),
|
|
||||||
"Parameter hidden_size mistmatched in R input. Current value is: ",
|
|
||||||
r_pshape[0].get_length(),
|
|
||||||
", expected: ",
|
|
||||||
merged_hidden_size.get_length() * s_gates_count,
|
|
||||||
".");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (b_pshape[0].is_static()) {
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
b_pshape[0].compatible(merged_hidden_size * s_gates_count),
|
|
||||||
"Parameter hidden_size mistmatched in B input. Current value is: ",
|
|
||||||
b_pshape[0].get_length(),
|
|
||||||
", expected: ",
|
|
||||||
merged_hidden_size.get_length() * s_gates_count,
|
|
||||||
".");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (p_pshape[0].is_static()) {
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
p_pshape[0].compatible(merged_hidden_size * s_peepholes_count),
|
|
||||||
"Parameter hidden_size mistmatched in P input. Current value is: ",
|
|
||||||
p_pshape[0].get_length(),
|
|
||||||
", expected: ",
|
|
||||||
merged_hidden_size.get_length() * s_peepholes_count,
|
|
||||||
".");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark inputs which are relevant to output parameters
|
// Mark inputs which are relevant to output parameters
|
||||||
set_input_is_relevant_to_shape(0);
|
set_input_is_relevant_to_shape(0);
|
||||||
set_input_is_relevant_to_shape(1);
|
set_input_is_relevant_to_shape(1);
|
||||||
@ -268,8 +173,8 @@ void op::v0::LSTMCell::validate_and_infer_types() {
|
|||||||
|
|
||||||
// Set output size, type and shape
|
// Set output size, type and shape
|
||||||
set_output_size(2);
|
set_output_size(2);
|
||||||
set_output_type(0, result_et, {merged_batch_size, merged_hidden_size});
|
set_output_type(0, result_et, output_shapes[0]);
|
||||||
set_output_type(1, result_et, {merged_batch_size, merged_hidden_size});
|
set_output_type(1, result_et, output_shapes[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
Output<Node> op::v0::LSTMCell::get_default_bias_input() const {
|
Output<Node> op::v0::LSTMCell::get_default_bias_input() const {
|
||||||
@ -414,15 +319,7 @@ bool ngraph::op::v4::LSTMCell::visit_attributes(AttributeVisitor& visitor) {
|
|||||||
|
|
||||||
void op::v4::LSTMCell::validate_and_infer_types() {
|
void op::v4::LSTMCell::validate_and_infer_types() {
|
||||||
NGRAPH_OP_SCOPE(v4_LSTMCell_validate_and_infer_types);
|
NGRAPH_OP_SCOPE(v4_LSTMCell_validate_and_infer_types);
|
||||||
for (const auto& input : inputs()) {
|
|
||||||
if (input.get_partial_shape().rank().is_dynamic()) {
|
|
||||||
set_output_type(0, get_input_element_type(0), ov::PartialShape::dynamic());
|
|
||||||
set_output_type(1, get_input_element_type(0), ov::PartialShape::dynamic());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
auto merged_batch_size = Dimension::dynamic();
|
|
||||||
auto merged_hidden_size = Dimension::dynamic();
|
|
||||||
auto result_et = element::dynamic;
|
auto result_et = element::dynamic;
|
||||||
|
|
||||||
// Get input partial shape for all inputs
|
// Get input partial shape for all inputs
|
||||||
@ -433,12 +330,6 @@ void op::v4::LSTMCell::validate_and_infer_types() {
|
|||||||
const auto& r_pshape = get_input_partial_shape(4);
|
const auto& r_pshape = get_input_partial_shape(4);
|
||||||
const auto& b_pshape = get_input_partial_shape(5);
|
const auto& b_pshape = get_input_partial_shape(5);
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
(ct_pshape.rank().get_length() == 2),
|
|
||||||
"LSTMCell input tensor initial_cell_state shall have dimension 2D.");
|
|
||||||
|
|
||||||
validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape});
|
|
||||||
|
|
||||||
// Validate input element types and save result for output type
|
// Validate input element types and save result for output type
|
||||||
NODE_VALIDATION_CHECK(this,
|
NODE_VALIDATION_CHECK(this,
|
||||||
element::Type::merge(result_et, result_et, get_input_element_type(0)) &&
|
element::Type::merge(result_et, result_et, get_input_element_type(0)) &&
|
||||||
@ -450,54 +341,9 @@ void op::v4::LSTMCell::validate_and_infer_types() {
|
|||||||
"Element types for X, initial_hidden_state, initial_cell_state, W, R and B do not "
|
"Element types for X, initial_hidden_state, initial_cell_state, W, R and B do not "
|
||||||
"match.");
|
"match.");
|
||||||
|
|
||||||
// Merge batch_size dimension across all inputs to evaluate output[0] dimension
|
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}, ov::PartialShape{}};
|
||||||
NODE_VALIDATION_CHECK(this,
|
std::vector<ov::PartialShape> input_shapes = {x_pshape, ht_pshape, ct_pshape, w_pshape, r_pshape, b_pshape};
|
||||||
Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) &&
|
shape_infer(this, input_shapes, output_shapes);
|
||||||
Dimension::merge(merged_batch_size, merged_batch_size, ct_pshape[0]) &&
|
|
||||||
Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]),
|
|
||||||
"Parameter batch_size not matched for X, initial_hidden_state or initial_cell_state "
|
|
||||||
"inputs.");
|
|
||||||
|
|
||||||
// Merge hidden_size dimension across all inputs to evaluate output[1] dimension
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[1]) &&
|
|
||||||
Dimension::merge(merged_hidden_size, merged_hidden_size, ct_pshape[1]) &&
|
|
||||||
Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[1]),
|
|
||||||
"Parameter hidden_size not matched for R, initial_hidden_state and initial_cell_state "
|
|
||||||
"inputs.");
|
|
||||||
|
|
||||||
// Validate hidden_size value for W, R and P inputs
|
|
||||||
if (merged_hidden_size.is_static()) {
|
|
||||||
if (w_pshape[0].is_static()) {
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
w_pshape[0].compatible(merged_hidden_size * s_gates_count),
|
|
||||||
"Parameter hidden_size mistmatched in W input. Current value is: ",
|
|
||||||
w_pshape[0].get_length(),
|
|
||||||
", expected: ",
|
|
||||||
merged_hidden_size.get_length() * s_gates_count,
|
|
||||||
".");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (r_pshape[0].is_static()) {
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
r_pshape[0].compatible(merged_hidden_size * s_gates_count),
|
|
||||||
"Parameter hidden_size mistmatched in R input. Current value is: ",
|
|
||||||
r_pshape[0].get_length(),
|
|
||||||
", expected: ",
|
|
||||||
merged_hidden_size.get_length() * s_gates_count,
|
|
||||||
".");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (b_pshape[0].is_static()) {
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
b_pshape[0].compatible(merged_hidden_size * s_gates_count),
|
|
||||||
"Parameter hidden_size mistmatched in B input. Current value is: ",
|
|
||||||
b_pshape[0].get_length(),
|
|
||||||
", expected: ",
|
|
||||||
merged_hidden_size.get_length() * s_gates_count,
|
|
||||||
".");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mark inputs which are relevant to output parameters
|
// Mark inputs which are relevant to output parameters
|
||||||
set_input_is_relevant_to_shape(0);
|
set_input_is_relevant_to_shape(0);
|
||||||
@ -507,8 +353,8 @@ void op::v4::LSTMCell::validate_and_infer_types() {
|
|||||||
|
|
||||||
// Set output size, type and shape
|
// Set output size, type and shape
|
||||||
set_output_size(2);
|
set_output_size(2);
|
||||||
set_output_type(0, result_et, {merged_batch_size, merged_hidden_size});
|
set_output_type(0, result_et, output_shapes[0]);
|
||||||
set_output_type(1, result_et, {merged_batch_size, merged_hidden_size});
|
set_output_type(1, result_et, output_shapes[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
Output<Node> op::v4::LSTMCell::get_default_bias_input() const {
|
Output<Node> op::v4::LSTMCell::get_default_bias_input() const {
|
||||||
|
@ -4,6 +4,8 @@
|
|||||||
|
|
||||||
#include "ngraph/op/read_value.hpp"
|
#include "ngraph/op/read_value.hpp"
|
||||||
|
|
||||||
|
#include <read_value_shape_inference.hpp>
|
||||||
|
|
||||||
#include "itt.hpp"
|
#include "itt.hpp"
|
||||||
#include "ngraph/op/util/variable_context.hpp"
|
#include "ngraph/op/util/variable_context.hpp"
|
||||||
#include "ngraph/ops.hpp"
|
#include "ngraph/ops.hpp"
|
||||||
@ -23,8 +25,13 @@ op::v3::ReadValue::ReadValue(const Output<Node>& init_value, const std::string&
|
|||||||
void op::v3::ReadValue::validate_and_infer_types() {
|
void op::v3::ReadValue::validate_and_infer_types() {
|
||||||
NGRAPH_OP_SCOPE(v3_ReadValue_validate_and_infer_types);
|
NGRAPH_OP_SCOPE(v3_ReadValue_validate_and_infer_types);
|
||||||
auto arg_t = get_input_element_type(0);
|
auto arg_t = get_input_element_type(0);
|
||||||
auto output_shape = get_input_partial_shape(0);
|
auto input_shape = get_input_partial_shape(0);
|
||||||
|
|
||||||
|
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
|
||||||
|
std::vector<ov::PartialShape> input_shapes = {input_shape};
|
||||||
|
shape_infer(this, input_shapes, output_shapes);
|
||||||
|
|
||||||
|
const auto& output_shape = output_shapes[0];
|
||||||
VariableInfo info = {output_shape, arg_t, m_variable_id};
|
VariableInfo info = {output_shape, arg_t, m_variable_id};
|
||||||
if (m_variable == nullptr)
|
if (m_variable == nullptr)
|
||||||
m_variable = std::make_shared<Variable>(info);
|
m_variable = std::make_shared<Variable>(info);
|
||||||
@ -54,7 +61,11 @@ op::v6::ReadValue::ReadValue(const Output<Node>& init_value, const shared_ptr<Va
|
|||||||
void op::v6::ReadValue::validate_and_infer_types() {
|
void op::v6::ReadValue::validate_and_infer_types() {
|
||||||
NGRAPH_OP_SCOPE(v6_ReadValue_validate_and_infer_types);
|
NGRAPH_OP_SCOPE(v6_ReadValue_validate_and_infer_types);
|
||||||
const auto arg_t = get_input_element_type(0);
|
const auto arg_t = get_input_element_type(0);
|
||||||
auto output_shape = get_input_partial_shape(0);
|
auto input_shape = get_input_partial_shape(0);
|
||||||
|
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
|
||||||
|
std::vector<ov::PartialShape> input_shapes = {input_shape};
|
||||||
|
shape_infer(this, input_shapes, output_shapes);
|
||||||
|
const auto& output_shape = output_shapes[0];
|
||||||
NGRAPH_CHECK(m_variable, "Variable is not initialized.");
|
NGRAPH_CHECK(m_variable, "Variable is not initialized.");
|
||||||
VariableInfo var_info = {output_shape, element::dynamic, m_variable->get_info().variable_id};
|
VariableInfo var_info = {output_shape, element::dynamic, m_variable->get_info().variable_id};
|
||||||
NODE_VALIDATION_CHECK(this,
|
NODE_VALIDATION_CHECK(this,
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include "ngraph/op/tile.hpp"
|
#include "ngraph/op/tile.hpp"
|
||||||
|
|
||||||
#include <ngraph/validation_util.hpp>
|
#include <ngraph/validation_util.hpp>
|
||||||
|
#include <tile_shape_inference.hpp>
|
||||||
|
|
||||||
#include "itt.hpp"
|
#include "itt.hpp"
|
||||||
#include "ngraph/op/constant.hpp"
|
#include "ngraph/op/constant.hpp"
|
||||||
@ -37,37 +38,10 @@ void op::v0::Tile::validate_and_infer_types() {
|
|||||||
"Tile repeats must have any integer element type, but has ",
|
"Tile repeats must have any integer element type, but has ",
|
||||||
repeats_et);
|
repeats_et);
|
||||||
|
|
||||||
auto arg_shape = get_input_partial_shape(0);
|
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
|
||||||
auto repeats_shape = get_input_partial_shape(1);
|
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0), get_input_partial_shape(1)};
|
||||||
NODE_VALIDATION_CHECK(this, repeats_shape.rank().compatible(1), "PartialShape of repeats must be of rank 1");
|
shape_infer(this, input_shapes, output_shapes);
|
||||||
ov::PartialShape repeats_as_pshape;
|
set_output_type(0, arg_et, output_shapes[0]);
|
||||||
bool repeats_are_known = evaluate_as_partial_shape(get_input_source_output(1), repeats_as_pshape);
|
|
||||||
std::vector<Dimension> repeats_value(repeats_as_pshape);
|
|
||||||
if (repeats_are_known && !repeats_value.empty() && arg_shape.rank().is_static()) {
|
|
||||||
std::vector<Dimension> data_shape(arg_shape);
|
|
||||||
auto data_rank = data_shape.size();
|
|
||||||
auto repeats_rank = repeats_value.size();
|
|
||||||
auto output_rank = std::max(data_rank, repeats_rank);
|
|
||||||
|
|
||||||
// expand data shape and repeats to output rank
|
|
||||||
data_shape.insert(data_shape.begin(), output_rank - data_rank, 1);
|
|
||||||
repeats_value.insert(repeats_value.begin(), output_rank - repeats_rank, 1);
|
|
||||||
|
|
||||||
auto output_shape = ov::PartialShape::dynamic(output_rank);
|
|
||||||
for (size_t i = 0; i < output_rank; i++)
|
|
||||||
output_shape[i] = data_shape[i] * repeats_value[i];
|
|
||||||
set_output_type(0, arg_et, output_shape);
|
|
||||||
} else {
|
|
||||||
Rank outRank = Rank::dynamic();
|
|
||||||
if (arg_shape.rank().is_static() && repeats_shape.is_static()) {
|
|
||||||
std::vector<Dimension> data_shape(arg_shape);
|
|
||||||
auto data_rank = data_shape.size();
|
|
||||||
auto repeats_rank = repeats_value.size();
|
|
||||||
auto output_rank = std::max(data_rank, repeats_rank);
|
|
||||||
outRank = Rank(output_rank);
|
|
||||||
}
|
|
||||||
set_output_type(0, arg_et, ov::PartialShape::dynamic(outRank));
|
|
||||||
}
|
|
||||||
|
|
||||||
set_input_is_relevant_to_shape(0);
|
set_input_is_relevant_to_shape(0);
|
||||||
set_input_is_relevant_to_shape(1);
|
set_input_is_relevant_to_shape(1);
|
||||||
@ -84,24 +58,16 @@ bool op::v0::Tile::evaluate_tile(const HostTensorVector& outputs, const HostTens
|
|||||||
const auto& axis = inputs[1];
|
const auto& axis = inputs[1];
|
||||||
auto& output = outputs[0];
|
auto& output = outputs[0];
|
||||||
auto repeats_val = read_index_vector(axis);
|
auto repeats_val = read_index_vector(axis);
|
||||||
auto repeats_rank = repeats_val.size();
|
const auto repeats_rank = repeats_val.size();
|
||||||
ov::Shape data_shape = data->get_shape();
|
|
||||||
auto data_rank = data_shape.size();
|
|
||||||
auto output_rank = std::max(data_rank, repeats_rank);
|
|
||||||
|
|
||||||
// expand data shape and repeats to output rank
|
|
||||||
data_shape.insert(data_shape.begin(), output_rank - data_rank, 1);
|
|
||||||
repeats_val.insert(repeats_val.begin(), output_rank - repeats_rank, 1);
|
|
||||||
|
|
||||||
ov::Shape output_shape(output_rank);
|
|
||||||
for (size_t i = 0; i < output_rank; i++) {
|
|
||||||
output_shape[i] = data_shape[i] * repeats_val[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
|
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{}};
|
||||||
|
std::vector<ov::PartialShape> input_shapes = {data->get_shape(), axis->get_shape()};
|
||||||
|
shape_infer(this, input_shapes, output_shapes, {{1, axis}});
|
||||||
|
const auto& output_shape = output_shapes[0].to_shape();
|
||||||
if (!output->get_is_allocated()) {
|
if (!output->get_is_allocated()) {
|
||||||
output->set_shape(output_shape);
|
output->set_shape(output_shape);
|
||||||
}
|
}
|
||||||
|
repeats_val.insert(repeats_val.begin(), output_shape.size() - repeats_rank, 1);
|
||||||
ngraph::runtime::reference::tile(data->get_data_ptr<const char>(),
|
ngraph::runtime::reference::tile(data->get_data_ptr<const char>(),
|
||||||
output->get_data_ptr<char>(),
|
output->get_data_ptr<char>(),
|
||||||
data->get_shape(),
|
data->get_shape(),
|
||||||
|
@ -53,7 +53,9 @@ TEST(type_prop, lstm_cell_invalid_input) {
|
|||||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||||
FAIL() << "LSTMCell node was created with invalid data.";
|
FAIL() << "LSTMCell node was created with invalid data.";
|
||||||
} catch (const NodeValidationFailure& error) {
|
} catch (const NodeValidationFailure& error) {
|
||||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size mistmatched in W input."));
|
EXPECT_HAS_SUBSTRING(
|
||||||
|
error.what(),
|
||||||
|
std::string("Parameter hidden_size not matched for W, R, B, initial_hidden_state and initial_cell_state"));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invalid R tensor shape.
|
// Invalid R tensor shape.
|
||||||
@ -64,7 +66,7 @@ TEST(type_prop, lstm_cell_invalid_input) {
|
|||||||
FAIL() << "LSTMCell node was created with invalid data.";
|
FAIL() << "LSTMCell node was created with invalid data.";
|
||||||
} catch (const NodeValidationFailure& error) {
|
} catch (const NodeValidationFailure& error) {
|
||||||
EXPECT_HAS_SUBSTRING(error.what(),
|
EXPECT_HAS_SUBSTRING(error.what(),
|
||||||
std::string("Parameter hidden_size not matched for R, "
|
std::string("Parameter hidden_size not matched for W, R, B, "
|
||||||
"initial_hidden_state and initial_cell_state inputs."));
|
"initial_hidden_state and initial_cell_state inputs."));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,7 +102,7 @@ TEST(type_prop, lstm_cell_invalid_input) {
|
|||||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size);
|
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size);
|
||||||
FAIL() << "LSTMCell node was created with invalid data.";
|
FAIL() << "LSTMCell node was created with invalid data.";
|
||||||
} catch (const NodeValidationFailure& error) {
|
} catch (const NodeValidationFailure& error) {
|
||||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size mistmatched in B input."));
|
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size not matched for W, R, B"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,8 +140,8 @@ TEST(type_prop, lstm_cell_dynamic_hidden_size) {
|
|||||||
|
|
||||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, 3);
|
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, 3);
|
||||||
|
|
||||||
EXPECT_EQ(lstm_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
|
EXPECT_EQ(lstm_cell->get_output_partial_shape(0), (PartialShape{batch_size, 3}));
|
||||||
EXPECT_EQ(lstm_cell->get_output_partial_shape(1), (PartialShape{batch_size, hidden_size}));
|
EXPECT_EQ(lstm_cell->get_output_partial_shape(1), (PartialShape{batch_size, 3}));
|
||||||
EXPECT_EQ(lstm_cell->get_output_element_type(0), element::f32);
|
EXPECT_EQ(lstm_cell->get_output_element_type(0), element::f32);
|
||||||
EXPECT_EQ(lstm_cell->get_output_element_type(1), element::f32);
|
EXPECT_EQ(lstm_cell->get_output_element_type(1), element::f32);
|
||||||
}
|
}
|
||||||
@ -158,8 +160,8 @@ TEST(type_prop, lstm_cell_dynamic_inputs) {
|
|||||||
|
|
||||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, 3);
|
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, 3);
|
||||||
|
|
||||||
EXPECT_EQ(lstm_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
|
EXPECT_EQ(lstm_cell->get_output_partial_shape(0), (PartialShape{batch_size, 3}));
|
||||||
EXPECT_EQ(lstm_cell->get_output_partial_shape(1), (PartialShape{batch_size, hidden_size}));
|
EXPECT_EQ(lstm_cell->get_output_partial_shape(1), (PartialShape{batch_size, 3}));
|
||||||
EXPECT_EQ(lstm_cell->get_output_element_type(0), element::f32);
|
EXPECT_EQ(lstm_cell->get_output_element_type(0), element::f32);
|
||||||
EXPECT_EQ(lstm_cell->get_output_element_type(1), element::f32);
|
EXPECT_EQ(lstm_cell->get_output_element_type(1), element::f32);
|
||||||
}
|
}
|
||||||
@ -224,9 +226,11 @@ TEST(type_prop, lstm_cell_invalid_input_dynamic_rank) {
|
|||||||
auto H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
auto H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||||
auto C_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
auto C_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||||
|
|
||||||
auto check_dynamic_lstm = [](const shared_ptr<opset4::LSTMCell>& lstm) -> bool {
|
auto check_dynamic_lstm = [=](const shared_ptr<opset4::LSTMCell>& lstm) -> bool {
|
||||||
return lstm->output(0).get_partial_shape() == PartialShape::dynamic() &&
|
const int64_t target_batch_size = batch_size;
|
||||||
lstm->output(1).get_partial_shape() == PartialShape::dynamic() &&
|
const int64_t target_hidden_size = hidden_size;
|
||||||
|
return lstm->output(0).get_partial_shape() == PartialShape{target_batch_size, target_hidden_size} &&
|
||||||
|
lstm->output(1).get_partial_shape() == PartialShape{target_batch_size, target_hidden_size} &&
|
||||||
lstm->output(0).get_element_type() == lstm->input(0).get_element_type();
|
lstm->output(0).get_element_type() == lstm->input(0).get_element_type();
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -265,3 +269,61 @@ TEST(type_prop, lstm_cell_invalid_input_dynamic_rank) {
|
|||||||
lstm = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size);
|
lstm = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size);
|
||||||
EXPECT_EQ(check_dynamic_lstm(lstm), true);
|
EXPECT_EQ(check_dynamic_lstm(lstm), true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, lstm_cell_shape_from_partial) {
|
||||||
|
const size_t batch_size = 2;
|
||||||
|
const size_t input_size = 3;
|
||||||
|
const size_t hidden_size = 3;
|
||||||
|
const size_t gates_count = 4;
|
||||||
|
|
||||||
|
auto check_dynamic_lstm = [=](const shared_ptr<opset4::LSTMCell>& lstm) -> bool {
|
||||||
|
const int64_t target_batch_size = batch_size;
|
||||||
|
const int64_t target_hidden_size = hidden_size;
|
||||||
|
return lstm->output(0).get_partial_shape() == PartialShape{target_batch_size, target_hidden_size} &&
|
||||||
|
lstm->output(1).get_partial_shape() == PartialShape{target_batch_size, target_hidden_size} &&
|
||||||
|
lstm->output(0).get_element_type() == lstm->input(0).get_element_type();
|
||||||
|
};
|
||||||
|
{
|
||||||
|
// from h & w
|
||||||
|
auto X = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||||
|
auto W = make_shared<opset4::Parameter>(element::f32, PartialShape{gates_count * hidden_size, input_size});
|
||||||
|
auto R = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||||
|
auto H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, -1});
|
||||||
|
auto C_t = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||||
|
auto lstm = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||||
|
EXPECT_EQ(check_dynamic_lstm(lstm), true);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// from x & w
|
||||||
|
auto X = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||||
|
auto W = make_shared<opset4::Parameter>(element::f32, PartialShape{gates_count * hidden_size, input_size});
|
||||||
|
auto R = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||||
|
auto H_t = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||||
|
auto C_t = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||||
|
auto lstm = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||||
|
EXPECT_EQ(check_dynamic_lstm(lstm), true);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// only valid rank for H_t tensor.
|
||||||
|
auto X = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||||
|
auto W = make_shared<opset4::Parameter>(element::f32, PartialShape{gates_count * hidden_size, input_size});
|
||||||
|
auto R = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||||
|
auto H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||||
|
auto C_t = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||||
|
auto lstm = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||||
|
EXPECT_EQ(check_dynamic_lstm(lstm), true);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// batch from x, hidden from h_t
|
||||||
|
auto X = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||||
|
auto W = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||||
|
auto R = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||||
|
auto H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{-1, hidden_size});
|
||||||
|
auto C_t = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||||
|
auto lstm = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||||
|
EXPECT_EQ(check_dynamic_lstm(lstm), true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -40,3 +40,11 @@ TEST(type_prop, tile_few_repeats_dyn_input) {
|
|||||||
ASSERT_EQ(top->get_element_type(), element::f32);
|
ASSERT_EQ(top->get_element_type(), element::f32);
|
||||||
ASSERT_EQ(top->get_output_partial_shape(0), (PartialShape{6, Dimension(32, 40), 10}));
|
ASSERT_EQ(top->get_output_partial_shape(0), (PartialShape{6, Dimension(32, 40), 10}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, tile_out_rank_from_repeats) {
|
||||||
|
auto param0 = make_shared<op::Parameter>(element::f32, Shape{6, 8, 10});
|
||||||
|
auto param1 = make_shared<op::Parameter>(element::i32, Shape{5});
|
||||||
|
auto top = make_shared<op::v0::Tile>(param0, param1);
|
||||||
|
ASSERT_EQ(top->get_element_type(), element::f32);
|
||||||
|
ASSERT_EQ(top->get_output_partial_shape(0).size(), 5);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user