Handle Split axes as i64 (#2079)
This commit is contained in:
parent
33c3aeb867
commit
0cd0c1a551
@ -34,7 +34,7 @@ namespace ngraph
|
||||
NGRAPH_DEPRECATED("This builder was deprecated.")
|
||||
OutputVector split(const Output<Node>& value,
|
||||
const std::vector<size_t>& length_parts,
|
||||
size_t axis = 0);
|
||||
int64_t axis = 0);
|
||||
|
||||
/// \brief Split node on specified axis into multiple parts.
|
||||
///
|
||||
|
@ -30,7 +30,7 @@ namespace
|
||||
}
|
||||
|
||||
std::shared_ptr<op::Slice> make_ng_slice(const Output<Node>& output,
|
||||
const std::vector<size_t>& axes,
|
||||
const std::vector<int64_t>& axes,
|
||||
const std::vector<size_t>& starts,
|
||||
const std::vector<size_t>& ends)
|
||||
{
|
||||
@ -38,7 +38,7 @@ namespace
|
||||
std::vector<size_t> lower_bounds(upper_bounds.size());
|
||||
for (size_t index{0}; index < axes.size(); ++index)
|
||||
{
|
||||
size_t axis{axes.at(index)};
|
||||
int64_t axis{axes.at(index)};
|
||||
lower_bounds.at(axis) =
|
||||
get_valid_array_index(starts.at(index), output.get_shape().at(axis));
|
||||
upper_bounds.at(axis) =
|
||||
@ -51,7 +51,7 @@ namespace
|
||||
}
|
||||
|
||||
OutputVector
|
||||
builder::split(const Output<Node>& value, const std::vector<size_t>& length_parts, size_t axis)
|
||||
builder::split(const Output<Node>& value, const std::vector<size_t>& length_parts, int64_t axis)
|
||||
{
|
||||
size_t start_index{0};
|
||||
OutputVector outputs;
|
||||
@ -81,7 +81,7 @@ OutputVector builder::opset1::split(const Output<Node>& value,
|
||||
const std::vector<size_t>& split_lengths,
|
||||
int64_t axis)
|
||||
{
|
||||
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
|
||||
const auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{}, {axis});
|
||||
const auto split_lengths_node =
|
||||
ngraph::opset1::Constant::create(element::u64, Shape{split_lengths.size()}, split_lengths);
|
||||
const auto variadic_split =
|
||||
@ -92,7 +92,7 @@ OutputVector builder::opset1::split(const Output<Node>& value,
|
||||
|
||||
OutputVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis)
|
||||
{
|
||||
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
|
||||
const auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{}, {axis});
|
||||
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);
|
||||
|
||||
return split->outputs();
|
||||
|
Loading…
Reference in New Issue
Block a user