[core]Migrate Concat operator to new API (#20600)

* Migrate Concat op to new API

* Move shape validation to shape_infer

* Fix getting concat axis in shape inference

---------

Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
Pawel Raasz 2023-10-30 12:07:36 +01:00 committed by GitHub
parent 4512141111
commit 7cfeb413d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 127 deletions

View File

@ -50,9 +50,6 @@ public:
void set_axis(int64_t axis) {
m_axis = axis;
}
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool has_evaluate() const override;
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool evaluate_lower(TensorVector& output_values) const override;

View File

@ -14,9 +14,10 @@ namespace v0 {
template <class T, class TRShape = result_shape_t<T>>
std::vector<TRShape> shape_infer(const Concat* op, const std::vector<T>& input_shapes) {
NODE_VALIDATION_CHECK(op, !input_shapes.empty());
using DimType = typename T::value_type;
const auto concat_axis = op->get_concatenation_axis();
auto concat_axis = op->get_concatenation_axis() < 0 ? op->get_axis() : op->get_concatenation_axis();
const auto empty_dim = DimType{};
auto concat_dim = DimType{0};
@ -27,21 +28,29 @@ std::vector<TRShape> shape_infer(const Concat* op, const std::vector<T>& input_s
output_shape = PartialShape::dynamic();
} else {
output_shape = input_shapes.front();
OPENVINO_SUPPRESS_DEPRECATED_START
concat_axis = ov::normalize_axis(op, concat_axis, output_shape.rank());
OPENVINO_SUPPRESS_DEPRECATED_END
output_shape[concat_axis] = empty_dim;
}
for (auto& input : input_shapes) {
if (input.rank().is_static()) {
const auto& input_rank = input.rank();
if (input_rank.is_static()) {
OPENVINO_SUPPRESS_DEPRECATED_START
concat_axis = ov::normalize_axis(op, concat_axis, input_rank);
OPENVINO_SUPPRESS_DEPRECATED_END
auto in_copy = TRShape(input);
concat_dim += in_copy[concat_axis];
in_copy[concat_axis] = empty_dim;
NODE_VALIDATION_CHECK(op,
TRShape::merge_into(output_shape, in_copy),
"Argument shapes are inconsistent; they must have the same rank, and must "
"have equal dimension everywhere except on the concatenation axis (axis ",
concat_axis,
").");
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
TRShape::merge_into(output_shape, in_copy),
"Argument shapes are inconsistent; they must have the same rank, and must "
"have equal dimension everywhere except on the concatenation axis (axis ",
concat_axis,
").");
} else {
concat_dim += empty_dim;
}

View File

@ -2,38 +2,33 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/concat.hpp"
#include <memory>
#include "openvino/op/concat.hpp"
#include "bound_evaluate.hpp"
#include "concat_shape_inference.hpp"
#include "itt.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/validation_util.hpp"
#include "openvino/core/dimension_tracker.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/reference/concat.hpp"
#include "validation_util.hpp"
using namespace std;
using namespace ngraph;
namespace ov {
namespace op {
namespace v0 {
op::Concat::Concat(const OutputVector& args, int64_t axis) : Op(args), m_axis(axis) {
Concat::Concat(const OutputVector& args, int64_t axis) : Op(args), m_axis(axis) {
constructor_validate_and_infer_types();
}
op::Concat::Concat(const NodeVector& args, int64_t axis) : Concat(as_output_vector(args), axis) {}
Concat::Concat(const NodeVector& args, int64_t axis) : Concat(as_output_vector(args), axis) {}
bool op::Concat::visit_attributes(AttributeVisitor& visitor) {
bool Concat::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v0_Concat_visit_attributes);
visitor.on_attribute("axis", m_axis);
return true;
}
void op::Concat::validate_and_infer_types() {
void Concat::validate_and_infer_types() {
OV_OP_SCOPE(v0_Concat_validate_and_infer_types);
NODE_VALIDATION_CHECK(this, get_input_size() >= 1, "At least one argument required.");
element::Type inputs_et{element::dynamic};
auto input_shapes = std::vector<PartialShape>();
@ -41,118 +36,68 @@ void op::Concat::validate_and_infer_types() {
NODE_VALIDATION_CHECK(this,
element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)),
"Argument element types are inconsistent.");
const auto& input_shape = get_input_partial_shape(i);
const auto& input_rank = input_shape.rank();
if (input_rank.is_static() && (get_concatenation_axis() < 0)) {
set_concatenation_axis(get_axis() < 0 ? get_axis() + input_rank.get_length() : get_axis());
}
const auto concat_axis = get_concatenation_axis();
NODE_VALIDATION_CHECK(this,
input_shape.is_dynamic() || (0 <= concat_axis && concat_axis < input_rank.get_length()),
"Concatenation axis (",
concat_axis,
") is out of bounds [",
-input_rank.get_length(),
", ",
input_rank.get_length() - 1,
"] for ",
"argument ",
i,
", which has shape ",
input_shape,
".");
input_shapes.push_back(input_shape);
input_shapes.push_back(get_input_partial_shape(i));
}
const auto output_shapes = shape_infer(this, input_shapes);
set_output_type(0, inputs_et, output_shapes.front());
const auto output_shape = shape_infer(this, input_shapes).front();
if (output_shape.rank().is_static() && (get_concatenation_axis() < 0)) {
set_concatenation_axis(ov::util::normalize(get_axis(), output_shape.size()));
}
set_output_type(0, inputs_et, output_shape);
}
shared_ptr<Node> op::Concat::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> Concat::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v0_Concat_clone_with_new_inputs);
return make_shared<Concat>(new_args, m_axis);
return std::make_shared<Concat>(new_args, m_axis);
}
OPENVINO_SUPPRESS_DEPRECATED_START
namespace {
bool evaluate_concat(const HostTensorVector& args, const HostTensorPtr& out, int64_t concatenation_axis) {
std::vector<const char*> arg_bufs;
std::vector<ov::Shape> arg_shapes;
ov::Shape out_shape(args[0]->get_shape());
out_shape[concatenation_axis] = 0;
for (auto& input : args) {
arg_bufs.push_back(input->get_data_ptr<char>());
arg_shapes.push_back(input->get_shape());
out_shape[concatenation_axis] += arg_shapes.back()[concatenation_axis];
}
out->set_shape(out_shape);
ov::reference::concat(arg_bufs,
out->get_data_ptr<char>(),
arg_shapes,
out_shape,
concatenation_axis,
out->get_element_type().size());
return true;
}
} // namespace
bool op::Concat::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
bool Concat::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v0_Concat_evaluate);
OPENVINO_ASSERT(!inputs.empty());
OPENVINO_ASSERT(validate_host_tensor_vector(inputs, inputs.size()));
OPENVINO_ASSERT(validate_host_tensor_vector(outputs, 1));
auto concat_axis = get_axis() < 0 ? get_axis() + inputs[0]->get_shape().size() : get_axis();
return evaluate_concat(inputs, outputs[0], concat_axis);
}
OPENVINO_SUPPRESS_DEPRECATED_END
bool op::Concat::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const {
OV_OP_SCOPE(v0_Concat_evaluate);
OPENVINO_ASSERT(!inputs.empty());
OPENVINO_ASSERT(outputs.size() == 1);
auto concat_axis = ov::util::normalize(get_axis(), inputs.front().get_shape().size());
const auto inputs_count = inputs.size();
std::vector<const char*> arg_bufs(inputs_count);
std::vector<Shape> arg_shapes;
std::vector<PartialShape> input_shapes;
arg_shapes.reserve(inputs_count);
input_shapes.reserve(inputs_count);
std::vector<const char*> arg_bufs;
std::vector<ov::Shape> arg_shapes;
ov::Shape out_shape(inputs.front().get_shape());
out_shape[concat_axis] = 0;
auto arg_buf = arg_bufs.begin();
for (auto& input : inputs) {
arg_bufs.push_back(static_cast<const char*>(input.data()));
arg_shapes.push_back(input.get_shape());
out_shape[concat_axis] += arg_shapes.back()[concat_axis];
*arg_buf = static_cast<const char*>(input.data());
++arg_buf;
const auto& input_shape = input.get_shape();
arg_shapes.emplace_back(input_shape);
input_shapes.emplace_back(input_shape);
}
const auto& out_shape = shape_infer(this, input_shapes).front().to_shape();
outputs.front().set_shape(out_shape);
ov::reference::concat(arg_bufs,
static_cast<char*>(outputs.front().data()),
arg_shapes,
out_shape,
concat_axis,
outputs.front().get_element_type().size());
reference::concat(arg_bufs,
static_cast<char*>(outputs.front().data()),
arg_shapes,
out_shape,
ov::util::normalize(get_axis(), out_shape.size()),
outputs.front().get_element_type().size());
return true;
}
bool op::Concat::has_evaluate() const {
bool Concat::has_evaluate() const {
OV_OP_SCOPE(v0_Concat_has_evaluate);
return true;
}
bool op::Concat::evaluate_lower(ov::TensorVector& output_values) const {
bool Concat::evaluate_lower(TensorVector& output_values) const {
return default_lower_bound_evaluator(this, output_values);
}
bool op::Concat::evaluate_upper(ov::TensorVector& output_values) const {
bool Concat::evaluate_upper(TensorVector& output_values) const {
return default_upper_bound_evaluator(this, output_values);
}
bool op::Concat::evaluate_label(TensorLabelVector& output_labels) const {
bool Concat::evaluate_label(TensorLabelVector& output_labels) const {
const auto& inputs = input_values();
if (std::all_of(inputs.cbegin(), inputs.cend(), [](const Output<Node>& out) {
const auto& labels = out.get_tensor().get_value_label();
@ -187,3 +132,6 @@ bool op::Concat::evaluate_label(TensorLabelVector& output_labels) const {
return false;
}
}
} // namespace v0
} // namespace op
} // namespace ov

View File

@ -6,6 +6,7 @@
#include <gmock/gmock.h>
#include "common_test_utils/test_assertions.hpp"
#include "common_test_utils/type_prop.hpp"
#include "openvino/core/dimension_tracker.hpp"
#include "openvino/op/broadcast.hpp"
@ -68,15 +69,10 @@ TEST(type_prop, concat_deduce_axis_oob) {
auto param0 = make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 3, 4});
auto param1 = make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 7, 4});
auto param2 = make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 2, 5});
try {
auto c = make_shared<ov::op::v0::Concat>(ov::NodeVector{param0, param1, param2}, 3);
// Should have thrown, so fail if it didn't
FAIL() << "Deduced type should disagree with specified type";
} catch (const ov::NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Concatenation axis (3) is out of bounds"));
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
}
OV_EXPECT_THROW(ignore = make_shared<ov::op::v0::Concat>(ov::NodeVector{param0, param1, param2}, 3),
ov::AssertFailure,
HasSubstr("Concat Parameter axis 3 out of the tensor rank range"));
}
TEST(type_prop, concat_deduce_axis_barely_in_bounds) {
@ -259,15 +255,9 @@ TEST(type_prop, concat_partial_negative_axis_incorrect) {
auto param1 = make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 7, 4});
auto param2 = make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 2, 4});
try {
auto c = make_shared<ov::op::v0::Concat>(ov::NodeVector{param0, param1, param2}, -4);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect negative axis value not detected (out of bounds)";
} catch (const ov::NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Concatenation axis (-1) is out of bounds"));
} catch (...) {
FAIL() << "Deduced type check failed for unexpected reason";
}
OV_EXPECT_THROW(ignore = make_shared<ov::op::v0::Concat>(ov::NodeVector{param0, param1, param2}, -4),
ov::AssertFailure,
HasSubstr("Concat Parameter axis -4 out of the tensor rank range"));
}
/** \brief Test uses evaluate lower/upper and label of concat op. */