🦒Nonzero-adjustment (#5863)

This commit is contained in:
Tomasz Socha 2021-05-27 16:42:52 +02:00 committed by GitHub
parent aaa632e8f5
commit 87a47aa117
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 3 additions and 12 deletions

View File

@ -29,7 +29,7 @@ The output tensor has shape `[rank(input), num_non_zero]`. For example, for the
**Types** **Types**
* *T*: any numeric type. * *T*: any type.
* *T_IND*: `int64` or `int32`. * *T_IND*: `int64` or `int32`.

View File

@ -48,12 +48,7 @@ void op::v3::NonZero::validate_and_infer_types()
{ {
NGRAPH_OP_SCOPE(v3_NonZero_validate_and_infer_types); NGRAPH_OP_SCOPE(v3_NonZero_validate_and_infer_types);
const PartialShape& input_shape = get_input_partial_shape(0); const PartialShape& input_shape = get_input_partial_shape(0);
const auto input_et = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
input_et.is_integral_number() || input_et.is_real(),
"NonZero input data type needs to be a numeric type. Got: ",
input_et);
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
m_output_type == element::i64 || m_output_type == element::i32, m_output_type == element::i64 || m_output_type == element::i32,
"Output type must be i32 or i64"); "Output type must be i32 or i64");
@ -154,6 +149,7 @@ namespace nonzero
switch (input->get_element_type()) switch (input->get_element_type())
{ {
NGRAPH_TYPE_CASE(evaluate_nonzero, boolean, input, output);
NGRAPH_TYPE_CASE(evaluate_nonzero, i8, input, output); NGRAPH_TYPE_CASE(evaluate_nonzero, i8, input, output);
NGRAPH_TYPE_CASE(evaluate_nonzero, i16, input, output); NGRAPH_TYPE_CASE(evaluate_nonzero, i16, input, output);
NGRAPH_TYPE_CASE(evaluate_nonzero, i32, input, output); NGRAPH_TYPE_CASE(evaluate_nonzero, i32, input, output);

View File

@ -19,8 +19,7 @@ namespace ngraph
OutputVector compress(const Node& node) OutputVector compress(const Node& node)
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
auto condition = std::make_shared<default_opset::Convert>( auto condition = node.get_ng_inputs().at(1);
node.get_ng_inputs().at(1), element::u8);
int64_t axis = 0; int64_t axis = 0;
if (node.has_attribute("axis")) if (node.has_attribute("axis"))

View File

@ -18,10 +18,6 @@ namespace ngraph
OutputVector non_zero(const Node& node) OutputVector non_zero(const Node& node)
{ {
auto data = node.get_ng_inputs().at(0); auto data = node.get_ng_inputs().at(0);
if (data.get_element_type() == element::boolean)
{
data = std::make_shared<default_opset::Convert>(data, element::u8);
}
return {std::make_shared<default_opset::NonZero>(data, element::i64)}; return {std::make_shared<default_opset::NonZero>(data, element::i64)};
} }