* Shape infer improvments * Add type_prop label and interval dims tests * Update shape_infer tests * Use new shape_infer * Revert headers changes * Rename test file
65 lines
2.5 KiB
C++
65 lines
2.5 KiB
C++
// Copyright (C) 2018-2023 Intel Corporation
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
|
|
#include "ngraph/op/gather_tree.hpp"
|
|
|
|
#include "gather_tree_shape_inference.hpp"
|
|
#include "itt.hpp"
|
|
#include "ngraph/shape.hpp"
|
|
#include "openvino/core/validation_util.hpp"
|
|
|
|
using namespace std;
|
|
using namespace ngraph;
|
|
|
|
op::v1::GatherTree::GatherTree(const Output<Node>& step_ids,
|
|
const Output<Node>& parent_idx,
|
|
const Output<Node>& max_seq_len,
|
|
const Output<Node>& end_token)
|
|
: Op({step_ids, parent_idx, max_seq_len, end_token}) {
|
|
constructor_validate_and_infer_types();
|
|
}
|
|
|
|
shared_ptr<Node> op::v1::GatherTree::clone_with_new_inputs(const OutputVector& new_args) const {
|
|
OV_OP_SCOPE(v1_GatherTree_clone_with_new_inputs);
|
|
check_new_args_count(this, new_args);
|
|
return make_shared<v1::GatherTree>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
|
|
}
|
|
|
|
bool ngraph::op::v1::GatherTree::visit_attributes(AttributeVisitor& visitor) {
|
|
OV_OP_SCOPE(v1_GatherTree_visit_attributes);
|
|
return true;
|
|
}
|
|
|
|
void op::v1::GatherTree::validate_and_infer_types() {
|
|
OV_OP_SCOPE(v1_GatherTree_validate_and_infer_types);
|
|
|
|
const auto& step_ids_et = get_input_element_type(0);
|
|
const auto& parent_idx_et = get_input_element_type(1);
|
|
const auto& max_seq_len_et = get_input_element_type(2);
|
|
const auto& end_token_et = get_input_element_type(3);
|
|
|
|
element::Type result_et;
|
|
NODE_VALIDATION_CHECK(this,
|
|
element::Type::merge(result_et, step_ids_et, parent_idx_et) &&
|
|
element::Type::merge(result_et, result_et, max_seq_len_et) &&
|
|
element::Type::merge(result_et, result_et, end_token_et),
|
|
"Inputs must have the same element type. Got: step_ids (",
|
|
step_ids_et,
|
|
"), parent_idx_et (",
|
|
parent_idx_et,
|
|
"), max_seq_len (",
|
|
max_seq_len_et,
|
|
"), end_token (",
|
|
end_token_et,
|
|
")");
|
|
|
|
NODE_VALIDATION_CHECK(this,
|
|
result_et.is_real() || result_et.is_integral_number(),
|
|
"Element type of inputs must be numeric. Got: ",
|
|
result_et);
|
|
|
|
const auto output_shape = shape_infer(this, ov::get_node_input_partial_shapes(*this)).front();
|
|
set_output_type(0, result_et, output_shape);
|
|
}
|