Treat 1d single-element tensors as scalars. (#1498)
This commit is contained in:
parent
4e1f7d2b96
commit
2a96917e2a
@ -75,10 +75,14 @@ namespace ngraph
|
||||
const std::shared_ptr<ngraph::Node> input,
|
||||
const std::set<element::Type> allowed_types)
|
||||
{
|
||||
const auto validated_input_rank = input->get_output_partial_shape(0).rank();
|
||||
const auto validated_input_shape = input->get_output_partial_shape(0);
|
||||
const auto validated_input_rank = validated_input_shape.rank();
|
||||
|
||||
NGRAPH_CHECK(
|
||||
validated_input_rank.same_scheme({0}), input_name, " needs to be a scalar.");
|
||||
NGRAPH_CHECK(validated_input_rank.same_scheme({0}) ||
|
||||
(validated_input_rank.same_scheme({1}) &&
|
||||
validated_input_shape[0].get_length() == 1),
|
||||
input_name,
|
||||
" needs to be a scalar or 1D, single-element tensor.");
|
||||
|
||||
if (!allowed_types.empty())
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user