[DOCS] Added an evaluate method for custom operation (#2273)
* Added an evaluate method for custom operation * Fixed comments
This commit is contained in:
parent
7535f80bd3
commit
5d291c3c84
@ -16,6 +16,8 @@ To add your custom nGraph operation, create a new class that extends `ngraph::Op
|
||||
|
||||
5. Override the `visit_attributes` method, which allows serialization and deserialization of attributes. An `AttributeVisitor` is passed to the method, and the implementation is expected to walk over all the attributes in the op using the type-aware `on_attribute` helper. Helpers are already implemented for standard C++ types like `int64_t`, `float`, `bool`, `vector` and for existing nGraph defined types.
|
||||
|
||||
6. Override `evaluate`, which is an optional method that enables the application of constant folding if there is a custom operation on the constant branch.
|
||||
|
||||
Based on that, declaration of a operation class can look as follows:
|
||||
|
||||
@snippet op.hpp op:header
|
||||
@ -51,6 +53,12 @@ nGraph operation contains two constructors: a default constructor, which allows
|
||||
|
||||
@snippet op.cpp op:visit_attributes
|
||||
|
||||
### `evaluate()`
|
||||
|
||||
`ngraph::Node::evaluate` method allows to apply constant folding to an operation.
|
||||
|
||||
@snippet op.cpp op:evaluate
|
||||
|
||||
## Register Custom Operations in Extension Class
|
||||
|
||||
To add custom operations to the [Extension](Extension.md) class, create an operation set with custom operations and implement the `InferenceEngine::IExtension::getOpSets` method:
|
||||
|
@ -36,3 +36,52 @@ bool Operation::visit_attributes(ngraph::AttributeVisitor &visitor) {
|
||||
return true;
|
||||
}
|
||||
//! [op:visit_attributes]
|
||||
|
||||
//! [op:evaluate]
|
||||
namespace
|
||||
{
|
||||
|
||||
template <class T>
|
||||
void implementation(const T* input,
|
||||
T* output,
|
||||
int64_t add,
|
||||
size_t size) {
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
output[i] = input[i] + add;
|
||||
}
|
||||
}
|
||||
|
||||
template <ngraph::element::Type_t ET>
|
||||
bool evaluate_op(const ngraph::HostTensorPtr& arg0,
|
||||
const ngraph::HostTensorPtr& out, int64_t add)
|
||||
{
|
||||
size_t size = ngraph::shape_size(arg0->get_shape());
|
||||
implementation(arg0->get_data_ptr<ET>(),
|
||||
out->get_data_ptr<ET>(),
|
||||
add,
|
||||
size);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool Operation::evaluate(const ngraph::HostTensorVector& outputs,
|
||||
const ngraph::HostTensorVector& inputs) const {
|
||||
switch (inputs[0]->get_element_type())
|
||||
{
|
||||
case ngraph::element::Type_t::i8: return evaluate_op<ngraph::element::Type_t::i8>(inputs[0], outputs[0], getAddAttr());
|
||||
case ngraph::element::Type_t::i16: return evaluate_op<ngraph::element::Type_t::i16>(inputs[0], outputs[0], getAddAttr());
|
||||
case ngraph::element::Type_t::i32: return evaluate_op<ngraph::element::Type_t::i32>(inputs[0], outputs[0], getAddAttr());
|
||||
case ngraph::element::Type_t::i64: return evaluate_op<ngraph::element::Type_t::i64>(inputs[0], outputs[0], getAddAttr());
|
||||
case ngraph::element::Type_t::u8: return evaluate_op<ngraph::element::Type_t::u8>(inputs[0], outputs[0], getAddAttr());
|
||||
case ngraph::element::Type_t::u16: return evaluate_op<ngraph::element::Type_t::u16>(inputs[0], outputs[0], getAddAttr());
|
||||
case ngraph::element::Type_t::u32: return evaluate_op<ngraph::element::Type_t::u32>(inputs[0], outputs[0], getAddAttr());
|
||||
case ngraph::element::Type_t::u64: return evaluate_op<ngraph::element::Type_t::u8>(inputs[0], outputs[0], getAddAttr());
|
||||
case ngraph::element::Type_t::bf16: return evaluate_op<ngraph::element::Type_t::bf16>(inputs[0], outputs[0], getAddAttr());
|
||||
case ngraph::element::Type_t::f16: return evaluate_op<ngraph::element::Type_t::f16>(inputs[0], outputs[0], getAddAttr());
|
||||
case ngraph::element::Type_t::f32: return evaluate_op<ngraph::element::Type_t::f32>(inputs[0], outputs[0], getAddAttr());
|
||||
default: break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
//! [op:evaluate]
|
||||
|
@ -19,7 +19,9 @@ public:
|
||||
void validate_and_infer_types() override;
|
||||
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override;
|
||||
bool visit_attributes(ngraph::AttributeVisitor& visitor) override;
|
||||
int64_t getAddAttr() { return add; }
|
||||
int64_t getAddAttr() const { return add; }
|
||||
bool evaluate(const ngraph::HostTensorVector& outputs,
|
||||
const ngraph::HostTensorVector& inputs) const override;
|
||||
|
||||
private:
|
||||
int64_t add;
|
||||
|
Loading…
Reference in New Issue
Block a user