[nGraph] Gather v7 v1 up down transformations (#5118)
* v7->v1 works fine
* gather v1->v7 works fine
* changed op::v7 -> opset7, op:v1 -> opset1
* bumped to gather7 in all transformations
* applied review comments
* fixed pre-commit failures
* disabled incorrect unit-test, fixed f32->i32
* added comments why AddConvertToReorderTest was disabled
* Revert "bumped to gather7 in all transformations"
This reverts commit 965dc295
* fixed typos in v1->v7, v7->v1
* added GatherBase, redefined pattern in eliminate_unsqueeze gather, turned on Gather downgrading transformation
* resolved conflicts and build errors
* 🚀 finally EliminateUnsqueezeGather works: added inheritance of RTTI info from GatherBase
* fixed pre-commit failurer
* removed redundant debug code
* reverted f32 for Gather-1 indices in unit-tests and transformations
* relaxed restrictions for indices and axis type for v1::Gather
* corrected op scope
* moved type_check to validation_util.cpp
* removed type checks from Gather, removed upgrading transformation
* applied review coments
* fixed minor typos
This commit is contained in:
parent
f2306adf98
commit
41bdec12df
@ -0,0 +1,27 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API ConvertGather7ToGather1;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief ConvertGather7ToGather1 covert v7::Gather into v1::Gather.
|
||||
*/
|
||||
class ngraph::pass::ConvertGather7ToGather1 : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertGather7ToGather1();
|
||||
};
|
@ -37,12 +37,12 @@ static bool simplify_gather(std::shared_ptr<Node> node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto axis = gather->get_axis();
|
||||
|
||||
if (axis == opset3::Gather::AXIS_NOT_SET_VALUE) {
|
||||
if (!gather->is_axis_set()) {
|
||||
NGRAPH_DEBUG << "axis value not set";
|
||||
return false;
|
||||
}
|
||||
auto axis = gather->get_axis();
|
||||
|
||||
// case_1 : if the input tensor is of shape (4, 1, 4)
|
||||
// and axis = 1, then the gather would be simply
|
||||
@ -85,12 +85,12 @@ static bool simplify_gather_shapeof(shared_ptr<Node> node) {
|
||||
}
|
||||
auto gather_in_rank = gather->get_input_partial_shape(0).rank();
|
||||
auto indices_rank = gather->get_input_partial_shape(1).rank();
|
||||
auto axis = gather->get_axis();
|
||||
if (gather_in_rank.is_dynamic() || indices_rank.is_dynamic() ||
|
||||
axis == opset3::Gather::AXIS_NOT_SET_VALUE) {
|
||||
!gather->is_axis_set()) {
|
||||
NGRAPH_DEBUG << gather << " cannot simplify gather->shapeof";
|
||||
return false;
|
||||
}
|
||||
auto axis = gather->get_axis();
|
||||
|
||||
auto zero_axis = opset3::Constant::create<int64_t>(element::i64, Shape{}, {0});
|
||||
NodeVector new_ops;
|
||||
|
@ -42,6 +42,7 @@
|
||||
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
|
||||
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
|
||||
#include "transformations/op_conversions/convert_divide.hpp"
|
||||
#include "transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp"
|
||||
#include "transformations/op_conversions/convert_mod.hpp"
|
||||
#include "transformations/op_conversions/convert_minimum_to_power_and_max.hpp"
|
||||
#include "transformations/op_conversions/convert_negative.hpp"
|
||||
@ -156,6 +157,8 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
|
||||
conv_fusions->set_name("ngraph::pass::ConvFusions");
|
||||
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
// need to convert to Gather-1 until plugins do not support Gather-7
|
||||
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>();
|
||||
|
||||
auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
fq_fusions->add_matcher<ngraph::pass::FakeQuantizeMulFusion>();
|
||||
|
@ -19,7 +19,7 @@ ngraph::pass::EliminateUnsqueezeGather::EliminateUnsqueezeGather() {
|
||||
const auto unsqueeze = ngraph::pattern::wrap_type<ngraph::opset6::Unsqueeze>({unsqueezeInput, unsqueezeAxis}, pattern::consumers_count(1));
|
||||
const auto gatherIndices = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{}, {0});
|
||||
const auto gatherAxis = ngraph::pattern::any_input();
|
||||
const auto gather = ngraph::pattern::wrap_type<ngraph::opset6::Gather>({unsqueeze, gatherIndices, gatherAxis});
|
||||
const auto gather = ngraph::pattern::wrap_type<ngraph::op::util::GatherBase>({unsqueeze, gatherIndices, gatherAxis});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto& patternValue = m.get_pattern_value_map();
|
||||
|
@ -0,0 +1,43 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp"
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertGather7ToGather1, "ConvertGather7ToGather1", 0);
|
||||
|
||||
ngraph::pass::ConvertGather7ToGather1::ConvertGather7ToGather1() {
|
||||
MATCHER_SCOPE(ConvertGather7ToGather1);
|
||||
|
||||
auto gather_v7 = pattern::wrap_type<ngraph::opset7::Gather>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto gather_v7_node = std::dynamic_pointer_cast<ngraph::opset7::Gather>(m.get_match_root());
|
||||
if (!gather_v7_node)
|
||||
return false;
|
||||
|
||||
int64_t batch_dims = 0;
|
||||
if (gather_v7_node->is_axis_set())
|
||||
batch_dims = gather_v7_node->get_batch_dims();
|
||||
if (batch_dims != 0)
|
||||
return false;
|
||||
auto data_input = gather_v7_node->input_value(0);
|
||||
auto indices_input = gather_v7_node->input_value(1);
|
||||
auto axis_input = gather_v7_node->input_value(2);
|
||||
|
||||
auto gather_v1 = std::make_shared<ngraph::opset1::Gather>(data_input, indices_input, axis_input);
|
||||
gather_v1->set_friendly_name(gather_v7_node->get_friendly_name());
|
||||
ngraph::copy_runtime_info(gather_v7_node, gather_v1);
|
||||
ngraph::replace_node(gather_v7_node, gather_v1);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(gather_v7, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/op_conversions/convert_gather_v7_to_gather_v1.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST(TransformationTests, ConvertGather7toGather1) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
|
||||
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
|
||||
|
||||
auto gather_v7 = std::make_shared<ngraph::opset7::Gather>(data, indices, axis, 0);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v7}, ngraph::ParameterVector{data, indices});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertGather7ToGather1>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2, 3});
|
||||
auto indices = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i32, ngraph::Shape{2, 2});
|
||||
auto axis = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{1}, {0});
|
||||
|
||||
auto gather_v1 = std::make_shared<ngraph::opset1::Gather>(data, indices, axis);
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather_v1}, ngraph::ParameterVector{data, indices});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
@ -6,7 +6,6 @@
|
||||
|
||||
#include <ngraph_functions/utils/ngraph_helpers.hpp>
|
||||
#include <common_test_utils/test_common.hpp>
|
||||
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/gather_base.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -13,12 +13,10 @@ namespace ngraph
|
||||
namespace v1
|
||||
{
|
||||
/// \brief Gather slices from axis of params according to indices
|
||||
class NGRAPH_API Gather : public Op
|
||||
class NGRAPH_API Gather : public op::util::GatherBase
|
||||
{
|
||||
public:
|
||||
static const int64_t AXIS_NOT_SET_VALUE = std::numeric_limits<int64_t>::max();
|
||||
static constexpr NodeTypeInfo type_info{"Gather", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
Gather() = default;
|
||||
/// \param params The tensor from which slices are gathered
|
||||
/// \param indices Tensor with indexes to gather
|
||||
@ -28,35 +26,16 @@ namespace ngraph
|
||||
const Output<Node>& axis);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
int64_t get_axis() const;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const override;
|
||||
bool evaluate_lower(const HostTensorVector& outputs) const override;
|
||||
bool evaluate_upper(const HostTensorVector& outputs) const override;
|
||||
|
||||
bool constant_fold(OutputVector& output_values,
|
||||
const OutputVector& inputs_values) override;
|
||||
|
||||
private:
|
||||
static const int PARAMS;
|
||||
static const int INDICES;
|
||||
static const int AXIS;
|
||||
|
||||
bool evaluate_gather(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const;
|
||||
};
|
||||
} // namespace v1
|
||||
|
||||
namespace v7
|
||||
{
|
||||
/// \brief Gather slices from axis of params according to indices
|
||||
class NGRAPH_API Gather : public Op
|
||||
class NGRAPH_API Gather : public op::util::GatherBase
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
@ -76,23 +55,6 @@ namespace ngraph
|
||||
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
int64_t get_batch_dims() const;
|
||||
int64_t get_axis() const;
|
||||
bool is_axis_set() const;
|
||||
|
||||
bool evaluate_gather(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const;
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const override;
|
||||
bool evaluate_lower(const HostTensorVector& outputs) const override;
|
||||
bool evaluate_upper(const HostTensorVector& outputs) const override;
|
||||
|
||||
bool constant_fold(OutputVector& output_values,
|
||||
const OutputVector& inputs_values) override;
|
||||
|
||||
private:
|
||||
int64_t m_batch_dims = 0;
|
||||
};
|
||||
} // namespace v7
|
||||
} // namespace op
|
||||
|
50
ngraph/core/include/ngraph/op/util/gather_base.hpp
Normal file
50
ngraph/core/include/ngraph/op/util/gather_base.hpp
Normal file
@ -0,0 +1,50 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace util
|
||||
{
|
||||
/// \brief GatherBase basic class for Gather v1 and v7
|
||||
class NGRAPH_API GatherBase : public Op
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
GatherBase() = default;
|
||||
|
||||
/// \param data The tensor from which slices are gathered
|
||||
/// \param indices Tensor with indexes to gather
|
||||
/// \param axis The tensor is a dimension index to gather data from
|
||||
/// \param batch_dims The number of batch dimension in data and indices tensors
|
||||
GatherBase(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& axis,
|
||||
const int64_t batch_dims = 0);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
int64_t get_batch_dims() const;
|
||||
int64_t get_axis() const;
|
||||
bool is_axis_set() const;
|
||||
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const override;
|
||||
|
||||
bool evaluate_lower(const HostTensorVector& outputs) const override;
|
||||
bool evaluate_upper(const HostTensorVector& outputs) const override;
|
||||
|
||||
bool constant_fold(OutputVector& output_values,
|
||||
const OutputVector& inputs_values) override;
|
||||
|
||||
protected:
|
||||
int64_t m_batch_dims = 0;
|
||||
};
|
||||
} // namespace utils
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
@ -3,30 +3,19 @@
|
||||
//
|
||||
|
||||
#include "ngraph/op/gather.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/squeeze.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/runtime/reference/gather.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::v1::Gather::type_info;
|
||||
const int64_t op::v1::Gather::AXIS_NOT_SET_VALUE;
|
||||
|
||||
const int op::v1::Gather::PARAMS = 0;
|
||||
const int op::v1::Gather::INDICES = 1;
|
||||
const int op::v1::Gather::AXIS = 2;
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::Gather, "Gather", 1, op::util::GatherBase);
|
||||
|
||||
op::v1::Gather::Gather(const Output<Node>& params,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& axes)
|
||||
: Op({params, indices, axes})
|
||||
: GatherBase(params, indices, axes)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
@ -37,107 +26,38 @@ bool ngraph::op::v1::Gather::visit_attributes(AttributeVisitor& visitor)
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::v1::Gather::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v1_Gather_validate_and_infer_types);
|
||||
const auto& input_rank = get_input_partial_shape(PARAMS).rank();
|
||||
const auto& axis_shape = get_input_partial_shape(AXIS);
|
||||
const auto& axis_rank = axis_shape.rank();
|
||||
|
||||
if (axis_rank.is_static() && axis_shape.is_static())
|
||||
{
|
||||
const auto axis_is_scalar = axis_rank.get_length() == 0;
|
||||
const auto axis_has_one_elem =
|
||||
axis_rank.get_length() == 1 && axis_shape[0].get_length() == 1;
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axis_is_scalar || axis_has_one_elem,
|
||||
"Axes input must be scalar or have 1 element (shape: ",
|
||||
axis_shape,
|
||||
").");
|
||||
}
|
||||
|
||||
int64_t axis = get_axis();
|
||||
if (input_rank.is_static() && axis != AXIS_NOT_SET_VALUE)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axis < input_rank.get_length(),
|
||||
"The axis must => 0 and <= input_rank (axis: ",
|
||||
axis,
|
||||
").");
|
||||
}
|
||||
|
||||
element::Type result_et = get_input_element_type(PARAMS);
|
||||
|
||||
const PartialShape& params_shape = get_input_partial_shape(PARAMS);
|
||||
const PartialShape& indices_shape = get_input_partial_shape(INDICES);
|
||||
|
||||
PartialShape result_shape;
|
||||
if (params_shape.rank().is_static() && indices_shape.rank().is_static() &&
|
||||
axis != AXIS_NOT_SET_VALUE)
|
||||
{
|
||||
std::vector<Dimension> result_dims(params_shape.rank().get_length() +
|
||||
indices_shape.rank().get_length() - 1);
|
||||
int64_t i = 0;
|
||||
for (; i < axis; i++)
|
||||
{
|
||||
result_dims[i] = params_shape[i];
|
||||
}
|
||||
for (int64_t j = 0; j < indices_shape.rank().get_length(); i++, j++)
|
||||
{
|
||||
result_dims[i] = indices_shape[j];
|
||||
}
|
||||
for (int64_t j = axis + 1; j < params_shape.rank().get_length(); i++, j++)
|
||||
{
|
||||
result_dims[i] = params_shape[j];
|
||||
}
|
||||
|
||||
result_shape = PartialShape(result_dims);
|
||||
}
|
||||
else
|
||||
{
|
||||
result_shape = PartialShape::dynamic();
|
||||
}
|
||||
|
||||
set_output_type(0, result_et, result_shape);
|
||||
}
|
||||
|
||||
int64_t op::v1::Gather::get_axis() const
|
||||
{
|
||||
int64_t axis = AXIS_NOT_SET_VALUE;
|
||||
if (const auto& const_op = get_constant_from_source(input_value(AXIS)))
|
||||
{
|
||||
axis = const_op->cast_vector<int64_t>()[0];
|
||||
}
|
||||
if (axis < 0)
|
||||
{
|
||||
const auto& input_rank = get_input_partial_shape(PARAMS).rank();
|
||||
if (input_rank.is_static())
|
||||
{
|
||||
axis += input_rank.get_length();
|
||||
}
|
||||
}
|
||||
return axis;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v1::Gather::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v1_Gather_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<v1::Gather>(new_args.at(PARAMS), new_args.at(INDICES), new_args.at(AXIS));
|
||||
return make_shared<v1::Gather>(new_args.at(0), new_args.at(1), new_args.at(2));
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v7::Gather, "Gather", 7);
|
||||
NGRAPH_RTTI_DEFINITION(op::v7::Gather, "Gather", 7, op::util::GatherBase);
|
||||
|
||||
op::v7::Gather::Gather(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& axis,
|
||||
const int64_t batch_dims)
|
||||
: Op({data, indices, axis})
|
||||
, m_batch_dims(batch_dims)
|
||||
: GatherBase(data, indices, axis, batch_dims)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::v7::Gather::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Gather_validate_and_infer_types);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
get_input_element_type(1).is_integral_number(),
|
||||
"Indices element type must be of an integral number type.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
get_input_element_type(2).is_integral_number(),
|
||||
"Axis element type must be of an integral number type.");
|
||||
|
||||
op::util::GatherBase::validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool ngraph::op::v7::Gather::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Gather_visit_attributes);
|
||||
@ -145,440 +65,9 @@ bool ngraph::op::v7::Gather::visit_attributes(AttributeVisitor& visitor)
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::v7::Gather::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Gather_validate_and_infer_types);
|
||||
const auto& data_type = get_input_element_type(0);
|
||||
const auto& indices_type = get_input_element_type(1);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
indices_type == element::Type_t::i32 ||
|
||||
indices_type == element::Type_t::i64,
|
||||
"indices must be of int32 or int64 type. But instead got: ",
|
||||
indices_type);
|
||||
|
||||
const auto& data_pshape = get_input_partial_shape(0);
|
||||
const auto& indices_pshape = get_input_partial_shape(1);
|
||||
const auto& axis_pshape = get_input_partial_shape(2);
|
||||
auto data_rank = data_pshape.rank();
|
||||
auto indices_rank = indices_pshape.rank();
|
||||
auto axis_rank = axis_pshape.rank();
|
||||
|
||||
if (axis_rank.is_static() && axis_pshape.is_static())
|
||||
{
|
||||
const auto axis_is_scalar = axis_rank.get_length() == 0;
|
||||
const auto axis_has_one_elem =
|
||||
axis_rank.get_length() == 1 && axis_pshape[0].get_length() == 1;
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
axis_is_scalar || axis_has_one_elem,
|
||||
"Axes input must be scalar or have 1 element. But instead got axis_shape = ",
|
||||
axis_pshape);
|
||||
}
|
||||
|
||||
int64_t batch_dims = get_batch_dims(); // will not be converted to positive if axis is not set
|
||||
if (is_axis_set())
|
||||
{
|
||||
int64_t axis = get_axis();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
batch_dims <= axis,
|
||||
"The batch_dims <= axis. But instead got: batch_dims = ",
|
||||
batch_dims,
|
||||
", axis = ",
|
||||
axis);
|
||||
|
||||
if (data_rank.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axis >= 0 && axis < data_rank.get_length(),
|
||||
"The axis must be => 0 and < data_rank. But instead got axis = ",
|
||||
axis,
|
||||
" data_rank = ",
|
||||
data_rank.get_length());
|
||||
}
|
||||
}
|
||||
|
||||
if (indices_rank.is_static() && batch_dims >= 0)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
batch_dims <= indices_rank.get_length(),
|
||||
"The batch_dims must be <= indices_rank. But instead got: batch_dims = ",
|
||||
batch_dims,
|
||||
", indices_rank = ",
|
||||
indices_rank.get_length());
|
||||
}
|
||||
|
||||
if (data_rank.is_static() && indices_rank.is_static())
|
||||
{
|
||||
if (batch_dims >= 0)
|
||||
{
|
||||
auto out_rank = data_rank.get_length() + indices_rank.get_length() - 1 - batch_dims;
|
||||
PartialShape output_pshape = PartialShape::dynamic(out_rank);
|
||||
|
||||
// implementation of out_shape formula
|
||||
// data.shape[:batch_dims] + data.shape[batch_dims:axis] + indices.shape[batch_dims:] +
|
||||
// data.shape[axis + 1:]
|
||||
int i = 0;
|
||||
for (; i < batch_dims; i++)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
data_pshape[i].compatible(indices_pshape[i]),
|
||||
"Shapes ",
|
||||
data_pshape,
|
||||
" and ",
|
||||
indices_pshape,
|
||||
" are not consistent. data and indices must have equal or "
|
||||
"intersecting sizes until batch_dims");
|
||||
|
||||
output_pshape[i] = data_pshape[i] & indices_pshape[i];
|
||||
}
|
||||
|
||||
if (is_axis_set())
|
||||
{
|
||||
int64_t axis = get_axis();
|
||||
for (; i < axis; i++)
|
||||
{
|
||||
output_pshape[i] = data_pshape[i];
|
||||
}
|
||||
for (; i < axis + indices_rank.get_length() - batch_dims; i++)
|
||||
{
|
||||
output_pshape[i] = indices_pshape[batch_dims - axis + i];
|
||||
}
|
||||
for (; i < out_rank; i++)
|
||||
{
|
||||
output_pshape[i] = data_pshape[batch_dims + 1 - indices_rank.get_length() + i];
|
||||
}
|
||||
}
|
||||
|
||||
set_output_type(0, data_type, output_pshape);
|
||||
}
|
||||
else if (batch_dims < 0)
|
||||
{
|
||||
// batch_dims < 0 could be only if axis is not set
|
||||
// as soon as axis value will arrive negative batch_dims should be resolved
|
||||
// batch_dims value will be within [0, data_rank] && [0, indices_rank]
|
||||
int64_t max_rank = data_rank.get_length() + indices_rank.get_length() - 1;
|
||||
int64_t min_rank = max_rank - max(data_rank.get_length(), indices_rank.get_length());
|
||||
|
||||
set_output_type(0, data_type, PartialShape::dynamic(Dimension(min_rank, max_rank)));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
set_output_type(0, data_type, PartialShape::dynamic());
|
||||
}
|
||||
}
|
||||
|
||||
int64_t op::v7::Gather::get_axis() const
|
||||
{
|
||||
const auto& const_op = get_constant_from_source(input_value(2));
|
||||
int64_t axis = const_op->cast_vector<int64_t>()[0];
|
||||
if (axis < 0)
|
||||
{
|
||||
const auto& data_rank = get_input_partial_shape(0).rank();
|
||||
if (data_rank.is_static())
|
||||
{
|
||||
axis += data_rank.get_length();
|
||||
}
|
||||
}
|
||||
return axis;
|
||||
}
|
||||
|
||||
int64_t op::v7::Gather::get_batch_dims() const
|
||||
{
|
||||
if (m_batch_dims < 0 && is_axis_set())
|
||||
return get_axis() + m_batch_dims;
|
||||
else
|
||||
return m_batch_dims;
|
||||
}
|
||||
|
||||
bool op::v7::Gather::is_axis_set() const
|
||||
{
|
||||
const auto& axes_constant = get_constant_from_source(input_value(2));
|
||||
if (axes_constant)
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v7::Gather::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Gather_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<v7::Gather>(new_args.at(0), new_args.at(1), new_args.at(2), m_batch_dims);
|
||||
}
|
||||
|
||||
namespace gather
|
||||
{
|
||||
template <element::Type_t ET>
|
||||
bool evaluate(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1,
|
||||
const HostTensorPtr& out,
|
||||
size_t axis,
|
||||
size_t batch_dims)
|
||||
{
|
||||
using T = typename element_type_traits<ET>::value_type;
|
||||
Shape params_shape = arg0->get_shape();
|
||||
Shape indices_shape = arg1->get_shape();
|
||||
Shape out_shape(params_shape.size() + indices_shape.size() - 1 - batch_dims);
|
||||
uint64_t i = 0;
|
||||
for (; i < axis; i++)
|
||||
{
|
||||
out_shape[i] = params_shape[i];
|
||||
}
|
||||
for (uint64_t j = batch_dims; j < indices_shape.size(); i++, j++)
|
||||
{
|
||||
out_shape[i] = indices_shape[j];
|
||||
}
|
||||
for (uint64_t j = axis + 1; j < params_shape.size(); i++, j++)
|
||||
{
|
||||
out_shape[i] = params_shape[j];
|
||||
}
|
||||
|
||||
out->set_shape(out_shape);
|
||||
|
||||
if (arg1->get_element_type() == element::i64)
|
||||
{
|
||||
runtime::reference::gather<T, int64_t>(arg0->get_data_ptr<ET>(),
|
||||
arg1->get_data_ptr<int64_t>(),
|
||||
out->get_data_ptr<ET>(),
|
||||
arg0->get_shape(),
|
||||
arg1->get_shape(),
|
||||
out->get_shape(),
|
||||
axis,
|
||||
batch_dims);
|
||||
}
|
||||
else if (arg1->get_element_type() == element::i32)
|
||||
{
|
||||
runtime::reference::gather<T, int32_t>(arg0->get_data_ptr<ET>(),
|
||||
arg1->get_data_ptr<int32_t>(),
|
||||
out->get_data_ptr<ET>(),
|
||||
arg0->get_shape(),
|
||||
arg1->get_shape(),
|
||||
out->get_shape(),
|
||||
axis,
|
||||
batch_dims);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Unexpected type");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool evaluate_gather(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1,
|
||||
const HostTensorPtr& out,
|
||||
size_t axis,
|
||||
size_t batch_dims = 0)
|
||||
{
|
||||
bool rc = true;
|
||||
|
||||
switch (out->get_element_type())
|
||||
{
|
||||
NGRAPH_TYPE_CASE(evaluate_gather, i32, arg0, arg1, out, axis, batch_dims);
|
||||
NGRAPH_TYPE_CASE(evaluate_gather, i64, arg0, arg1, out, axis, batch_dims);
|
||||
NGRAPH_TYPE_CASE(evaluate_gather, u32, arg0, arg1, out, axis, batch_dims);
|
||||
NGRAPH_TYPE_CASE(evaluate_gather, u64, arg0, arg1, out, axis, batch_dims);
|
||||
NGRAPH_TYPE_CASE(evaluate_gather, f16, arg0, arg1, out, axis, batch_dims);
|
||||
NGRAPH_TYPE_CASE(evaluate_gather, f32, arg0, arg1, out, axis, batch_dims);
|
||||
NGRAPH_TYPE_CASE(evaluate_gather, boolean, arg0, arg1, out, axis, batch_dims);
|
||||
default: rc = false; break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
bool cf_gather_with_subgraph(OutputVector& output_values,
|
||||
const OutputVector& input_values,
|
||||
const PartialShape& gather_ps)
|
||||
{
|
||||
if (gather_ps.is_dynamic() || input_values.size() != 3)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto concat =
|
||||
std::dynamic_pointer_cast<op::Concat>(input_values[0].get_node_shared_ptr());
|
||||
const auto indices =
|
||||
std::dynamic_pointer_cast<op::Constant>(input_values[1].get_node_shared_ptr());
|
||||
const auto axis =
|
||||
std::dynamic_pointer_cast<op::Constant>(input_values[2].get_node_shared_ptr());
|
||||
|
||||
if (!concat || !indices || !axis)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// only along axis=0
|
||||
if (axis->cast_vector<int64_t>()[0] != 0 || concat->get_axis() != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// only single indices are accepted
|
||||
const auto indices_shape = indices->get_shape();
|
||||
if (indices_shape.size() > 1 || (indices_shape.size() == 1 && indices_shape[0] > 1))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// concat inputs are 1D and their count is equal to Concat output shape
|
||||
if (concat->get_output_partial_shape(0).is_dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
const auto concat_inputs = concat->inputs();
|
||||
// concat inputs must be single elements
|
||||
if (concat_inputs.size() != shape_size(concat->get_shape()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const int64_t rank = concat->get_shape()[0];
|
||||
const int64_t raw_index = indices->cast_vector<int64_t>()[0];
|
||||
const int64_t positive_index = raw_index < 0 ? rank + raw_index : raw_index;
|
||||
NGRAPH_CHECK(positive_index >= 0 && positive_index < rank);
|
||||
|
||||
// gather takes exactly one element out of the Concat output
|
||||
const auto gathered_concat_input =
|
||||
concat_inputs[positive_index].get_source_output().get_node_shared_ptr();
|
||||
// Concat inputs are 1D, resulting tensor shape depends on Gather indices
|
||||
auto gathered = gathered_concat_input;
|
||||
if (indices_shape.empty())
|
||||
{
|
||||
// gathering a scalar
|
||||
const auto axes = op::Constant::create(element::i64, Shape{1}, {0});
|
||||
gathered = make_shared<op::v0::Squeeze>(gathered_concat_input, axes);
|
||||
}
|
||||
|
||||
output_values[0] = gathered;
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace gather
|
||||
|
||||
bool op::v1::Gather::evaluate_gather(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const
|
||||
{
|
||||
int64_t axis = 0;
|
||||
switch (inputs[2]->get_element_type())
|
||||
{
|
||||
case element::Type_t::i8: axis = inputs[2]->get_data_ptr<element::Type_t::i8>()[0]; break;
|
||||
case element::Type_t::i16: axis = inputs[2]->get_data_ptr<element::Type_t::i16>()[0]; break;
|
||||
case element::Type_t::i32: axis = inputs[2]->get_data_ptr<element::Type_t::i32>()[0]; break;
|
||||
case element::Type_t::i64: axis = inputs[2]->get_data_ptr<element::Type_t::i64>()[0]; break;
|
||||
case element::Type_t::u8: axis = inputs[2]->get_data_ptr<element::Type_t::u8>()[0]; break;
|
||||
case element::Type_t::u16: axis = inputs[2]->get_data_ptr<element::Type_t::u16>()[0]; break;
|
||||
case element::Type_t::u32: axis = inputs[2]->get_data_ptr<element::Type_t::u32>()[0]; break;
|
||||
case element::Type_t::u64: axis = inputs[2]->get_data_ptr<element::Type_t::u64>()[0]; break;
|
||||
default: throw ngraph_error("axis element type is not integral data type");
|
||||
}
|
||||
|
||||
if (axis < 0)
|
||||
{
|
||||
const auto& input_rank = get_input_partial_shape(PARAMS).rank();
|
||||
if (input_rank.is_static())
|
||||
{
|
||||
axis += input_rank.get_length();
|
||||
}
|
||||
}
|
||||
return gather::evaluate_gather(inputs[0], inputs[1], outputs[0], axis);
|
||||
}
|
||||
|
||||
bool op::v1::Gather::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v1_Gather_evaluate);
|
||||
NGRAPH_CHECK(validate_host_tensor_vector(inputs, 3));
|
||||
NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1));
|
||||
return evaluate_gather(outputs, inputs);
|
||||
}
|
||||
|
||||
bool op::v1::Gather::evaluate_lower(const HostTensorVector& output_values) const
|
||||
{
|
||||
if (!input_value(INDICES).get_tensor().has_and_set_bound() ||
|
||||
!input_value(AXIS).get_tensor().has_and_set_bound())
|
||||
return false;
|
||||
return default_lower_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::v1::Gather::evaluate_upper(const HostTensorVector& output_values) const
|
||||
{
|
||||
if (!input_value(INDICES).get_tensor().has_and_set_bound() ||
|
||||
!input_value(AXIS).get_tensor().has_and_set_bound())
|
||||
return false;
|
||||
return default_upper_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::v1::Gather::constant_fold(OutputVector& output_values, const OutputVector& input_values)
|
||||
{
|
||||
// try the regular constant folding just for the Gather node
|
||||
if (Node::constant_fold(output_values, input_values))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return gather::cf_gather_with_subgraph(
|
||||
output_values, input_values, get_output_partial_shape(0));
|
||||
}
|
||||
}
|
||||
|
||||
bool op::v7::Gather::evaluate_gather(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const
|
||||
{
|
||||
int64_t axis = 0;
|
||||
switch (inputs[2]->get_element_type())
|
||||
{
|
||||
case element::Type_t::i32: axis = inputs[2]->get_data_ptr<element::Type_t::i32>()[0]; break;
|
||||
case element::Type_t::i64: axis = inputs[2]->get_data_ptr<element::Type_t::i64>()[0]; break;
|
||||
default: throw ngraph_error("axis must be of int32 or int64 type.");
|
||||
}
|
||||
|
||||
if (axis < 0)
|
||||
{
|
||||
const auto& input_rank = get_input_partial_shape(0).rank();
|
||||
if (input_rank.is_static())
|
||||
{
|
||||
axis += input_rank.get_length();
|
||||
}
|
||||
}
|
||||
return gather::evaluate_gather(inputs[0], inputs[1], outputs[0], axis, get_batch_dims());
|
||||
}
|
||||
|
||||
bool op::v7::Gather::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Gather_evaluate);
|
||||
NGRAPH_CHECK(validate_host_tensor_vector(inputs, 3));
|
||||
NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1));
|
||||
return evaluate_gather(outputs, inputs);
|
||||
}
|
||||
|
||||
bool op::v7::Gather::evaluate_lower(const HostTensorVector& output_values) const
|
||||
{
|
||||
if (!input_value(1).get_tensor().has_and_set_bound() ||
|
||||
!input_value(2).get_tensor().has_and_set_bound())
|
||||
return false;
|
||||
return default_lower_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::v7::Gather::evaluate_upper(const HostTensorVector& output_values) const
|
||||
{
|
||||
if (!input_value(1).get_tensor().has_and_set_bound() ||
|
||||
!input_value(2).get_tensor().has_and_set_bound())
|
||||
return false;
|
||||
return default_upper_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::v7::Gather::constant_fold(OutputVector& output_values, const OutputVector& input_values)
|
||||
{
|
||||
// try the regular constant folding just for the Gather node
|
||||
if (Node::constant_fold(output_values, input_values))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return gather::cf_gather_with_subgraph(
|
||||
output_values, input_values, get_output_partial_shape(0));
|
||||
}
|
||||
}
|
||||
|
393
ngraph/core/src/op/util/gather_base.cpp
Normal file
393
ngraph/core/src/op/util/gather_base.cpp
Normal file
@ -0,0 +1,393 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/op/util/gather_base.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/squeeze.hpp"
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/runtime/reference/gather.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
|
||||
#include <ngraph/validation_util.hpp>
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::util::GatherBase, "GatherBase", 7);
|
||||
|
||||
op::util::GatherBase::GatherBase(const Output<Node>& data,
|
||||
const Output<Node>& indices,
|
||||
const Output<Node>& axis,
|
||||
const int64_t batch_dims)
|
||||
: Op({data, indices, axis})
|
||||
, m_batch_dims(batch_dims)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::util::GatherBase::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(util_GatherBase_validate_and_infer_types);
|
||||
const auto& data_type = get_input_element_type(0);
|
||||
|
||||
const auto& data_pshape = get_input_partial_shape(0);
|
||||
const auto& indices_pshape = get_input_partial_shape(1);
|
||||
const auto& axis_pshape = get_input_partial_shape(2);
|
||||
auto data_rank = data_pshape.rank();
|
||||
auto indices_rank = indices_pshape.rank();
|
||||
auto axis_rank = axis_pshape.rank();
|
||||
|
||||
if (axis_rank.is_static() && axis_pshape.is_static())
|
||||
{
|
||||
const auto axis_is_scalar = axis_rank.get_length() == 0;
|
||||
const auto axis_has_one_elem =
|
||||
axis_rank.get_length() == 1 && axis_pshape[0].get_length() == 1;
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
axis_is_scalar || axis_has_one_elem,
|
||||
"Axis input must be scalar or have 1 element. But instead got axis_shape = ",
|
||||
axis_pshape);
|
||||
}
|
||||
|
||||
int64_t batch_dims = get_batch_dims(); // will not be converted to positive if axis is not set
|
||||
if (is_axis_set())
|
||||
{
|
||||
int64_t axis = get_axis();
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
batch_dims <= axis,
|
||||
"The batch_dims <= axis. But instead got: batch_dims = ",
|
||||
batch_dims,
|
||||
", axis = ",
|
||||
axis);
|
||||
|
||||
if (data_rank.is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axis >= 0 && axis < data_rank.get_length(),
|
||||
"The axis must be >= 0 and < data_rank. But instead got axis = ",
|
||||
axis,
|
||||
" data_rank = ",
|
||||
data_rank.get_length());
|
||||
}
|
||||
}
|
||||
|
||||
if (indices_rank.is_static() && batch_dims >= 0)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
batch_dims <= indices_rank.get_length(),
|
||||
"The batch_dims must be <= indices_rank. But instead got: batch_dims = ",
|
||||
batch_dims,
|
||||
", indices_rank = ",
|
||||
indices_rank.get_length());
|
||||
}
|
||||
|
||||
if (data_rank.is_static() && indices_rank.is_static())
|
||||
{
|
||||
if (batch_dims >= 0)
|
||||
{
|
||||
auto out_rank = data_rank.get_length() + indices_rank.get_length() - 1 - batch_dims;
|
||||
PartialShape output_pshape = PartialShape::dynamic(out_rank);
|
||||
|
||||
// implementation of out_shape formula
|
||||
// data.shape[:batch_dims] + data.shape[batch_dims:axis] + indices.shape[batch_dims:] +
|
||||
// data.shape[axis + 1:]
|
||||
int i = 0;
|
||||
for (; i < batch_dims; i++)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
data_pshape[i].compatible(indices_pshape[i]),
|
||||
"Shapes ",
|
||||
data_pshape,
|
||||
" and ",
|
||||
indices_pshape,
|
||||
" are not consistent. data and indices must have equal or "
|
||||
"intersecting sizes until batch_dims");
|
||||
|
||||
output_pshape[i] = data_pshape[i] & indices_pshape[i];
|
||||
}
|
||||
|
||||
if (is_axis_set())
|
||||
{
|
||||
int64_t axis = get_axis();
|
||||
for (; i < axis; i++)
|
||||
{
|
||||
output_pshape[i] = data_pshape[i];
|
||||
}
|
||||
for (; i < axis + indices_rank.get_length() - batch_dims; i++)
|
||||
{
|
||||
output_pshape[i] = indices_pshape[batch_dims - axis + i];
|
||||
}
|
||||
for (; i < out_rank; i++)
|
||||
{
|
||||
output_pshape[i] = data_pshape[batch_dims + 1 - indices_rank.get_length() + i];
|
||||
}
|
||||
}
|
||||
|
||||
set_output_type(0, data_type, output_pshape);
|
||||
}
|
||||
else if (batch_dims < 0)
|
||||
{
|
||||
// batch_dims < 0 could be only if axis is not set
|
||||
// as soon as axis value will arrive negative batch_dims should be resolved
|
||||
// batch_dims value will be within [0, data_rank] && [0, indices_rank]
|
||||
int64_t max_rank = data_rank.get_length() + indices_rank.get_length() - 1;
|
||||
int64_t min_rank = max_rank - max(data_rank.get_length(), indices_rank.get_length());
|
||||
|
||||
set_output_type(0, data_type, PartialShape::dynamic(Dimension(min_rank, max_rank)));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
set_output_type(0, data_type, PartialShape::dynamic());
|
||||
}
|
||||
}
|
||||
|
||||
int64_t op::util::GatherBase::get_axis() const
|
||||
{
|
||||
const auto& const_op = get_constant_from_source(input_value(2));
|
||||
if (!const_op)
|
||||
throw ngraph_error("axis value is not set");
|
||||
|
||||
int64_t axis = const_op->cast_vector<int64_t>()[0];
|
||||
if (axis < 0)
|
||||
{
|
||||
const auto& data_rank = get_input_partial_shape(0).rank();
|
||||
if (data_rank.is_static())
|
||||
{
|
||||
axis += data_rank.get_length();
|
||||
}
|
||||
}
|
||||
return axis;
|
||||
}
|
||||
|
||||
int64_t op::util::GatherBase::get_batch_dims() const
|
||||
{
|
||||
if (m_batch_dims < 0 && is_axis_set())
|
||||
return get_axis() + m_batch_dims;
|
||||
else
|
||||
return m_batch_dims;
|
||||
}
|
||||
|
||||
bool op::util::GatherBase::is_axis_set() const
|
||||
{
|
||||
const auto& axis_constant = get_constant_from_source(input_value(2));
|
||||
if (axis_constant)
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace gather
|
||||
{
|
||||
template <element::Type_t ET>
|
||||
bool evaluate(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1,
|
||||
const HostTensorPtr& out,
|
||||
size_t axis,
|
||||
size_t batch_dims)
|
||||
{
|
||||
using T = typename element_type_traits<ET>::value_type;
|
||||
Shape params_shape = arg0->get_shape();
|
||||
Shape indices_shape = arg1->get_shape();
|
||||
Shape out_shape(params_shape.size() + indices_shape.size() - 1 - batch_dims);
|
||||
uint64_t i = 0;
|
||||
for (; i < axis; i++)
|
||||
{
|
||||
out_shape[i] = params_shape[i];
|
||||
}
|
||||
for (uint64_t j = batch_dims; j < indices_shape.size(); i++, j++)
|
||||
{
|
||||
out_shape[i] = indices_shape[j];
|
||||
}
|
||||
for (uint64_t j = axis + 1; j < params_shape.size(); i++, j++)
|
||||
{
|
||||
out_shape[i] = params_shape[j];
|
||||
}
|
||||
|
||||
out->set_shape(out_shape);
|
||||
|
||||
if (arg1->get_element_type() == element::i64)
|
||||
{
|
||||
runtime::reference::gather<T, int64_t>(arg0->get_data_ptr<ET>(),
|
||||
arg1->get_data_ptr<int64_t>(),
|
||||
out->get_data_ptr<ET>(),
|
||||
arg0->get_shape(),
|
||||
arg1->get_shape(),
|
||||
out->get_shape(),
|
||||
axis,
|
||||
batch_dims);
|
||||
}
|
||||
else if (arg1->get_element_type() == element::i32)
|
||||
{
|
||||
runtime::reference::gather<T, int32_t>(arg0->get_data_ptr<ET>(),
|
||||
arg1->get_data_ptr<int32_t>(),
|
||||
out->get_data_ptr<ET>(),
|
||||
arg0->get_shape(),
|
||||
arg1->get_shape(),
|
||||
out->get_shape(),
|
||||
axis,
|
||||
batch_dims);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Unexpected type");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool evaluate_gather(const HostTensorPtr& arg0,
|
||||
const HostTensorPtr& arg1,
|
||||
const HostTensorPtr& out,
|
||||
size_t axis,
|
||||
size_t batch_dims = 0)
|
||||
{
|
||||
bool rc = true;
|
||||
|
||||
switch (out->get_element_type())
|
||||
{
|
||||
NGRAPH_TYPE_CASE(evaluate_gather, i32, arg0, arg1, out, axis, batch_dims);
|
||||
NGRAPH_TYPE_CASE(evaluate_gather, i64, arg0, arg1, out, axis, batch_dims);
|
||||
NGRAPH_TYPE_CASE(evaluate_gather, u32, arg0, arg1, out, axis, batch_dims);
|
||||
NGRAPH_TYPE_CASE(evaluate_gather, u64, arg0, arg1, out, axis, batch_dims);
|
||||
NGRAPH_TYPE_CASE(evaluate_gather, f16, arg0, arg1, out, axis, batch_dims);
|
||||
NGRAPH_TYPE_CASE(evaluate_gather, f32, arg0, arg1, out, axis, batch_dims);
|
||||
NGRAPH_TYPE_CASE(evaluate_gather, boolean, arg0, arg1, out, axis, batch_dims);
|
||||
default: rc = false; break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
bool cf_gather_with_subgraph(OutputVector& output_values,
|
||||
const OutputVector& input_values,
|
||||
const PartialShape& gather_ps)
|
||||
{
|
||||
if (gather_ps.is_dynamic() || input_values.size() != 3)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto concat =
|
||||
std::dynamic_pointer_cast<op::Concat>(input_values[0].get_node_shared_ptr());
|
||||
const auto indices =
|
||||
std::dynamic_pointer_cast<op::Constant>(input_values[1].get_node_shared_ptr());
|
||||
const auto axis =
|
||||
std::dynamic_pointer_cast<op::Constant>(input_values[2].get_node_shared_ptr());
|
||||
|
||||
if (!concat || !indices || !axis)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// only along axis=0
|
||||
if (axis->cast_vector<int64_t>()[0] != 0 || concat->get_axis() != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// only single indices are accepted
|
||||
const auto indices_shape = indices->get_shape();
|
||||
if (indices_shape.size() > 1 || (indices_shape.size() == 1 && indices_shape[0] > 1))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// concat inputs are 1D and their count is equal to Concat output shape
|
||||
if (concat->get_output_partial_shape(0).is_dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
const auto concat_inputs = concat->inputs();
|
||||
// concat inputs must be single elements
|
||||
if (concat_inputs.size() != shape_size(concat->get_shape()))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const int64_t rank = concat->get_shape()[0];
|
||||
const int64_t raw_index = indices->cast_vector<int64_t>()[0];
|
||||
const int64_t positive_index = raw_index < 0 ? rank + raw_index : raw_index;
|
||||
NGRAPH_CHECK(positive_index >= 0 && positive_index < rank);
|
||||
|
||||
// gather takes exactly one element out of the Concat output
|
||||
const auto gathered_concat_input =
|
||||
concat_inputs[positive_index].get_source_output().get_node_shared_ptr();
|
||||
// Concat inputs are 1D, resulting tensor shape depends on Gather indices
|
||||
auto gathered = gathered_concat_input;
|
||||
if (indices_shape.empty())
|
||||
{
|
||||
// gathering a scalar
|
||||
const auto axis_const = op::Constant::create(element::i64, Shape{1}, {0});
|
||||
gathered = make_shared<op::v0::Squeeze>(gathered_concat_input, axis_const);
|
||||
}
|
||||
|
||||
output_values[0] = gathered;
|
||||
|
||||
return true;
|
||||
}
|
||||
} // namespace gather
|
||||
|
||||
bool op::util::GatherBase::evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const
|
||||
{
|
||||
NGRAPH_OP_SCOPE(util_GatherBase_evaluate);
|
||||
NGRAPH_CHECK(validate_host_tensor_vector(inputs, 3));
|
||||
NGRAPH_CHECK(validate_host_tensor_vector(outputs, 1));
|
||||
|
||||
int64_t axis = 0;
|
||||
switch (inputs[2]->get_element_type())
|
||||
{
|
||||
case element::Type_t::i32: axis = inputs[2]->get_data_ptr<element::Type_t::i32>()[0]; break;
|
||||
case element::Type_t::i64: axis = inputs[2]->get_data_ptr<element::Type_t::i64>()[0]; break;
|
||||
case element::Type_t::i8: axis = inputs[2]->get_data_ptr<element::Type_t::i8>()[0]; break;
|
||||
case element::Type_t::i16: axis = inputs[2]->get_data_ptr<element::Type_t::i16>()[0]; break;
|
||||
case element::Type_t::u8: axis = inputs[2]->get_data_ptr<element::Type_t::u8>()[0]; break;
|
||||
case element::Type_t::u16: axis = inputs[2]->get_data_ptr<element::Type_t::u16>()[0]; break;
|
||||
case element::Type_t::u32: axis = inputs[2]->get_data_ptr<element::Type_t::u32>()[0]; break;
|
||||
case element::Type_t::u64: axis = inputs[2]->get_data_ptr<element::Type_t::u64>()[0]; break;
|
||||
default: throw ngraph_error("axis must be of integral data type.");
|
||||
}
|
||||
|
||||
if (axis < 0)
|
||||
{
|
||||
const auto& input_rank = get_input_partial_shape(0).rank();
|
||||
if (input_rank.is_static())
|
||||
{
|
||||
axis += input_rank.get_length();
|
||||
}
|
||||
}
|
||||
return gather::evaluate_gather(inputs[0], inputs[1], outputs[0], axis, get_batch_dims());
|
||||
}
|
||||
|
||||
bool op::util::GatherBase::evaluate_lower(const HostTensorVector& output_values) const
|
||||
{
|
||||
if (!input_value(1).get_tensor().has_and_set_bound() ||
|
||||
!input_value(2).get_tensor().has_and_set_bound())
|
||||
return false;
|
||||
return default_lower_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::util::GatherBase::evaluate_upper(const HostTensorVector& output_values) const
|
||||
{
|
||||
if (!input_value(1).get_tensor().has_and_set_bound() ||
|
||||
!input_value(2).get_tensor().has_and_set_bound())
|
||||
return false;
|
||||
return default_upper_bound_evaluator(this, output_values);
|
||||
}
|
||||
|
||||
bool op::util::GatherBase::constant_fold(OutputVector& output_values,
|
||||
const OutputVector& input_values)
|
||||
{
|
||||
// try the regular constant folding just for the Gather node
|
||||
if (Node::constant_fold(output_values, input_values))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return gather::cf_gather_with_subgraph(
|
||||
output_values, input_values, get_output_partial_shape(0));
|
||||
}
|
||||
}
|
@ -27,6 +27,40 @@ TEST(type_prop, gather_axis_0)
|
||||
ASSERT_EQ(G->get_axis(), 0);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_uint8)
|
||||
{
|
||||
// Gather_1 must allow even if indices is not int32/int64
|
||||
PartialShape data_shape{3, 2};
|
||||
PartialShape indices_shape{2, 2};
|
||||
PartialShape out_shape{2, 2, 2};
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::u8, indices_shape);
|
||||
auto A = op::Constant::create(element::i64, Shape{}, {0});
|
||||
auto G = make_shared<op::v1::Gather>(D, I, A);
|
||||
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||
ASSERT_EQ(G->get_axis(), 0);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_7_float32)
|
||||
{
|
||||
// Gather_1 should allow non int32/int64 indices
|
||||
PartialShape data_shape{3, 2};
|
||||
PartialShape indices_shape{2, 2};
|
||||
PartialShape out_shape{2, 2, 2};
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::f32, indices_shape);
|
||||
auto A = op::Constant::create(element::i64, Shape{}, {0});
|
||||
auto G = make_shared<op::v1::Gather>(D, I, A);
|
||||
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
ASSERT_EQ(G->get_output_partial_shape(0), out_shape);
|
||||
ASSERT_EQ(G->get_axis(), 0);
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_axis_1)
|
||||
{
|
||||
Shape params_shape{3, 3};
|
||||
@ -55,7 +89,7 @@ TEST(type_prop, gather_v1_incorrect_axis_shape)
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Axes input must be scalar or have 1 element (shape:"));
|
||||
std::string("Axis input must be scalar or have 1 element"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
@ -77,7 +111,7 @@ TEST(type_prop, gather_v1_axis_out_of_input_rank)
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("The axis must => 0 and <= input_rank (axis:"));
|
||||
std::string("The axis must be >= 0 and < data_rank. But instead got axis"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
@ -241,7 +275,7 @@ TEST(type_prop, gather_7_axis_not_set)
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
auto A = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
auto A = make_shared<op::Parameter>(element::i32, Shape{1});
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A);
|
||||
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
@ -260,7 +294,7 @@ TEST(type_prop, gather_7_axis_not_set_positive_batch_dims)
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
auto A = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
auto A = make_shared<op::Parameter>(element::i32, Shape{1});
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
@ -279,7 +313,7 @@ TEST(type_prop, gather_7_axis_not_set_negative_batch)
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i64, indices_shape);
|
||||
auto A = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
auto A = make_shared<op::Parameter>(element::i32, Shape{1});
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
|
||||
ASSERT_EQ(G->get_element_type(), element::f32);
|
||||
@ -303,7 +337,7 @@ TEST(type_prop, gather_7_incorrect_axis_shape)
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Axes input must be scalar or have 1 element"));
|
||||
std::string("Axis input must be scalar or have 1 element"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
@ -326,7 +360,7 @@ TEST(type_prop, gather_7_axis_out_of_input_rank)
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(), std::string("The axis must be => 0 and < data_rank. But instead got"));
|
||||
error.what(), std::string("The axis must be >= 0 and < data_rank. But instead got"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
@ -420,3 +454,63 @@ TEST(type_prop, gather_7_batch_dims_less_indices_rank_check)
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
// disabled until decision of type constrains for gather
|
||||
TEST(type_prop, DISABLED_gather_7_indices_type_check)
|
||||
{
|
||||
PartialShape data_shape{1, 20, 20, 22, 22};
|
||||
PartialShape indices_shape{1, 3};
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::f32, indices_shape);
|
||||
int64_t axis = 4;
|
||||
auto A = make_shared<op::Constant>(element::i64, Shape{1}, vector<int64_t>{axis});
|
||||
int64_t batch_dims = 0;
|
||||
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "indices element_type check failed";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("Indices element type must be of an integral number type"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
||||
// disabled until decision of type constrains for gather
|
||||
TEST(type_prop, DISABLED_gather_7_axis_type_check)
|
||||
{
|
||||
PartialShape data_shape{1, 20, 20, 22, 22};
|
||||
PartialShape indices_shape{1, 3};
|
||||
|
||||
auto D = make_shared<op::Parameter>(element::f32, data_shape);
|
||||
auto I = make_shared<op::Parameter>(element::i32, indices_shape);
|
||||
int64_t axis = 4;
|
||||
auto A = make_shared<op::Constant>(element::f32, Shape{1}, vector<int64_t>{axis});
|
||||
int64_t batch_dims = 0;
|
||||
|
||||
try
|
||||
{
|
||||
auto G = make_shared<op::v7::Gather>(D, I, A, batch_dims);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "axis element_type check failed";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("Axis element type must be of an integral number type"));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Deduced type check failed for unexpected reason";
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user