add validate_and_infer_types method

This commit is contained in:
pszmel 2021-06-09 13:35:40 +02:00
parent 2eb78027a5
commit 075427a0fe
2 changed files with 10 additions and 6 deletions

View File

@ -24,6 +24,7 @@ namespace ngraph
/// \param arg Node that produces the input tensor.
Negative(const Output<Node>& arg);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;

View File

@ -18,6 +18,15 @@ op::Negative::Negative(const Output<Node>& arg)
constructor_validate_and_infer_types();
}
void op::Negative::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(v0_Negative_validate_and_infer_types);
auto input_et = get_input_element_type(0);
NODE_VALIDATION_CHECK(this, input_et != element::boolean, "Input type cannot be a boolean");
NODE_VALIDATION_CHECK(this, input_et.is_signed(), "Input type has to be signed");
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
bool ngraph::op::v0::Negative::visit_attributes(AttributeVisitor& visitor)
{
NGRAPH_OP_SCOPE(v0_Negative_visit_attributes);
@ -48,11 +57,8 @@ namespace negativeop
switch (arg0->get_element_type())
{
NGRAPH_TYPE_CASE(evaluate_negative, boolean, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_negative, i32, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_negative, i64, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_negative, u32, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_negative, u64, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_negative, f16, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_negative, f32, arg0, out, count);
default: rc = false; break;
@ -72,11 +78,8 @@ bool op::Negative::has_evaluate() const
NGRAPH_OP_SCOPE(v0_Negative_has_evaluate);
switch (get_input_element_type(0))
{
case ngraph::element::boolean:
case ngraph::element::i32:
case ngraph::element::i64:
case ngraph::element::u32:
case ngraph::element::u64:
case ngraph::element::f16:
case ngraph::element::f32: return true;
default: break;