Treat 1d single-element tensors as scalars. (#1498)
This commit is contained in:
parent
4e1f7d2b96
commit
2a96917e2a
@ -75,10 +75,14 @@ namespace ngraph
|
|||||||
const std::shared_ptr<ngraph::Node> input,
|
const std::shared_ptr<ngraph::Node> input,
|
||||||
const std::set<element::Type> allowed_types)
|
const std::set<element::Type> allowed_types)
|
||||||
{
|
{
|
||||||
const auto validated_input_rank = input->get_output_partial_shape(0).rank();
|
const auto validated_input_shape = input->get_output_partial_shape(0);
|
||||||
|
const auto validated_input_rank = validated_input_shape.rank();
|
||||||
|
|
||||||
NGRAPH_CHECK(
|
NGRAPH_CHECK(validated_input_rank.same_scheme({0}) ||
|
||||||
validated_input_rank.same_scheme({0}), input_name, " needs to be a scalar.");
|
(validated_input_rank.same_scheme({1}) &&
|
||||||
|
validated_input_shape[0].get_length() == 1),
|
||||||
|
input_name,
|
||||||
|
" needs to be a scalar or 1D, single-element tensor.");
|
||||||
|
|
||||||
if (!allowed_types.empty())
|
if (!allowed_types.empty())
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user