fix unit tests execution

This commit is contained in:
Evgeny Kotov 2023-02-20 18:50:03 +01:00
parent 48f20927af
commit 4417a13bad
2 changed files with 4 additions and 11 deletions

View File

@ -87,13 +87,6 @@ NodePtr InsertUnsqueeze(Output<Node> node, size_t n_dims) {
return unsqueeze;
}
Output<Node> FixInputNodeRank(Output<Node> input_node, Rank::value_type required_rank) {
const Rank::value_type output_rank = input_node.get_partial_shape().rank().get_length();
if (output_rank >= required_rank)
return input_node;
return InsertUnsqueeze(input_node, required_rank - output_rank)->output(0);
}
/*
Converts gather indices to positive form
*/
@ -234,7 +227,7 @@ void UpdateInputGather(NodePtr main_node, const GatherInputsInfo& gather_input_i
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
input_node.get_partial_shape().rank().get_length());
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
Shape{1},
Shape{},
gather_positive_axis);
auto new_gather = std::make_shared<Gather>(input_node, new_indices_const, new_axis_const);
@ -263,7 +256,7 @@ NodeVector InsertOutputGather(NodePtr main_node, const GatherInputsInfo& gather_
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
main_node->output(i).get_partial_shape().rank().get_length());
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
Shape{1},
Shape{},
gather_positive_axis);
auto new_gather = std::make_shared<Gather>(main_node->output(i), new_indices_const, new_axis_const);
@ -317,7 +310,7 @@ NodeVector InsertGatherBeforeNode(NodePtr main_node,
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
input_node.get_partial_shape().rank().get_length());
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
Shape{1},
Shape{},
gather_positive_axis);
auto new_gather = std::make_shared<Gather>(input_node, new_indices_const, new_axis_const);

View File

@ -81,7 +81,7 @@ std::shared_ptr<Gather> MakeGather(NodePtr input_node, CreateIndicesF create_ind
const std::vector<size_t> indexes = create_indices_func(input_shape[axis], 0);
auto gather_indexes_node = Constant::create(ngraph::element::i64, ov::Shape{indexes.size()}, indexes);
auto gather_axis_node = Constant::create(ngraph::element::i64, ngraph::Shape{1}, {axis});
auto gather_axis_node = Constant::create(ngraph::element::i64, ngraph::Shape{}, {axis});
return std::make_shared<Gather>(input_node, gather_indexes_node, gather_axis_node);
}