Files
openvino/src/core/src/op/gather_tree.cpp
Katarzyna Mitrus 407590cfc2 [ShapeInference] GatherTree shape infer improvements (#15399)
* Shape infer improvments

* Add type_prop label and interval dims tests

* Update shape_infer tests

* Use new shape_infer

* Revert headers changes

* Rename test file
2023-01-31 14:04:19 +01:00

65 lines
2.5 KiB
C++

// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/op/gather_tree.hpp"
#include "gather_tree_shape_inference.hpp"
#include "itt.hpp"
#include "ngraph/shape.hpp"
#include "openvino/core/validation_util.hpp"
using namespace std;
using namespace ngraph;
op::v1::GatherTree::GatherTree(const Output<Node>& step_ids,
const Output<Node>& parent_idx,
const Output<Node>& max_seq_len,
const Output<Node>& end_token)
: Op({step_ids, parent_idx, max_seq_len, end_token}) {
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::v1::GatherTree::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v1_GatherTree_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v1::GatherTree>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
}
bool ngraph::op::v1::GatherTree::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v1_GatherTree_visit_attributes);
return true;
}
void op::v1::GatherTree::validate_and_infer_types() {
OV_OP_SCOPE(v1_GatherTree_validate_and_infer_types);
const auto& step_ids_et = get_input_element_type(0);
const auto& parent_idx_et = get_input_element_type(1);
const auto& max_seq_len_et = get_input_element_type(2);
const auto& end_token_et = get_input_element_type(3);
element::Type result_et;
NODE_VALIDATION_CHECK(this,
element::Type::merge(result_et, step_ids_et, parent_idx_et) &&
element::Type::merge(result_et, result_et, max_seq_len_et) &&
element::Type::merge(result_et, result_et, end_token_et),
"Inputs must have the same element type. Got: step_ids (",
step_ids_et,
"), parent_idx_et (",
parent_idx_et,
"), max_seq_len (",
max_seq_len_et,
"), end_token (",
end_token_et,
")");
NODE_VALIDATION_CHECK(this,
result_et.is_real() || result_et.is_integral_number(),
"Element type of inputs must be numeric. Got: ",
result_et);
const auto output_shape = shape_infer(this, ov::get_node_input_partial_shapes(*this)).front();
set_output_type(0, result_et, output_shape);
}