From 87a47aa117a5f91f49d29988cb72f10f50370198 Mon Sep 17 00:00:00 2001 From: Tomasz Socha Date: Thu, 27 May 2021 16:42:52 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=A6=92Nonzero-adjustment=20(#5863)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/ops/condition/NonZero_3.md | 2 +- ngraph/core/src/op/non_zero.cpp | 6 +----- ngraph/frontend/onnx_import/src/op/compress.cpp | 3 +-- ngraph/frontend/onnx_import/src/op/non_zero.cpp | 4 ---- 4 files changed, 3 insertions(+), 12 deletions(-) diff --git a/docs/ops/condition/NonZero_3.md b/docs/ops/condition/NonZero_3.md index acf75ae0886..44bd96690dd 100644 --- a/docs/ops/condition/NonZero_3.md +++ b/docs/ops/condition/NonZero_3.md @@ -29,7 +29,7 @@ The output tensor has shape `[rank(input), num_non_zero]`. For example, for the **Types** -* *T*: any numeric type. +* *T*: any type. * *T_IND*: `int64` or `int32`. diff --git a/ngraph/core/src/op/non_zero.cpp b/ngraph/core/src/op/non_zero.cpp index 19e52f77fe9..1e11aad1b2a 100644 --- a/ngraph/core/src/op/non_zero.cpp +++ b/ngraph/core/src/op/non_zero.cpp @@ -48,12 +48,7 @@ void op::v3::NonZero::validate_and_infer_types() { NGRAPH_OP_SCOPE(v3_NonZero_validate_and_infer_types); const PartialShape& input_shape = get_input_partial_shape(0); - const auto input_et = get_input_element_type(0); - NODE_VALIDATION_CHECK(this, - input_et.is_integral_number() || input_et.is_real(), - "NonZero input data type needs to be a numeric type. Got: ", - input_et); NODE_VALIDATION_CHECK(this, m_output_type == element::i64 || m_output_type == element::i32, "Output type must be i32 or i64"); @@ -154,6 +149,7 @@ namespace nonzero switch (input->get_element_type()) { + NGRAPH_TYPE_CASE(evaluate_nonzero, boolean, input, output); NGRAPH_TYPE_CASE(evaluate_nonzero, i8, input, output); NGRAPH_TYPE_CASE(evaluate_nonzero, i16, input, output); NGRAPH_TYPE_CASE(evaluate_nonzero, i32, input, output); diff --git a/ngraph/frontend/onnx_import/src/op/compress.cpp b/ngraph/frontend/onnx_import/src/op/compress.cpp index f7658a5e7aa..0cb7b77f442 100644 --- a/ngraph/frontend/onnx_import/src/op/compress.cpp +++ b/ngraph/frontend/onnx_import/src/op/compress.cpp @@ -19,8 +19,7 @@ namespace ngraph OutputVector compress(const Node& node) { auto data = node.get_ng_inputs().at(0); - auto condition = std::make_shared( - node.get_ng_inputs().at(1), element::u8); + auto condition = node.get_ng_inputs().at(1); int64_t axis = 0; if (node.has_attribute("axis")) diff --git a/ngraph/frontend/onnx_import/src/op/non_zero.cpp b/ngraph/frontend/onnx_import/src/op/non_zero.cpp index 5f580111e0d..1798d12ba1f 100644 --- a/ngraph/frontend/onnx_import/src/op/non_zero.cpp +++ b/ngraph/frontend/onnx_import/src/op/non_zero.cpp @@ -18,10 +18,6 @@ namespace ngraph OutputVector non_zero(const Node& node) { auto data = node.get_ng_inputs().at(0); - if (data.get_element_type() == element::boolean) - { - data = std::make_shared(data, element::u8); - } return {std::make_shared(data, element::i64)}; }