[core]Migrate Concat operator to new API (#20600)
* Migrate Concat op to new API * Move shape validation to shape_infer * Fix getting concat axis in shape inference --------- Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
parent
4512141111
commit
7cfeb413d4
@ -50,9 +50,6 @@ public:
|
||||
void set_axis(int64_t axis) {
|
||||
m_axis = axis;
|
||||
}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
bool has_evaluate() const override;
|
||||
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||
bool evaluate_lower(TensorVector& output_values) const override;
|
||||
|
@ -14,9 +14,10 @@ namespace v0 {
|
||||
|
||||
template <class T, class TRShape = result_shape_t<T>>
|
||||
std::vector<TRShape> shape_infer(const Concat* op, const std::vector<T>& input_shapes) {
|
||||
NODE_VALIDATION_CHECK(op, !input_shapes.empty());
|
||||
using DimType = typename T::value_type;
|
||||
|
||||
const auto concat_axis = op->get_concatenation_axis();
|
||||
auto concat_axis = op->get_concatenation_axis() < 0 ? op->get_axis() : op->get_concatenation_axis();
|
||||
const auto empty_dim = DimType{};
|
||||
|
||||
auto concat_dim = DimType{0};
|
||||
@ -27,21 +28,29 @@ std::vector<TRShape> shape_infer(const Concat* op, const std::vector<T>& input_s
|
||||
output_shape = PartialShape::dynamic();
|
||||
} else {
|
||||
output_shape = input_shapes.front();
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
concat_axis = ov::normalize_axis(op, concat_axis, output_shape.rank());
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
output_shape[concat_axis] = empty_dim;
|
||||
}
|
||||
|
||||
for (auto& input : input_shapes) {
|
||||
if (input.rank().is_static()) {
|
||||
const auto& input_rank = input.rank();
|
||||
if (input_rank.is_static()) {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
concat_axis = ov::normalize_axis(op, concat_axis, input_rank);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
auto in_copy = TRShape(input);
|
||||
concat_dim += in_copy[concat_axis];
|
||||
in_copy[concat_axis] = empty_dim;
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
TRShape::merge_into(output_shape, in_copy),
|
||||
"Argument shapes are inconsistent; they must have the same rank, and must "
|
||||
"have equal dimension everywhere except on the concatenation axis (axis ",
|
||||
concat_axis,
|
||||
").");
|
||||
NODE_SHAPE_INFER_CHECK(op,
|
||||
input_shapes,
|
||||
TRShape::merge_into(output_shape, in_copy),
|
||||
"Argument shapes are inconsistent; they must have the same rank, and must "
|
||||
"have equal dimension everywhere except on the concatenation axis (axis ",
|
||||
concat_axis,
|
||||
").");
|
||||
} else {
|
||||
concat_dim += empty_dim;
|
||||
}
|
||||
|
@ -2,38 +2,33 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/op/concat.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include "openvino/op/concat.hpp"
|
||||
|
||||
#include "bound_evaluate.hpp"
|
||||
#include "concat_shape_inference.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "openvino/core/dimension_tracker.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/reference/concat.hpp"
|
||||
#include "validation_util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace v0 {
|
||||
|
||||
op::Concat::Concat(const OutputVector& args, int64_t axis) : Op(args), m_axis(axis) {
|
||||
Concat::Concat(const OutputVector& args, int64_t axis) : Op(args), m_axis(axis) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
op::Concat::Concat(const NodeVector& args, int64_t axis) : Concat(as_output_vector(args), axis) {}
|
||||
Concat::Concat(const NodeVector& args, int64_t axis) : Concat(as_output_vector(args), axis) {}
|
||||
|
||||
bool op::Concat::visit_attributes(AttributeVisitor& visitor) {
|
||||
bool Concat::visit_attributes(AttributeVisitor& visitor) {
|
||||
OV_OP_SCOPE(v0_Concat_visit_attributes);
|
||||
visitor.on_attribute("axis", m_axis);
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::Concat::validate_and_infer_types() {
|
||||
void Concat::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v0_Concat_validate_and_infer_types);
|
||||
NODE_VALIDATION_CHECK(this, get_input_size() >= 1, "At least one argument required.");
|
||||
|
||||
element::Type inputs_et{element::dynamic};
|
||||
auto input_shapes = std::vector<PartialShape>();
|
||||
|
||||
@ -41,118 +36,68 @@ void op::Concat::validate_and_infer_types() {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
element::Type::merge(inputs_et, inputs_et, get_input_element_type(i)),
|
||||
"Argument element types are inconsistent.");
|
||||
const auto& input_shape = get_input_partial_shape(i);
|
||||
const auto& input_rank = input_shape.rank();
|
||||
|
||||
if (input_rank.is_static() && (get_concatenation_axis() < 0)) {
|
||||
set_concatenation_axis(get_axis() < 0 ? get_axis() + input_rank.get_length() : get_axis());
|
||||
}
|
||||
|
||||
const auto concat_axis = get_concatenation_axis();
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_shape.is_dynamic() || (0 <= concat_axis && concat_axis < input_rank.get_length()),
|
||||
"Concatenation axis (",
|
||||
concat_axis,
|
||||
") is out of bounds [",
|
||||
-input_rank.get_length(),
|
||||
", ",
|
||||
input_rank.get_length() - 1,
|
||||
"] for ",
|
||||
"argument ",
|
||||
i,
|
||||
", which has shape ",
|
||||
input_shape,
|
||||
".");
|
||||
|
||||
input_shapes.push_back(input_shape);
|
||||
input_shapes.push_back(get_input_partial_shape(i));
|
||||
}
|
||||
|
||||
const auto output_shapes = shape_infer(this, input_shapes);
|
||||
set_output_type(0, inputs_et, output_shapes.front());
|
||||
const auto output_shape = shape_infer(this, input_shapes).front();
|
||||
if (output_shape.rank().is_static() && (get_concatenation_axis() < 0)) {
|
||||
set_concatenation_axis(ov::util::normalize(get_axis(), output_shape.size()));
|
||||
}
|
||||
|
||||
set_output_type(0, inputs_et, output_shape);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::Concat::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
std::shared_ptr<Node> Concat::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
OV_OP_SCOPE(v0_Concat_clone_with_new_inputs);
|
||||
return make_shared<Concat>(new_args, m_axis);
|
||||
return std::make_shared<Concat>(new_args, m_axis);
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
namespace {
|
||||
bool evaluate_concat(const HostTensorVector& args, const HostTensorPtr& out, int64_t concatenation_axis) {
|
||||
std::vector<const char*> arg_bufs;
|
||||
std::vector<ov::Shape> arg_shapes;
|
||||
ov::Shape out_shape(args[0]->get_shape());
|
||||
out_shape[concatenation_axis] = 0;
|
||||
for (auto& input : args) {
|
||||
arg_bufs.push_back(input->get_data_ptr<char>());
|
||||
arg_shapes.push_back(input->get_shape());
|
||||
out_shape[concatenation_axis] += arg_shapes.back()[concatenation_axis];
|
||||
}
|
||||
out->set_shape(out_shape);
|
||||
ov::reference::concat(arg_bufs,
|
||||
out->get_data_ptr<char>(),
|
||||
arg_shapes,
|
||||
out_shape,
|
||||
concatenation_axis,
|
||||
out->get_element_type().size());
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool op::Concat::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
|
||||
bool Concat::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v0_Concat_evaluate);
|
||||
OPENVINO_ASSERT(!inputs.empty());
|
||||
OPENVINO_ASSERT(validate_host_tensor_vector(inputs, inputs.size()));
|
||||
OPENVINO_ASSERT(validate_host_tensor_vector(outputs, 1));
|
||||
auto concat_axis = get_axis() < 0 ? get_axis() + inputs[0]->get_shape().size() : get_axis();
|
||||
return evaluate_concat(inputs, outputs[0], concat_axis);
|
||||
}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
||||
bool op::Concat::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const {
|
||||
OV_OP_SCOPE(v0_Concat_evaluate);
|
||||
OPENVINO_ASSERT(!inputs.empty());
|
||||
OPENVINO_ASSERT(outputs.size() == 1);
|
||||
|
||||
auto concat_axis = ov::util::normalize(get_axis(), inputs.front().get_shape().size());
|
||||
const auto inputs_count = inputs.size();
|
||||
std::vector<const char*> arg_bufs(inputs_count);
|
||||
std::vector<Shape> arg_shapes;
|
||||
std::vector<PartialShape> input_shapes;
|
||||
arg_shapes.reserve(inputs_count);
|
||||
input_shapes.reserve(inputs_count);
|
||||
|
||||
std::vector<const char*> arg_bufs;
|
||||
std::vector<ov::Shape> arg_shapes;
|
||||
|
||||
ov::Shape out_shape(inputs.front().get_shape());
|
||||
out_shape[concat_axis] = 0;
|
||||
auto arg_buf = arg_bufs.begin();
|
||||
for (auto& input : inputs) {
|
||||
arg_bufs.push_back(static_cast<const char*>(input.data()));
|
||||
arg_shapes.push_back(input.get_shape());
|
||||
out_shape[concat_axis] += arg_shapes.back()[concat_axis];
|
||||
*arg_buf = static_cast<const char*>(input.data());
|
||||
++arg_buf;
|
||||
const auto& input_shape = input.get_shape();
|
||||
arg_shapes.emplace_back(input_shape);
|
||||
input_shapes.emplace_back(input_shape);
|
||||
}
|
||||
|
||||
const auto& out_shape = shape_infer(this, input_shapes).front().to_shape();
|
||||
outputs.front().set_shape(out_shape);
|
||||
ov::reference::concat(arg_bufs,
|
||||
static_cast<char*>(outputs.front().data()),
|
||||
arg_shapes,
|
||||
out_shape,
|
||||
concat_axis,
|
||||
outputs.front().get_element_type().size());
|
||||
reference::concat(arg_bufs,
|
||||
static_cast<char*>(outputs.front().data()),
|
||||
arg_shapes,
|
||||
out_shape,
|
||||
ov::util::normalize(get_axis(), out_shape.size()),
|
||||
outputs.front().get_element_type().size());
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool op::Concat::has_evaluate() const {
|
||||
bool Concat::has_evaluate() const {
|
||||
OV_OP_SCOPE(v0_Concat_has_evaluate);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool op::Concat::evaluate_lower(ov::TensorVector& output_values) const {
|
||||
bool Concat::evaluate_lower(TensorVector& output_values) const {
|
||||
return default_lower_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::Concat::evaluate_upper(ov::TensorVector& output_values) const {
|
||||
bool Concat::evaluate_upper(TensorVector& output_values) const {
|
||||
return default_upper_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::Concat::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
bool Concat::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
const auto& inputs = input_values();
|
||||
if (std::all_of(inputs.cbegin(), inputs.cend(), [](const Output<Node>& out) {
|
||||
const auto& labels = out.get_tensor().get_value_label();
|
||||
@ -187,3 +132,6 @@ bool op::Concat::evaluate_label(TensorLabelVector& output_labels) const {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
} // namespace v0
|
||||
} // namespace op
|
||||
} // namespace ov
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
#include "common_test_utils/test_assertions.hpp"
|
||||
#include "common_test_utils/type_prop.hpp"
|
||||
#include "openvino/core/dimension_tracker.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
@ -68,15 +69,10 @@ TEST(type_prop, concat_deduce_axis_oob) {
|
||||
auto param0 = make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 3, 4});
|
||||
auto param1 = make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 7, 4});
|
||||
auto param2 = make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 2, 5});
|
||||
try {
|
||||
auto c = make_shared<ov::op::v0::Concat>(ov::NodeVector{param0, param1, param2}, 3);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Deduced type should disagree with specified type";
|
||||
} catch (const ov::NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Concatenation axis (3) is out of bounds"));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
|
||||
OV_EXPECT_THROW(ignore = make_shared<ov::op::v0::Concat>(ov::NodeVector{param0, param1, param2}, 3),
|
||||
ov::AssertFailure,
|
||||
HasSubstr("Concat Parameter axis 3 out of the tensor rank range"));
|
||||
}
|
||||
|
||||
TEST(type_prop, concat_deduce_axis_barely_in_bounds) {
|
||||
@ -259,15 +255,9 @@ TEST(type_prop, concat_partial_negative_axis_incorrect) {
|
||||
auto param1 = make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 7, 4});
|
||||
auto param2 = make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{2, 2, 4});
|
||||
|
||||
try {
|
||||
auto c = make_shared<ov::op::v0::Concat>(ov::NodeVector{param0, param1, param2}, -4);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect negative axis value not detected (out of bounds)";
|
||||
} catch (const ov::NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Concatenation axis (-1) is out of bounds"));
|
||||
} catch (...) {
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
OV_EXPECT_THROW(ignore = make_shared<ov::op::v0::Concat>(ov::NodeVector{param0, param1, param2}, -4),
|
||||
ov::AssertFailure,
|
||||
HasSubstr("Concat Parameter axis -4 out of the tensor rank range"));
|
||||
}
|
||||
|
||||
/** \brief Test uses evaluate lower/upper and label of concat op. */
|
||||
|
Loading…
Reference in New Issue
Block a user