Handle Split axes as i64 (#2079)

This commit is contained in:
Tomasz Dołbniak 2020-09-07 04:48:32 +02:00 committed by GitHub
parent 33c3aeb867
commit 0cd0c1a551
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 6 deletions

View File

@ -34,7 +34,7 @@ namespace ngraph
NGRAPH_DEPRECATED("This builder was deprecated.")
OutputVector split(const Output<Node>& value,
const std::vector<size_t>& length_parts,
size_t axis = 0);
int64_t axis = 0);
/// \brief Split node on specified axis into multiple parts.
///

View File

@ -30,7 +30,7 @@ namespace
}
std::shared_ptr<op::Slice> make_ng_slice(const Output<Node>& output,
const std::vector<size_t>& axes,
const std::vector<int64_t>& axes,
const std::vector<size_t>& starts,
const std::vector<size_t>& ends)
{
@ -38,7 +38,7 @@ namespace
std::vector<size_t> lower_bounds(upper_bounds.size());
for (size_t index{0}; index < axes.size(); ++index)
{
size_t axis{axes.at(index)};
int64_t axis{axes.at(index)};
lower_bounds.at(axis) =
get_valid_array_index(starts.at(index), output.get_shape().at(axis));
upper_bounds.at(axis) =
@ -51,7 +51,7 @@ namespace
}
OutputVector
builder::split(const Output<Node>& value, const std::vector<size_t>& length_parts, size_t axis)
builder::split(const Output<Node>& value, const std::vector<size_t>& length_parts, int64_t axis)
{
size_t start_index{0};
OutputVector outputs;
@ -81,7 +81,7 @@ OutputVector builder::opset1::split(const Output<Node>& value,
const std::vector<size_t>& split_lengths,
int64_t axis)
{
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
const auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{}, {axis});
const auto split_lengths_node =
ngraph::opset1::Constant::create(element::u64, Shape{split_lengths.size()}, split_lengths);
const auto variadic_split =
@ -92,7 +92,7 @@ OutputVector builder::opset1::split(const Output<Node>& value,
OutputVector builder::opset1::split(const Output<Node>& value, size_t num_splits, int64_t axis)
{
const auto axis_node = ngraph::opset1::Constant::create(element::u64, Shape{}, {axis});
const auto axis_node = ngraph::opset1::Constant::create(element::i64, Shape{}, {axis});
const auto split = std::make_shared<ngraph::opset1::Split>(value, axis_node, num_splits);
return split->outputs();