diff --git a/inference-engine/src/transformations/include/transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp b/inference-engine/src/transformations/include/transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp new file mode 100644 index 00000000000..757d20390d3 --- /dev/null +++ b/inference-engine/src/transformations/include/transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include + +namespace ngraph { +namespace pass { + +class TRANSFORMATIONS_API ConvertGather7ToGather1; + +} // namespace pass +} // namespace ngraph + +/** + * @ingroup ie_transformation_common_api + * @brief ConvertGather7ToGather1 covert v7::Gather into v1::Gather. + */ +class ngraph::pass::ConvertGather7ToGather1 : public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + ConvertGather7ToGather1(); +}; diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/algebraic_simplification.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/algebraic_simplification.cpp index 9e36bb8c8f0..4800121b0dd 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/algebraic_simplification.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/algebraic_simplification.cpp @@ -37,12 +37,12 @@ static bool simplify_gather(std::shared_ptr node) { return false; } - auto axis = gather->get_axis(); - if (axis == opset3::Gather::AXIS_NOT_SET_VALUE) { + if (!gather->is_axis_set()) { NGRAPH_DEBUG << "axis value not set"; return false; } + auto axis = gather->get_axis(); // case_1 : if the input tensor is of shape (4, 1, 4) // and axis = 1, then the gather would be simply @@ -85,12 +85,12 @@ static bool simplify_gather_shapeof(shared_ptr node) { } auto gather_in_rank = gather->get_input_partial_shape(0).rank(); auto indices_rank = gather->get_input_partial_shape(1).rank(); - auto axis = gather->get_axis(); if (gather_in_rank.is_dynamic() || indices_rank.is_dynamic() || - axis == opset3::Gather::AXIS_NOT_SET_VALUE) { + !gather->is_axis_set()) { NGRAPH_DEBUG << gather << " cannot simplify gather->shapeof"; return false; } + auto axis = gather->get_axis(); auto zero_axis = opset3::Constant::create(element::i64, Shape{}, {0}); NodeVector new_ops; diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index 244cc6e0847..47e9f789ea6 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -42,6 +42,7 @@ #include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp" #include "transformations/op_conversions/convert_pad_to_group_conv.hpp" #include "transformations/op_conversions/convert_divide.hpp" +#include "transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp" #include "transformations/op_conversions/convert_mod.hpp" #include "transformations/op_conversions/convert_minimum_to_power_and_max.hpp" #include "transformations/op_conversions/convert_negative.hpp" @@ -156,6 +157,8 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptrset_name("ngraph::pass::ConvFusions"); manager.register_pass(); + // need to convert to Gather-1 until plugins do not support Gather-7 + manager.register_pass(); auto fq_fusions = manager.register_pass(); fq_fusions->add_matcher(); diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/eliminate_unsqueeze_gather.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/eliminate_unsqueeze_gather.cpp index f5b1062e2b2..fae3b71ac1e 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/eliminate_unsqueeze_gather.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/eliminate_unsqueeze_gather.cpp @@ -19,7 +19,7 @@ ngraph::pass::EliminateUnsqueezeGather::EliminateUnsqueezeGather() { const auto unsqueeze = ngraph::pattern::wrap_type({unsqueezeInput, unsqueezeAxis}, pattern::consumers_count(1)); const auto gatherIndices = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0}); const auto gatherAxis = ngraph::pattern::any_input(); - const auto gather = ngraph::pattern::wrap_type({unsqueeze, gatherIndices, gatherAxis}); + const auto gather = ngraph::pattern::wrap_type({unsqueeze, gatherIndices, gatherAxis}); ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { auto& patternValue = m.get_pattern_value_map(); diff --git a/inference-engine/src/transformations/src/transformations/op_conversions/convert_gather_v7_to_gather_v1.cpp b/inference-engine/src/transformations/src/transformations/op_conversions/convert_gather_v7_to_gather_v1.cpp new file mode 100644 index 00000000000..67d5c0ebc1c --- /dev/null +++ b/inference-engine/src/transformations/src/transformations/op_conversions/convert_gather_v7_to_gather_v1.cpp @@ -0,0 +1,43 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp" +#include +#include +#include +#include + +#include "itt.hpp" + +NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGather7ToGather1, "ConvertGather7ToGather1", 0); + +ngraph::pass::ConvertGather7ToGather1::ConvertGather7ToGather1() { + MATCHER_SCOPE(ConvertGather7ToGather1); + + auto gather_v7 = pattern::wrap_type(); + + ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) { + auto gather_v7_node = std::dynamic_pointer_cast(m.get_match_root()); + if (!gather_v7_node) + return false; + + int64_t batch_dims = 0; + if (gather_v7_node->is_axis_set()) + batch_dims = gather_v7_node->get_batch_dims(); + if (batch_dims != 0) + return false; + auto data_input = gather_v7_node->input_value(0); + auto indices_input = gather_v7_node->input_value(1); + auto axis_input = gather_v7_node->input_value(2); + + auto gather_v1 = std::make_shared(data_input, indices_input, axis_input); + gather_v1->set_friendly_name(gather_v7_node->get_friendly_name()); + ngraph::copy_runtime_info(gather_v7_node, gather_v1); + ngraph::replace_node(gather_v7_node, gather_v1); + return true; + }; + + auto m = std::make_shared(gather_v7, matcher_name); + register_matcher(m, callback); +} diff --git a/inference-engine/tests/functional/inference_engine/transformations/convert_gather_v7_to_gather_v1_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/convert_gather_v7_to_gather_v1_test.cpp new file mode 100644 index 00000000000..b409ac16738 --- /dev/null +++ b/inference-engine/tests/functional/inference_engine/transformations/convert_gather_v7_to_gather_v1_test.cpp @@ -0,0 +1,51 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" + +using namespace testing; + +TEST(TransformationTests, ConvertGather7toGather1) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{2, 3}); + auto indices = std::make_shared(ngraph::element::i32, ngraph::Shape{2, 2}); + auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0}); + + auto gather_v7 = std::make_shared(data, indices, axis, 0); + + f = std::make_shared(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices}); + + ngraph::pass::Manager manager; + manager.register_pass(); + manager.register_pass(); + manager.run_passes(f); + ASSERT_NO_THROW(check_rt_info(f)); + } + + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{2, 3}); + auto indices = std::make_shared(ngraph::element::i32, ngraph::Shape{2, 2}); + auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0}); + + auto gather_v1 = std::make_shared(data, indices, axis); + + f_ref = std::make_shared(ngraph::NodeVector{gather_v1}, ngraph::ParameterVector{data, indices}); + } + + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; +} diff --git a/inference-engine/tests/functional/inference_engine/transformations/eliminate_unsqueeze_gather.cpp b/inference-engine/tests/functional/inference_engine/transformations/eliminate_unsqueeze_gather.cpp index 1733d7291df..526695879d8 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/eliminate_unsqueeze_gather.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/eliminate_unsqueeze_gather.cpp @@ -6,7 +6,6 @@ #include #include - #include #include diff --git a/ngraph/core/include/ngraph/op/gather.hpp b/ngraph/core/include/ngraph/op/gather.hpp index 6a1c096f04f..a165e5ddac4 100644 --- a/ngraph/core/include/ngraph/op/gather.hpp +++ b/ngraph/core/include/ngraph/op/gather.hpp @@ -4,7 +4,7 @@ #pragma once -#include "ngraph/op/op.hpp" +#include "ngraph/op/util/gather_base.hpp" namespace ngraph { @@ -13,12 +13,10 @@ namespace ngraph namespace v1 { /// \brief Gather slices from axis of params according to indices - class NGRAPH_API Gather : public Op + class NGRAPH_API Gather : public op::util::GatherBase { public: - static const int64_t AXIS_NOT_SET_VALUE = std::numeric_limits::max(); - static constexpr NodeTypeInfo type_info{"Gather", 1}; - const NodeTypeInfo& get_type_info() const override { return type_info; } + NGRAPH_RTTI_DECLARATION; Gather() = default; /// \param params The tensor from which slices are gathered /// \param indices Tensor with indexes to gather @@ -28,35 +26,16 @@ namespace ngraph const Output& axis); bool visit_attributes(AttributeVisitor& visitor) override; - int64_t get_axis() const; - void validate_and_infer_types() override; - - virtual std::shared_ptr + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; - - bool evaluate(const HostTensorVector& outputs, - const HostTensorVector& inputs) const override; - bool evaluate_lower(const HostTensorVector& outputs) const override; - bool evaluate_upper(const HostTensorVector& outputs) const override; - - bool constant_fold(OutputVector& output_values, - const OutputVector& inputs_values) override; - - private: - static const int PARAMS; - static const int INDICES; - static const int AXIS; - - bool evaluate_gather(const HostTensorVector& outputs, - const HostTensorVector& inputs) const; }; } // namespace v1 namespace v7 { /// \brief Gather slices from axis of params according to indices - class NGRAPH_API Gather : public Op + class NGRAPH_API Gather : public op::util::GatherBase { public: NGRAPH_RTTI_DECLARATION; @@ -76,23 +55,6 @@ namespace ngraph std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; - - int64_t get_batch_dims() const; - int64_t get_axis() const; - bool is_axis_set() const; - - bool evaluate_gather(const HostTensorVector& outputs, - const HostTensorVector& inputs) const; - bool evaluate(const HostTensorVector& outputs, - const HostTensorVector& inputs) const override; - bool evaluate_lower(const HostTensorVector& outputs) const override; - bool evaluate_upper(const HostTensorVector& outputs) const override; - - bool constant_fold(OutputVector& output_values, - const OutputVector& inputs_values) override; - - private: - int64_t m_batch_dims = 0; }; } // namespace v7 } // namespace op diff --git a/ngraph/core/include/ngraph/op/util/gather_base.hpp b/ngraph/core/include/ngraph/op/util/gather_base.hpp new file mode 100644 index 00000000000..e6bd909731d --- /dev/null +++ b/ngraph/core/include/ngraph/op/util/gather_base.hpp @@ -0,0 +1,50 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "ngraph/op/op.hpp" + +namespace ngraph +{ + namespace op + { + namespace util + { + /// \brief GatherBase basic class for Gather v1 and v7 + class NGRAPH_API GatherBase : public Op + { + public: + NGRAPH_RTTI_DECLARATION; + GatherBase() = default; + + /// \param data The tensor from which slices are gathered + /// \param indices Tensor with indexes to gather + /// \param axis The tensor is a dimension index to gather data from + /// \param batch_dims The number of batch dimension in data and indices tensors + GatherBase(const Output& data, + const Output& indices, + const Output& axis, + const int64_t batch_dims = 0); + + void validate_and_infer_types() override; + int64_t get_batch_dims() const; + int64_t get_axis() const; + bool is_axis_set() const; + + bool evaluate(const HostTensorVector& outputs, + const HostTensorVector& inputs) const override; + + bool evaluate_lower(const HostTensorVector& outputs) const override; + bool evaluate_upper(const HostTensorVector& outputs) const override; + + bool constant_fold(OutputVector& output_values, + const OutputVector& inputs_values) override; + + protected: + int64_t m_batch_dims = 0; + }; + } // namespace utils + } // namespace op +} // namespace ngraph diff --git a/ngraph/core/src/op/gather.cpp b/ngraph/core/src/op/gather.cpp index 522ce299e35..f68a7a9491a 100644 --- a/ngraph/core/src/op/gather.cpp +++ b/ngraph/core/src/op/gather.cpp @@ -3,30 +3,19 @@ // #include "ngraph/op/gather.hpp" -#include "itt.hpp" -#include "ngraph/op/concat.hpp" -#include "ngraph/op/constant.hpp" -#include "ngraph/op/squeeze.hpp" -#include "ngraph/runtime/host_tensor.hpp" -#include "ngraph/runtime/reference/gather.hpp" -#include "ngraph/shape.hpp" - #include +#include "itt.hpp" +#include "ngraph/shape.hpp" using namespace std; using namespace ngraph; -constexpr NodeTypeInfo op::v1::Gather::type_info; -const int64_t op::v1::Gather::AXIS_NOT_SET_VALUE; - -const int op::v1::Gather::PARAMS = 0; -const int op::v1::Gather::INDICES = 1; -const int op::v1::Gather::AXIS = 2; +NGRAPH_RTTI_DEFINITION(op::v1::Gather, "Gather", 1, op::util::GatherBase); op::v1::Gather::Gather(const Output& params, const Output& indices, const Output& axes) - : Op({params, indices, axes}) + : GatherBase(params, indices, axes) { constructor_validate_and_infer_types(); } @@ -37,107 +26,38 @@ bool ngraph::op::v1::Gather::visit_attributes(AttributeVisitor& visitor) return true; } -void op::v1::Gather::validate_and_infer_types() -{ - NGRAPH_OP_SCOPE(v1_Gather_validate_and_infer_types); - const auto& input_rank = get_input_partial_shape(PARAMS).rank(); - const auto& axis_shape = get_input_partial_shape(AXIS); - const auto& axis_rank = axis_shape.rank(); - - if (axis_rank.is_static() && axis_shape.is_static()) - { - const auto axis_is_scalar = axis_rank.get_length() == 0; - const auto axis_has_one_elem = - axis_rank.get_length() == 1 && axis_shape[0].get_length() == 1; - NODE_VALIDATION_CHECK(this, - axis_is_scalar || axis_has_one_elem, - "Axes input must be scalar or have 1 element (shape: ", - axis_shape, - ")."); - } - - int64_t axis = get_axis(); - if (input_rank.is_static() && axis != AXIS_NOT_SET_VALUE) - { - NODE_VALIDATION_CHECK(this, - axis < input_rank.get_length(), - "The axis must => 0 and <= input_rank (axis: ", - axis, - ")."); - } - - element::Type result_et = get_input_element_type(PARAMS); - - const PartialShape& params_shape = get_input_partial_shape(PARAMS); - const PartialShape& indices_shape = get_input_partial_shape(INDICES); - - PartialShape result_shape; - if (params_shape.rank().is_static() && indices_shape.rank().is_static() && - axis != AXIS_NOT_SET_VALUE) - { - std::vector result_dims(params_shape.rank().get_length() + - indices_shape.rank().get_length() - 1); - int64_t i = 0; - for (; i < axis; i++) - { - result_dims[i] = params_shape[i]; - } - for (int64_t j = 0; j < indices_shape.rank().get_length(); i++, j++) - { - result_dims[i] = indices_shape[j]; - } - for (int64_t j = axis + 1; j < params_shape.rank().get_length(); i++, j++) - { - result_dims[i] = params_shape[j]; - } - - result_shape = PartialShape(result_dims); - } - else - { - result_shape = PartialShape::dynamic(); - } - - set_output_type(0, result_et, result_shape); -} - -int64_t op::v1::Gather::get_axis() const -{ - int64_t axis = AXIS_NOT_SET_VALUE; - if (const auto& const_op = get_constant_from_source(input_value(AXIS))) - { - axis = const_op->cast_vector()[0]; - } - if (axis < 0) - { - const auto& input_rank = get_input_partial_shape(PARAMS).rank(); - if (input_rank.is_static()) - { - axis += input_rank.get_length(); - } - } - return axis; -} - shared_ptr op::v1::Gather::clone_with_new_inputs(const OutputVector& new_args) const { NGRAPH_OP_SCOPE(v1_Gather_clone_with_new_inputs); check_new_args_count(this, new_args); - return make_shared(new_args.at(PARAMS), new_args.at(INDICES), new_args.at(AXIS)); + return make_shared(new_args.at(0), new_args.at(1), new_args.at(2)); } -NGRAPH_RTTI_DEFINITION(op::v7::Gather, "Gather", 7); +NGRAPH_RTTI_DEFINITION(op::v7::Gather, "Gather", 7, op::util::GatherBase); op::v7::Gather::Gather(const Output& data, const Output& indices, const Output& axis, const int64_t batch_dims) - : Op({data, indices, axis}) - , m_batch_dims(batch_dims) + : GatherBase(data, indices, axis, batch_dims) { constructor_validate_and_infer_types(); } +void op::v7::Gather::validate_and_infer_types() +{ + NGRAPH_OP_SCOPE(v7_Gather_validate_and_infer_types); + NODE_VALIDATION_CHECK(this, + get_input_element_type(1).is_integral_number(), + "Indices element type must be of an integral number type."); + + NODE_VALIDATION_CHECK(this, + get_input_element_type(2).is_integral_number(), + "Axis element type must be of an integral number type."); + + op::util::GatherBase::validate_and_infer_types(); +} + bool ngraph::op::v7::Gather::visit_attributes(AttributeVisitor& visitor) { NGRAPH_OP_SCOPE(v7_Gather_visit_attributes); @@ -145,440 +65,9 @@ bool ngraph::op::v7::Gather::visit_attributes(AttributeVisitor& visitor) return true; } -void op::v7::Gather::validate_and_infer_types() -{ - NGRAPH_OP_SCOPE(v7_Gather_validate_and_infer_types); - const auto& data_type = get_input_element_type(0); - const auto& indices_type = get_input_element_type(1); - - NODE_VALIDATION_CHECK(this, - indices_type == element::Type_t::i32 || - indices_type == element::Type_t::i64, - "indices must be of int32 or int64 type. But instead got: ", - indices_type); - - const auto& data_pshape = get_input_partial_shape(0); - const auto& indices_pshape = get_input_partial_shape(1); - const auto& axis_pshape = get_input_partial_shape(2); - auto data_rank = data_pshape.rank(); - auto indices_rank = indices_pshape.rank(); - auto axis_rank = axis_pshape.rank(); - - if (axis_rank.is_static() && axis_pshape.is_static()) - { - const auto axis_is_scalar = axis_rank.get_length() == 0; - const auto axis_has_one_elem = - axis_rank.get_length() == 1 && axis_pshape[0].get_length() == 1; - NODE_VALIDATION_CHECK( - this, - axis_is_scalar || axis_has_one_elem, - "Axes input must be scalar or have 1 element. But instead got axis_shape = ", - axis_pshape); - } - - int64_t batch_dims = get_batch_dims(); // will not be converted to positive if axis is not set - if (is_axis_set()) - { - int64_t axis = get_axis(); - NODE_VALIDATION_CHECK(this, - batch_dims <= axis, - "The batch_dims <= axis. But instead got: batch_dims = ", - batch_dims, - ", axis = ", - axis); - - if (data_rank.is_static()) - { - NODE_VALIDATION_CHECK(this, - axis >= 0 && axis < data_rank.get_length(), - "The axis must be => 0 and < data_rank. But instead got axis = ", - axis, - " data_rank = ", - data_rank.get_length()); - } - } - - if (indices_rank.is_static() && batch_dims >= 0) - { - NODE_VALIDATION_CHECK( - this, - batch_dims <= indices_rank.get_length(), - "The batch_dims must be <= indices_rank. But instead got: batch_dims = ", - batch_dims, - ", indices_rank = ", - indices_rank.get_length()); - } - - if (data_rank.is_static() && indices_rank.is_static()) - { - if (batch_dims >= 0) - { - auto out_rank = data_rank.get_length() + indices_rank.get_length() - 1 - batch_dims; - PartialShape output_pshape = PartialShape::dynamic(out_rank); - - // implementation of out_shape formula - // data.shape[:batch_dims] + data.shape[batch_dims:axis] + indices.shape[batch_dims:] + - // data.shape[axis + 1:] - int i = 0; - for (; i < batch_dims; i++) - { - NODE_VALIDATION_CHECK(this, - data_pshape[i].compatible(indices_pshape[i]), - "Shapes ", - data_pshape, - " and ", - indices_pshape, - " are not consistent. data and indices must have equal or " - "intersecting sizes until batch_dims"); - - output_pshape[i] = data_pshape[i] & indices_pshape[i]; - } - - if (is_axis_set()) - { - int64_t axis = get_axis(); - for (; i < axis; i++) - { - output_pshape[i] = data_pshape[i]; - } - for (; i < axis + indices_rank.get_length() - batch_dims; i++) - { - output_pshape[i] = indices_pshape[batch_dims - axis + i]; - } - for (; i < out_rank; i++) - { - output_pshape[i] = data_pshape[batch_dims + 1 - indices_rank.get_length() + i]; - } - } - - set_output_type(0, data_type, output_pshape); - } - else if (batch_dims < 0) - { - // batch_dims < 0 could be only if axis is not set - // as soon as axis value will arrive negative batch_dims should be resolved - // batch_dims value will be within [0, data_rank] && [0, indices_rank] - int64_t max_rank = data_rank.get_length() + indices_rank.get_length() - 1; - int64_t min_rank = max_rank - max(data_rank.get_length(), indices_rank.get_length()); - - set_output_type(0, data_type, PartialShape::dynamic(Dimension(min_rank, max_rank))); - } - } - else - { - set_output_type(0, data_type, PartialShape::dynamic()); - } -} - -int64_t op::v7::Gather::get_axis() const -{ - const auto& const_op = get_constant_from_source(input_value(2)); - int64_t axis = const_op->cast_vector()[0]; - if (axis < 0) - { - const auto& data_rank = get_input_partial_shape(0).rank(); - if (data_rank.is_static()) - { - axis += data_rank.get_length(); - } - } - return axis; -} - -int64_t op::v7::Gather::get_batch_dims() const -{ - if (m_batch_dims < 0 && is_axis_set()) - return get_axis() + m_batch_dims; - else - return m_batch_dims; -} - -bool op::v7::Gather::is_axis_set() const -{ - const auto& axes_constant = get_constant_from_source(input_value(2)); - if (axes_constant) - return true; - else - return false; -} - shared_ptr op::v7::Gather::clone_with_new_inputs(const OutputVector& new_args) const { NGRAPH_OP_SCOPE(v7_Gather_clone_with_new_inputs); check_new_args_count(this, new_args); return make_shared(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims); } - -namespace gather -{ - template - bool evaluate(const HostTensorPtr& arg0, - const HostTensorPtr& arg1, - const HostTensorPtr& out, - size_t axis, - size_t batch_dims) - { - using T = typename element_type_traits::value_type; - Shape params_shape = arg0->get_shape(); - Shape indices_shape = arg1->get_shape(); - Shape out_shape(params_shape.size() + indices_shape.size() - 1 - batch_dims); - uint64_t i = 0; - for (; i < axis; i++) - { - out_shape[i] = params_shape[i]; - } - for (uint64_t j = batch_dims; j < indices_shape.size(); i++, j++) - { - out_shape[i] = indices_shape[j]; - } - for (uint64_t j = axis + 1; j < params_shape.size(); i++, j++) - { - out_shape[i] = params_shape[j]; - } - - out->set_shape(out_shape); - - if (arg1->get_element_type() == element::i64) - { - runtime::reference::gather(arg0->get_data_ptr(), - arg1->get_data_ptr(), - out->get_data_ptr(), - arg0->get_shape(), - arg1->get_shape(), - out->get_shape(), - axis, - batch_dims); - } - else if (arg1->get_element_type() == element::i32) - { - runtime::reference::gather(arg0->get_data_ptr(), - arg1->get_data_ptr(), - out->get_data_ptr(), - arg0->get_shape(), - arg1->get_shape(), - out->get_shape(), - axis, - batch_dims); - } - else - { - throw ngraph_error("Unexpected type"); - } - - return true; - } - - bool evaluate_gather(const HostTensorPtr& arg0, - const HostTensorPtr& arg1, - const HostTensorPtr& out, - size_t axis, - size_t batch_dims = 0) - { - bool rc = true; - - switch (out->get_element_type()) - { - NGRAPH_TYPE_CASE(evaluate_gather, i32, arg0, arg1, out, axis, batch_dims); - NGRAPH_TYPE_CASE(evaluate_gather, i64, arg0, arg1, out, axis, batch_dims); - NGRAPH_TYPE_CASE(evaluate_gather, u32, arg0, arg1, out, axis, batch_dims); - NGRAPH_TYPE_CASE(evaluate_gather, u64, arg0, arg1, out, axis, batch_dims); - NGRAPH_TYPE_CASE(evaluate_gather, f16, arg0, arg1, out, axis, batch_dims); - NGRAPH_TYPE_CASE(evaluate_gather, f32, arg0, arg1, out, axis, batch_dims); - NGRAPH_TYPE_CASE(evaluate_gather, boolean, arg0, arg1, out, axis, batch_dims); - default: rc = false; break; - } - return rc; - } - - bool cf_gather_with_subgraph(OutputVector& output_values, - const OutputVector& input_values, - const PartialShape& gather_ps) - { - if (gather_ps.is_dynamic() || input_values.size() != 3) - { - return false; - } - - const auto concat = - std::dynamic_pointer_cast(input_values[0].get_node_shared_ptr()); - const auto indices = - std::dynamic_pointer_cast(input_values[1].get_node_shared_ptr()); - const auto axis = - std::dynamic_pointer_cast(input_values[2].get_node_shared_ptr()); - - if (!concat || !indices || !axis) - { - return false; - } - - // only along axis=0 - if (axis->cast_vector()[0] != 0 || concat->get_axis() != 0) - { - return false; - } - // only single indices are accepted - const auto indices_shape = indices->get_shape(); - if (indices_shape.size() > 1 || (indices_shape.size() == 1 && indices_shape[0] > 1)) - { - return false; - } - // concat inputs are 1D and their count is equal to Concat output shape - if (concat->get_output_partial_shape(0).is_dynamic()) - { - return false; - } - const auto concat_inputs = concat->inputs(); - // concat inputs must be single elements - if (concat_inputs.size() != shape_size(concat->get_shape())) - { - return false; - } - - const int64_t rank = concat->get_shape()[0]; - const int64_t raw_index = indices->cast_vector()[0]; - const int64_t positive_index = raw_index < 0 ? rank + raw_index : raw_index; - NGRAPH_CHECK(positive_index >= 0 && positive_index < rank); - - // gather takes exactly one element out of the Concat output - const auto gathered_concat_input = - concat_inputs[positive_index].get_source_output().get_node_shared_ptr(); - // Concat inputs are 1D, resulting tensor shape depends on Gather indices - auto gathered = gathered_concat_input; - if (indices_shape.empty()) - { - // gathering a scalar - const auto axes = op::Constant::create(element::i64, Shape{1}, {0}); - gathered = make_shared(gathered_concat_input, axes); - } - - output_values[0] = gathered; - - return true; - } -} // namespace gather - -bool op::v1::Gather::evaluate_gather(const HostTensorVector& outputs, - const HostTensorVector& inputs) const -{ - int64_t axis = 0; - switch (inputs[2]->get_element_type()) - { - case element::Type_t::i8: axis = inputs[2]->get_data_ptr()[0]; break; - case element::Type_t::i16: axis = inputs[2]->get_data_ptr()[0]; break; - case element::Type_t::i32: axis = inputs[2]->get_data_ptr()[0]; break; - case element::Type_t::i64: axis = inputs[2]->get_data_ptr()[0]; break; - case element::Type_t::u8: axis = inputs[2]->get_data_ptr()[0]; break; - case element::Type_t::u16: axis = inputs[2]->get_data_ptr()[0]; break; - case element::Type_t::u32: axis = inputs[2]->get_data_ptr()[0]; break; - case element::Type_t::u64: axis = inputs[2]->get_data_ptr()[0]; break; - default: throw ngraph_error("axis element type is not integral data type"); - } - - if (axis < 0) - { - const auto& input_rank = get_input_partial_shape(PARAMS).rank(); - if (input_rank.is_static()) - { - axis += input_rank.get_length(); - } - } - return gather::evaluate_gather(inputs[0], inputs[1], outputs[0], axis); -} - -bool op::v1::Gather::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const -{ - NGRAPH_OP_SCOPE(v1_Gather_evaluate); - NGRAPH_CHECK(validate_host_tensor_vector(inputs, 3)); - NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1)); - return evaluate_gather(outputs, inputs); -} - -bool op::v1::Gather::evaluate_lower(const HostTensorVector& output_values) const -{ - if (!input_value(INDICES).get_tensor().has_and_set_bound() || - !input_value(AXIS).get_tensor().has_and_set_bound()) - return false; - return default_lower_bound_evaluator(this, output_values); -} - -bool op::v1::Gather::evaluate_upper(const HostTensorVector& output_values) const -{ - if (!input_value(INDICES).get_tensor().has_and_set_bound() || - !input_value(AXIS).get_tensor().has_and_set_bound()) - return false; - return default_upper_bound_evaluator(this, output_values); -} - -bool op::v1::Gather::constant_fold(OutputVector& output_values, const OutputVector& input_values) -{ - // try the regular constant folding just for the Gather node - if (Node::constant_fold(output_values, input_values)) - { - return true; - } - else - { - return gather::cf_gather_with_subgraph( - output_values, input_values, get_output_partial_shape(0)); - } -} - -bool op::v7::Gather::evaluate_gather(const HostTensorVector& outputs, - const HostTensorVector& inputs) const -{ - int64_t axis = 0; - switch (inputs[2]->get_element_type()) - { - case element::Type_t::i32: axis = inputs[2]->get_data_ptr()[0]; break; - case element::Type_t::i64: axis = inputs[2]->get_data_ptr()[0]; break; - default: throw ngraph_error("axis must be of int32 or int64 type."); - } - - if (axis < 0) - { - const auto& input_rank = get_input_partial_shape(0).rank(); - if (input_rank.is_static()) - { - axis += input_rank.get_length(); - } - } - return gather::evaluate_gather(inputs[0], inputs[1], outputs[0], axis, get_batch_dims()); -} - -bool op::v7::Gather::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const -{ - NGRAPH_OP_SCOPE(v7_Gather_evaluate); - NGRAPH_CHECK(validate_host_tensor_vector(inputs, 3)); - NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1)); - return evaluate_gather(outputs, inputs); -} - -bool op::v7::Gather::evaluate_lower(const HostTensorVector& output_values) const -{ - if (!input_value(1).get_tensor().has_and_set_bound() || - !input_value(2).get_tensor().has_and_set_bound()) - return false; - return default_lower_bound_evaluator(this, output_values); -} - -bool op::v7::Gather::evaluate_upper(const HostTensorVector& output_values) const -{ - if (!input_value(1).get_tensor().has_and_set_bound() || - !input_value(2).get_tensor().has_and_set_bound()) - return false; - return default_upper_bound_evaluator(this, output_values); -} - -bool op::v7::Gather::constant_fold(OutputVector& output_values, const OutputVector& input_values) -{ - // try the regular constant folding just for the Gather node - if (Node::constant_fold(output_values, input_values)) - { - return true; - } - else - { - return gather::cf_gather_with_subgraph( - output_values, input_values, get_output_partial_shape(0)); - } -} diff --git a/ngraph/core/src/op/util/gather_base.cpp b/ngraph/core/src/op/util/gather_base.cpp new file mode 100644 index 00000000000..efc5e476c53 --- /dev/null +++ b/ngraph/core/src/op/util/gather_base.cpp @@ -0,0 +1,393 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ngraph/op/util/gather_base.hpp" +#include "itt.hpp" +#include "ngraph/op/concat.hpp" +#include "ngraph/op/constant.hpp" +#include "ngraph/op/squeeze.hpp" +#include "ngraph/runtime/host_tensor.hpp" +#include "ngraph/runtime/reference/gather.hpp" +#include "ngraph/shape.hpp" + +#include + +using namespace std; +using namespace ngraph; + +NGRAPH_RTTI_DEFINITION(op::util::GatherBase, "GatherBase", 7); + +op::util::GatherBase::GatherBase(const Output& data, + const Output& indices, + const Output& axis, + const int64_t batch_dims) + : Op({data, indices, axis}) + , m_batch_dims(batch_dims) +{ + constructor_validate_and_infer_types(); +} + +void op::util::GatherBase::validate_and_infer_types() +{ + NGRAPH_OP_SCOPE(util_GatherBase_validate_and_infer_types); + const auto& data_type = get_input_element_type(0); + + const auto& data_pshape = get_input_partial_shape(0); + const auto& indices_pshape = get_input_partial_shape(1); + const auto& axis_pshape = get_input_partial_shape(2); + auto data_rank = data_pshape.rank(); + auto indices_rank = indices_pshape.rank(); + auto axis_rank = axis_pshape.rank(); + + if (axis_rank.is_static() && axis_pshape.is_static()) + { + const auto axis_is_scalar = axis_rank.get_length() == 0; + const auto axis_has_one_elem = + axis_rank.get_length() == 1 && axis_pshape[0].get_length() == 1; + NODE_VALIDATION_CHECK( + this, + axis_is_scalar || axis_has_one_elem, + "Axis input must be scalar or have 1 element. But instead got axis_shape = ", + axis_pshape); + } + + int64_t batch_dims = get_batch_dims(); // will not be converted to positive if axis is not set + if (is_axis_set()) + { + int64_t axis = get_axis(); + NODE_VALIDATION_CHECK(this, + batch_dims <= axis, + "The batch_dims <= axis. But instead got: batch_dims = ", + batch_dims, + ", axis = ", + axis); + + if (data_rank.is_static()) + { + NODE_VALIDATION_CHECK(this, + axis >= 0 && axis < data_rank.get_length(), + "The axis must be >= 0 and < data_rank. But instead got axis = ", + axis, + " data_rank = ", + data_rank.get_length()); + } + } + + if (indices_rank.is_static() && batch_dims >= 0) + { + NODE_VALIDATION_CHECK( + this, + batch_dims <= indices_rank.get_length(), + "The batch_dims must be <= indices_rank. But instead got: batch_dims = ", + batch_dims, + ", indices_rank = ", + indices_rank.get_length()); + } + + if (data_rank.is_static() && indices_rank.is_static()) + { + if (batch_dims >= 0) + { + auto out_rank = data_rank.get_length() + indices_rank.get_length() - 1 - batch_dims; + PartialShape output_pshape = PartialShape::dynamic(out_rank); + + // implementation of out_shape formula + // data.shape[:batch_dims] + data.shape[batch_dims:axis] + indices.shape[batch_dims:] + + // data.shape[axis + 1:] + int i = 0; + for (; i < batch_dims; i++) + { + NODE_VALIDATION_CHECK(this, + data_pshape[i].compatible(indices_pshape[i]), + "Shapes ", + data_pshape, + " and ", + indices_pshape, + " are not consistent. data and indices must have equal or " + "intersecting sizes until batch_dims"); + + output_pshape[i] = data_pshape[i] & indices_pshape[i]; + } + + if (is_axis_set()) + { + int64_t axis = get_axis(); + for (; i < axis; i++) + { + output_pshape[i] = data_pshape[i]; + } + for (; i < axis + indices_rank.get_length() - batch_dims; i++) + { + output_pshape[i] = indices_pshape[batch_dims - axis + i]; + } + for (; i < out_rank; i++) + { + output_pshape[i] = data_pshape[batch_dims + 1 - indices_rank.get_length() + i]; + } + } + + set_output_type(0, data_type, output_pshape); + } + else if (batch_dims < 0) + { + // batch_dims < 0 could be only if axis is not set + // as soon as axis value will arrive negative batch_dims should be resolved + // batch_dims value will be within [0, data_rank] && [0, indices_rank] + int64_t max_rank = data_rank.get_length() + indices_rank.get_length() - 1; + int64_t min_rank = max_rank - max(data_rank.get_length(), indices_rank.get_length()); + + set_output_type(0, data_type, PartialShape::dynamic(Dimension(min_rank, max_rank))); + } + } + else + { + set_output_type(0, data_type, PartialShape::dynamic()); + } +} + +int64_t op::util::GatherBase::get_axis() const +{ + const auto& const_op = get_constant_from_source(input_value(2)); + if (!const_op) + throw ngraph_error("axis value is not set"); + + int64_t axis = const_op->cast_vector()[0]; + if (axis < 0) + { + const auto& data_rank = get_input_partial_shape(0).rank(); + if (data_rank.is_static()) + { + axis += data_rank.get_length(); + } + } + return axis; +} + +int64_t op::util::GatherBase::get_batch_dims() const +{ + if (m_batch_dims < 0 && is_axis_set()) + return get_axis() + m_batch_dims; + else + return m_batch_dims; +} + +bool op::util::GatherBase::is_axis_set() const +{ + const auto& axis_constant = get_constant_from_source(input_value(2)); + if (axis_constant) + return true; + else + return false; +} + +namespace gather +{ + template + bool evaluate(const HostTensorPtr& arg0, + const HostTensorPtr& arg1, + const HostTensorPtr& out, + size_t axis, + size_t batch_dims) + { + using T = typename element_type_traits::value_type; + Shape params_shape = arg0->get_shape(); + Shape indices_shape = arg1->get_shape(); + Shape out_shape(params_shape.size() + indices_shape.size() - 1 - batch_dims); + uint64_t i = 0; + for (; i < axis; i++) + { + out_shape[i] = params_shape[i]; + } + for (uint64_t j = batch_dims; j < indices_shape.size(); i++, j++) + { + out_shape[i] = indices_shape[j]; + } + for (uint64_t j = axis + 1; j < params_shape.size(); i++, j++) + { + out_shape[i] = params_shape[j]; + } + + out->set_shape(out_shape); + + if (arg1->get_element_type() == element::i64) + { + runtime::reference::gather(arg0->get_data_ptr(), + arg1->get_data_ptr(), + out->get_data_ptr(), + arg0->get_shape(), + arg1->get_shape(), + out->get_shape(), + axis, + batch_dims); + } + else if (arg1->get_element_type() == element::i32) + { + runtime::reference::gather(arg0->get_data_ptr(), + arg1->get_data_ptr(), + out->get_data_ptr(), + arg0->get_shape(), + arg1->get_shape(), + out->get_shape(), + axis, + batch_dims); + } + else + { + throw ngraph_error("Unexpected type"); + } + + return true; + } + + bool evaluate_gather(const HostTensorPtr& arg0, + const HostTensorPtr& arg1, + const HostTensorPtr& out, + size_t axis, + size_t batch_dims = 0) + { + bool rc = true; + + switch (out->get_element_type()) + { + NGRAPH_TYPE_CASE(evaluate_gather, i32, arg0, arg1, out, axis, batch_dims); + NGRAPH_TYPE_CASE(evaluate_gather, i64, arg0, arg1, out, axis, batch_dims); + NGRAPH_TYPE_CASE(evaluate_gather, u32, arg0, arg1, out, axis, batch_dims); + NGRAPH_TYPE_CASE(evaluate_gather, u64, arg0, arg1, out, axis, batch_dims); + NGRAPH_TYPE_CASE(evaluate_gather, f16, arg0, arg1, out, axis, batch_dims); + NGRAPH_TYPE_CASE(evaluate_gather, f32, arg0, arg1, out, axis, batch_dims); + NGRAPH_TYPE_CASE(evaluate_gather, boolean, arg0, arg1, out, axis, batch_dims); + default: rc = false; break; + } + return rc; + } + + bool cf_gather_with_subgraph(OutputVector& output_values, + const OutputVector& input_values, + const PartialShape& gather_ps) + { + if (gather_ps.is_dynamic() || input_values.size() != 3) + { + return false; + } + + const auto concat = + std::dynamic_pointer_cast(input_values[0].get_node_shared_ptr()); + const auto indices = + std::dynamic_pointer_cast(input_values[1].get_node_shared_ptr()); + const auto axis = + std::dynamic_pointer_cast(input_values[2].get_node_shared_ptr()); + + if (!concat || !indices || !axis) + { + return false; + } + + // only along axis=0 + if (axis->cast_vector()[0] != 0 || concat->get_axis() != 0) + { + return false; + } + // only single indices are accepted + const auto indices_shape = indices->get_shape(); + if (indices_shape.size() > 1 || (indices_shape.size() == 1 && indices_shape[0] > 1)) + { + return false; + } + // concat inputs are 1D and their count is equal to Concat output shape + if (concat->get_output_partial_shape(0).is_dynamic()) + { + return false; + } + const auto concat_inputs = concat->inputs(); + // concat inputs must be single elements + if (concat_inputs.size() != shape_size(concat->get_shape())) + { + return false; + } + + const int64_t rank = concat->get_shape()[0]; + const int64_t raw_index = indices->cast_vector()[0]; + const int64_t positive_index = raw_index < 0 ? rank + raw_index : raw_index; + NGRAPH_CHECK(positive_index >= 0 && positive_index < rank); + + // gather takes exactly one element out of the Concat output + const auto gathered_concat_input = + concat_inputs[positive_index].get_source_output().get_node_shared_ptr(); + // Concat inputs are 1D, resulting tensor shape depends on Gather indices + auto gathered = gathered_concat_input; + if (indices_shape.empty()) + { + // gathering a scalar + const auto axis_const = op::Constant::create(element::i64, Shape{1}, {0}); + gathered = make_shared(gathered_concat_input, axis_const); + } + + output_values[0] = gathered; + + return true; + } +} // namespace gather + +bool op::util::GatherBase::evaluate(const HostTensorVector& outputs, + const HostTensorVector& inputs) const +{ + NGRAPH_OP_SCOPE(util_GatherBase_evaluate); + NGRAPH_CHECK(validate_host_tensor_vector(inputs, 3)); + NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1)); + + int64_t axis = 0; + switch (inputs[2]->get_element_type()) + { + case element::Type_t::i32: axis = inputs[2]->get_data_ptr()[0]; break; + case element::Type_t::i64: axis = inputs[2]->get_data_ptr()[0]; break; + case element::Type_t::i8: axis = inputs[2]->get_data_ptr()[0]; break; + case element::Type_t::i16: axis = inputs[2]->get_data_ptr()[0]; break; + case element::Type_t::u8: axis = inputs[2]->get_data_ptr()[0]; break; + case element::Type_t::u16: axis = inputs[2]->get_data_ptr()[0]; break; + case element::Type_t::u32: axis = inputs[2]->get_data_ptr()[0]; break; + case element::Type_t::u64: axis = inputs[2]->get_data_ptr()[0]; break; + default: throw ngraph_error("axis must be of integral data type."); + } + + if (axis < 0) + { + const auto& input_rank = get_input_partial_shape(0).rank(); + if (input_rank.is_static()) + { + axis += input_rank.get_length(); + } + } + return gather::evaluate_gather(inputs[0], inputs[1], outputs[0], axis, get_batch_dims()); +} + +bool op::util::GatherBase::evaluate_lower(const HostTensorVector& output_values) const +{ + if (!input_value(1).get_tensor().has_and_set_bound() || + !input_value(2).get_tensor().has_and_set_bound()) + return false; + return default_lower_bound_evaluator(this, output_values); +} + +bool op::util::GatherBase::evaluate_upper(const HostTensorVector& output_values) const +{ + if (!input_value(1).get_tensor().has_and_set_bound() || + !input_value(2).get_tensor().has_and_set_bound()) + return false; + return default_upper_bound_evaluator(this, output_values); +} + +bool op::util::GatherBase::constant_fold(OutputVector& output_values, + const OutputVector& input_values) +{ + // try the regular constant folding just for the Gather node + if (Node::constant_fold(output_values, input_values)) + { + return true; + } + else + { + return gather::cf_gather_with_subgraph( + output_values, input_values, get_output_partial_shape(0)); + } +} diff --git a/ngraph/test/type_prop/gather.cpp b/ngraph/test/type_prop/gather.cpp index e74b809f416..35bd4215e7d 100644 --- a/ngraph/test/type_prop/gather.cpp +++ b/ngraph/test/type_prop/gather.cpp @@ -27,6 +27,40 @@ TEST(type_prop, gather_axis_0) ASSERT_EQ(G->get_axis(), 0); } +TEST(type_prop, gather_7_uint8) +{ + // Gather_1 must allow even if indices is not int32/int64 + PartialShape data_shape{3, 2}; + PartialShape indices_shape{2, 2}; + PartialShape out_shape{2, 2, 2}; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::u8, indices_shape); + auto A = op::Constant::create(element::i64, Shape{}, {0}); + auto G = make_shared(D, I, A); + + ASSERT_EQ(G->get_element_type(), element::f32); + ASSERT_EQ(G->get_output_partial_shape(0), out_shape); + ASSERT_EQ(G->get_axis(), 0); +} + +TEST(type_prop, gather_7_float32) +{ + // Gather_1 should allow non int32/int64 indices + PartialShape data_shape{3, 2}; + PartialShape indices_shape{2, 2}; + PartialShape out_shape{2, 2, 2}; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::f32, indices_shape); + auto A = op::Constant::create(element::i64, Shape{}, {0}); + auto G = make_shared(D, I, A); + + ASSERT_EQ(G->get_element_type(), element::f32); + ASSERT_EQ(G->get_output_partial_shape(0), out_shape); + ASSERT_EQ(G->get_axis(), 0); +} + TEST(type_prop, gather_axis_1) { Shape params_shape{3, 3}; @@ -55,7 +89,7 @@ TEST(type_prop, gather_v1_incorrect_axis_shape) catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING(error.what(), - std::string("Axes input must be scalar or have 1 element (shape:")); + std::string("Axis input must be scalar or have 1 element")); } catch (...) { @@ -77,7 +111,7 @@ TEST(type_prop, gather_v1_axis_out_of_input_rank) catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING(error.what(), - std::string("The axis must => 0 and <= input_rank (axis:")); + std::string("The axis must be >= 0 and < data_rank. But instead got axis")); } catch (...) { @@ -241,7 +275,7 @@ TEST(type_prop, gather_7_axis_not_set) auto D = make_shared(element::f32, data_shape); auto I = make_shared(element::i64, indices_shape); - auto A = make_shared(element::f32, Shape{1}); + auto A = make_shared(element::i32, Shape{1}); auto G = make_shared(D, I, A); ASSERT_EQ(G->get_element_type(), element::f32); @@ -260,7 +294,7 @@ TEST(type_prop, gather_7_axis_not_set_positive_batch_dims) auto D = make_shared(element::f32, data_shape); auto I = make_shared(element::i64, indices_shape); - auto A = make_shared(element::f32, Shape{1}); + auto A = make_shared(element::i32, Shape{1}); auto G = make_shared(D, I, A, batch_dims); ASSERT_EQ(G->get_element_type(), element::f32); @@ -279,7 +313,7 @@ TEST(type_prop, gather_7_axis_not_set_negative_batch) auto D = make_shared(element::f32, data_shape); auto I = make_shared(element::i64, indices_shape); - auto A = make_shared(element::f32, Shape{1}); + auto A = make_shared(element::i32, Shape{1}); auto G = make_shared(D, I, A, batch_dims); ASSERT_EQ(G->get_element_type(), element::f32); @@ -303,7 +337,7 @@ TEST(type_prop, gather_7_incorrect_axis_shape) catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING(error.what(), - std::string("Axes input must be scalar or have 1 element")); + std::string("Axis input must be scalar or have 1 element")); } catch (...) { @@ -326,7 +360,7 @@ TEST(type_prop, gather_7_axis_out_of_input_rank) catch (const NodeValidationFailure& error) { EXPECT_HAS_SUBSTRING( - error.what(), std::string("The axis must be => 0 and < data_rank. But instead got")); + error.what(), std::string("The axis must be >= 0 and < data_rank. But instead got")); } catch (...) { @@ -420,3 +454,63 @@ TEST(type_prop, gather_7_batch_dims_less_indices_rank_check) FAIL() << "Deduced type check failed for unexpected reason"; } } + +// disabled until decision of type constrains for gather +TEST(type_prop, DISABLED_gather_7_indices_type_check) +{ + PartialShape data_shape{1, 20, 20, 22, 22}; + PartialShape indices_shape{1, 3}; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::f32, indices_shape); + int64_t axis = 4; + auto A = make_shared(element::i64, Shape{1}, vector{axis}); + int64_t batch_dims = 0; + + try + { + auto G = make_shared(D, I, A, batch_dims); + // Should have thrown, so fail if it didn't + FAIL() << "indices element_type check failed"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string("Indices element type must be of an integral number type")); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } +} + +// disabled until decision of type constrains for gather +TEST(type_prop, DISABLED_gather_7_axis_type_check) +{ + PartialShape data_shape{1, 20, 20, 22, 22}; + PartialShape indices_shape{1, 3}; + + auto D = make_shared(element::f32, data_shape); + auto I = make_shared(element::i32, indices_shape); + int64_t axis = 4; + auto A = make_shared(element::f32, Shape{1}, vector{axis}); + int64_t batch_dims = 0; + + try + { + auto G = make_shared(D, I, A, batch_dims); + // Should have thrown, so fail if it didn't + FAIL() << "axis element_type check failed"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string("Axis element type must be of an integral number type")); + } + catch (...) + { + FAIL() << "Deduced type check failed for unexpected reason"; + } +}