Treat 1d single-element tensors as scalars. (#1498)

This commit is contained in:
Adam Osewski 2020-07-28 14:01:13 +02:00 committed by GitHub
parent 4e1f7d2b96
commit 2a96917e2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -75,10 +75,14 @@ namespace ngraph
const std::shared_ptr<ngraph::Node> input,
const std::set<element::Type> allowed_types)
{
const auto validated_input_rank = input->get_output_partial_shape(0).rank();
const auto validated_input_shape = input->get_output_partial_shape(0);
const auto validated_input_rank = validated_input_shape.rank();
NGRAPH_CHECK(
validated_input_rank.same_scheme({0}), input_name, " needs to be a scalar.");
NGRAPH_CHECK(validated_input_rank.same_scheme({0}) ||
(validated_input_rank.same_scheme({1}) &&
validated_input_shape[0].get_length() == 1),
input_name,
" needs to be a scalar or 1D, single-element tensor.");
if (!allowed_types.empty())
{