fix unit tests execution
This commit is contained in:
parent
48f20927af
commit
4417a13bad
@ -87,13 +87,6 @@ NodePtr InsertUnsqueeze(Output<Node> node, size_t n_dims) {
|
||||
return unsqueeze;
|
||||
}
|
||||
|
||||
Output<Node> FixInputNodeRank(Output<Node> input_node, Rank::value_type required_rank) {
|
||||
const Rank::value_type output_rank = input_node.get_partial_shape().rank().get_length();
|
||||
if (output_rank >= required_rank)
|
||||
return input_node;
|
||||
return InsertUnsqueeze(input_node, required_rank - output_rank)->output(0);
|
||||
}
|
||||
|
||||
/*
|
||||
Converts gather indices to positive form
|
||||
*/
|
||||
@ -234,7 +227,7 @@ void UpdateInputGather(NodePtr main_node, const GatherInputsInfo& gather_input_i
|
||||
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
|
||||
input_node.get_partial_shape().rank().get_length());
|
||||
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
|
||||
Shape{1},
|
||||
Shape{},
|
||||
gather_positive_axis);
|
||||
|
||||
auto new_gather = std::make_shared<Gather>(input_node, new_indices_const, new_axis_const);
|
||||
@ -263,7 +256,7 @@ NodeVector InsertOutputGather(NodePtr main_node, const GatherInputsInfo& gather_
|
||||
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
|
||||
main_node->output(i).get_partial_shape().rank().get_length());
|
||||
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
|
||||
Shape{1},
|
||||
Shape{},
|
||||
gather_positive_axis);
|
||||
auto new_gather = std::make_shared<Gather>(main_node->output(i), new_indices_const, new_axis_const);
|
||||
|
||||
@ -317,7 +310,7 @@ NodeVector InsertGatherBeforeNode(NodePtr main_node,
|
||||
const int64_t gather_positive_axis = ConvertAxisToPositive(gather_negative_axis,
|
||||
input_node.get_partial_shape().rank().get_length());
|
||||
auto new_axis_const = std::make_shared<Constant>(axis_element_type,
|
||||
Shape{1},
|
||||
Shape{},
|
||||
gather_positive_axis);
|
||||
|
||||
auto new_gather = std::make_shared<Gather>(input_node, new_indices_const, new_axis_const);
|
||||
|
@ -81,7 +81,7 @@ std::shared_ptr<Gather> MakeGather(NodePtr input_node, CreateIndicesF create_ind
|
||||
const std::vector<size_t> indexes = create_indices_func(input_shape[axis], 0);
|
||||
auto gather_indexes_node = Constant::create(ngraph::element::i64, ov::Shape{indexes.size()}, indexes);
|
||||
|
||||
auto gather_axis_node = Constant::create(ngraph::element::i64, ngraph::Shape{1}, {axis});
|
||||
auto gather_axis_node = Constant::create(ngraph::element::i64, ngraph::Shape{}, {axis});
|
||||
|
||||
return std::make_shared<Gather>(input_node, gather_indexes_node, gather_axis_node);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user