Add access to the methods Node::evaluate to each custom Op which inherit from ov::op::Op (#12976)

* Add access to the hidden Node::evaluate method in each Ops

* Fix the test

* Add new line in EOF

* Add comment about using ov::op::Op

* Use opset9

* Add more detailed comment
This commit is contained in:
Artur Kulikowski 2022-09-12 06:51:16 +02:00 committed by GitHub
parent 49ebb95067
commit 1b620fa8bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 6 deletions

View File

@ -14,12 +14,18 @@
#define _OPENVINO_RTTI_OP_WITH_TYPE_VERSION(TYPE_NAME, VERSION_NAME) \ #define _OPENVINO_RTTI_OP_WITH_TYPE_VERSION(TYPE_NAME, VERSION_NAME) \
_OPENVINO_RTTI_WITH_TYPE_VERSION_PARENT(TYPE_NAME, VERSION_NAME, ::ov::op::Op) _OPENVINO_RTTI_WITH_TYPE_VERSION_PARENT(TYPE_NAME, VERSION_NAME, ::ov::op::Op)
#define OPENVINO_OP(...) \ #define OPENVINO_OP(...) \
_OPENVINO_RTTI_EXPAND(_OPENVINO_RTTI_DEFINITION_SELECTOR(__VA_ARGS__, \ _OPENVINO_RTTI_EXPAND(_OPENVINO_RTTI_DEFINITION_SELECTOR(__VA_ARGS__, \
_OPENVINO_RTTI_WITH_TYPE_VERSIONS_PARENT, \ _OPENVINO_RTTI_WITH_TYPE_VERSIONS_PARENT, \
_OPENVINO_RTTI_WITH_TYPE_VERSION_PARENT, \ _OPENVINO_RTTI_WITH_TYPE_VERSION_PARENT, \
_OPENVINO_RTTI_OP_WITH_TYPE_VERSION, \ _OPENVINO_RTTI_OP_WITH_TYPE_VERSION, \
_OPENVINO_RTTI_OP_WITH_TYPE)(__VA_ARGS__)) _OPENVINO_RTTI_OP_WITH_TYPE)(__VA_ARGS__)) \
/* Add accessibility for Op to the method: evaluate from the Base class \
Usually C++ allows to use virtual methods of Base class from Derived class but if they have \
the same name and not all of them are overrided in Derived class, the only overrided methods \
will be available from Derived class. We need to explicitly cast Derived to Base class to \
have an access to remaining methods or use this using. */ \
using ov::op::Op::evaluate;
namespace ov { namespace ov {
namespace op { namespace op {

View File

@ -12,6 +12,7 @@
#include "ngraph/runtime/host_tensor.hpp" #include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/util.hpp" #include "ngraph/util.hpp"
#include "ngraph/validation_util.hpp" #include "ngraph/validation_util.hpp"
#include "openvino/opsets/opset9.hpp"
#include "util/test_tools.hpp" #include "util/test_tools.hpp"
using namespace std; using namespace std;
@ -75,3 +76,14 @@ TEST(op_eval, swish_without_beta) {
for (size_t i = 0; i < inputs.size(); i++) for (size_t i = 0; i < inputs.size(); i++)
EXPECT_NEAR(result_data[i], expected_result[i], 0.000001); EXPECT_NEAR(result_data[i], expected_result[i], 0.000001);
} }
TEST(op_eval, swish_new_evaluate) {
Shape shape{3};
auto p = make_shared<op::Parameter>(element::f32, shape);
auto beta = make_shared<op::Parameter>(element::f32, Shape{});
auto swish = make_shared<ov::opset9::Swish>(p, beta);
ov::TensorVector inputs = {ov::Tensor(element::f32, shape)};
ov::TensorVector outputs = {ov::Tensor(element::f32, shape)};
ASSERT_TRUE(swish->evaluate(outputs, inputs));
}