Files
openvino/ngraph/test/runtime/interpreter/int_executable.hpp
Maksim Kutakov 2c7f06e08f [IE TESTS] GatherTree single layer test has been created. (#2006)
* [IE TESTS] GatherTree op ref function has been created.

* [IE TESTS] Added GatherTree single layer test

* [IE TESTS] Fixed code styles.

* [IE TESTS] GatherTree test FP32 precion was enabled.

* [IE TESTS] Refactoring of Builder::makeConstatn procedure

The refactoring is aimed at managing the range of random data for the constants initialization procedure.

* [IE TESTS] GatherTree test was extended with constants

* [IE TESTS] GatherTree ref rewritten to non-templated function.

* [IE TESTS] GatherTree test inp shape indx enum removed.

* Revert "[IE TESTS] Refactoring of Builder::makeConstatn procedure"

This reverts commit 2648172e00ccca266d39e8775b890b8a8395f57c.

* [IE TESTS] makeConstant was augmented with random data range parameters.

* [IE TESTS] GatherTree test was rewritten using makeConstant function.

* [IE TESTS] GaterTree test call templated makeConstant

* [IE TESTS] GaterTree test code style fix
2020-09-07 17:17:14 +03:00

1351 lines
61 KiB
C++

//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <initializer_list>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <vector>
#include <ngraph/runtime/host_tensor.hpp>
#include "backend.hpp"
#include "int_backend_visibility.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/runtime/aligned_buffer.hpp"
#include "ngraph/runtime/reference/abs.hpp"
#include "ngraph/runtime/reference/acos.hpp"
#include "ngraph/runtime/reference/any.hpp"
#include "ngraph/runtime/reference/asin.hpp"
#include "ngraph/runtime/reference/atan.hpp"
#include "ngraph/runtime/reference/atan2.hpp"
#include "ngraph/runtime/reference/avg_pool.hpp"
#include "ngraph/runtime/reference/batch_norm.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
#include "ngraph/runtime/reference/ceiling.hpp"
#include "ngraph/runtime/reference/concat.hpp"
#include "ngraph/runtime/reference/constant.hpp"
#include "ngraph/runtime/reference/convert.hpp"
#include "ngraph/runtime/reference/convolution.hpp"
#include "ngraph/runtime/reference/cos.hpp"
#include "ngraph/runtime/reference/cosh.hpp"
#include "ngraph/runtime/reference/ctc_loss.hpp"
#include "ngraph/runtime/reference/cum_sum.hpp"
#include "ngraph/runtime/reference/dequantize.hpp"
#include "ngraph/runtime/reference/detection_output.hpp"
#include "ngraph/runtime/reference/dot.hpp"
#include "ngraph/runtime/reference/elu.hpp"
#include "ngraph/runtime/reference/embedding_bag_offsets_sum.hpp"
#include "ngraph/runtime/reference/embedding_bag_packed_sum.hpp"
#include "ngraph/runtime/reference/embedding_segments_sum.hpp"
#include "ngraph/runtime/reference/erf.hpp"
#include "ngraph/runtime/reference/exp.hpp"
#include "ngraph/runtime/reference/extract_image_patches.hpp"
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/gather.hpp"
#include "ngraph/runtime/reference/gather_nd.hpp"
#include "ngraph/runtime/reference/gather_tree.hpp"
#include "ngraph/runtime/reference/gather_tree.hpp"
#include "ngraph/runtime/reference/gru_cell.hpp"
#include "ngraph/runtime/reference/log.hpp"
#include "ngraph/runtime/reference/lrn.hpp"
#include "ngraph/runtime/reference/lstm_cell.hpp"
#include "ngraph/runtime/reference/matmul.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/max_pool.hpp"
#include "ngraph/runtime/reference/min.hpp"
#include "ngraph/runtime/reference/negate.hpp"
#include "ngraph/runtime/reference/not.hpp"
#include "ngraph/runtime/reference/one_hot.hpp"
#include "ngraph/runtime/reference/pad.hpp"
#include "ngraph/runtime/reference/product.hpp"
#include "ngraph/runtime/reference/quantize.hpp"
#include "ngraph/runtime/reference/relu.hpp"
#include "ngraph/runtime/reference/replace_slice.hpp"
#include "ngraph/runtime/reference/reshape.hpp"
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/rnn_cell.hpp"
#include "ngraph/runtime/reference/round.hpp"
#include "ngraph/runtime/reference/scatter_nd_update.hpp"
#include "ngraph/runtime/reference/select.hpp"
#include "ngraph/runtime/reference/sigmoid.hpp"
#include "ngraph/runtime/reference/sign.hpp"
#include "ngraph/runtime/reference/sin.hpp"
#include "ngraph/runtime/reference/sinh.hpp"
#include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/runtime/reference/sqrt.hpp"
#include "ngraph/runtime/reference/sum.hpp"
#include "ngraph/runtime/reference/tan.hpp"
#include "ngraph/runtime/reference/tanh.hpp"
#include "ngraph/runtime/reference/topk.hpp"
#include "ngraph/runtime/tensor.hpp"
#include "op/avg_pool.hpp"
#include "op/convolution.hpp"
#include "op/group_conv.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
namespace ngraph
{
namespace runtime
{
namespace interpreter
{
class INTBackend;
class INTExecutable;
// This expands the op list in op_tbl.hpp into a list of enumerations that look like
// this:
// Abs,
// Acos,
// ...
enum class OP_TYPEID
{
#define NGRAPH_OP(NAME, NAMESPACE) ID_SUFFIX(NAME),
#include "opset_int_tbl.hpp"
#undef NGRAPH_OP
UnknownOp
};
} // namespace interpreter
} // namespace runtime
} // namespace ngraph
class INTERPRETER_BACKEND_API ngraph::runtime::interpreter::INTExecutable : public Executable
{
friend class INTBackend;
public:
INTExecutable(const std::shared_ptr<Function>& function,
bool enable_performance_collection = false);
bool call(const std::vector<std::shared_ptr<Tensor>>& outputs,
const std::vector<std::shared_ptr<Tensor>>& inputs) override;
void set_nan_check(bool enable);
std::vector<PerformanceCounter> get_performance_data() const override;
std::shared_ptr<runtime::Tensor> create_input_tensor(size_t input_index) override;
std::shared_ptr<runtime::Tensor> create_output_tensor(size_t output_index) override;
std::vector<std::shared_ptr<runtime::Tensor>>
create_input_tensor(size_t input_index, size_t pipeline_depth) override;
std::vector<std::shared_ptr<runtime::Tensor>>
create_output_tensor(size_t output_index, size_t pipeline_depth) override;
protected:
std::shared_ptr<ngraph::op::Parameter> get_parameter(size_t index) const;
std::shared_ptr<ngraph::op::Result> get_result(size_t index) const;
int get_alignment() const { return 64; }
bool m_is_compiled = false;
bool m_nan_check_enabled = false;
bool m_performance_counters_enabled = false;
std::shared_ptr<Function> m_function;
std::unordered_map<std::shared_ptr<const Node>, stopwatch> m_timer_map;
std::vector<std::shared_ptr<Node>> m_nodes;
std::set<std::string> m_unsupported_op_name_list;
static OP_TYPEID get_typeid(const Node& node);
static void perform_nan_check(const std::vector<std::shared_ptr<HostTensor>>&,
const Node* op = nullptr);
virtual void generate_calls(const element::Type& type,
const Node& op,
const std::vector<std::shared_ptr<HostTensor>>& outputs,
const std::vector<std::shared_ptr<HostTensor>>& inputs);
template <typename T>
void op_engine(const Node& node,
const std::vector<std::shared_ptr<HostTensor>>& out,
const std::vector<std::shared_ptr<HostTensor>>& args)
{
// We want to check that every OP_TYPEID enumeration is included in the list.
// These GCC flags enable compile-time checking so that if an enumeration
// is not in the list an error is generated.
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic push
#pragma GCC diagnostic error "-Wswitch"
#pragma GCC diagnostic error "-Wswitch-enum"
#endif
switch (get_typeid(node))
{
case OP_TYPEID::Abs:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::abs<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Acos:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::acos<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Any:
{
const op::Any* any = static_cast<const op::Any*>(&node);
reference::any(args[0]->get_data_ptr<const char>(),
out[0]->get_data_ptr<char>(),
node.get_input_shape(0),
any->get_reduction_axes(),
false);
break;
}
case OP_TYPEID::Asin:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::asin<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Atan:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::atan<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Elu:
{
const op::Elu* elu_node = static_cast<const op::Elu*>(&node);
size_t element_count = shape_size(node.get_output_shape(0));
reference::elu<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count,
elu_node->get_alpha());
break;
}
case OP_TYPEID::AvgPool:
{
const op::v0::AvgPool* avg_pool = static_cast<const op::v0::AvgPool*>(&node);
reference::avg_pool<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
avg_pool->get_window_shape(),
avg_pool->get_window_movement_strides(),
avg_pool->get_padding_below(),
avg_pool->get_padding_above(),
avg_pool->get_include_padding_in_avg_computation());
break;
}
case OP_TYPEID::BatchNormInference:
{
const ngraph::op::BatchNormInference* bn =
static_cast<const ngraph::op::BatchNormInference*>(&node);
reference::batch_norm_inference<T>(bn->get_eps_value(),
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const T>(),
args[3]->get_data_ptr<const T>(),
args[4]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(2));
break;
}
case OP_TYPEID::BroadcastLike: break;
case OP_TYPEID::Ceiling:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::ceiling<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Convert:
{
// const op::Convert* c = static_cast<const op::Convert*>(&node);
element::Type type = node.get_element_type();
std::stringstream ss;
size_t element_count = shape_size(node.get_output_shape(0));
switch (type)
{
case element::Type_t::boolean:
reference::convert_to_bool<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<char>(), element_count);
break;
case element::Type_t::f32:
reference::convert<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<float>(), element_count);
break;
case element::Type_t::f64:
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<double>(),
element_count);
break;
case element::Type_t::i8:
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int8_t>(),
element_count);
break;
case element::Type_t::i16:
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int16_t>(),
element_count);
break;
case element::Type_t::i32:
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int32_t>(),
element_count);
break;
case element::Type_t::i64:
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int64_t>(),
element_count);
break;
case element::Type_t::u8:
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<uint8_t>(),
element_count);
break;
case element::Type_t::u16:
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<uint16_t>(),
element_count);
break;
case element::Type_t::u32:
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<uint32_t>(),
element_count);
break;
case element::Type_t::u64:
reference::convert<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<uint64_t>(),
element_count);
break;
case element::Type_t::undefined:
case element::Type_t::dynamic:
case element::Type_t::u1:
case element::Type_t::bf16:
case element::Type_t::f16:
ss << "unsupported element type " << type << " op Convert";
throw std::runtime_error(ss.str());
}
break;
}
case OP_TYPEID::Convolution:
{
const op::v0::Convolution* c = static_cast<const op::v0::Convolution*>(&node);
reference::convolution<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
c->get_window_movement_strides(),
c->get_window_dilation_strides(),
c->get_padding_below(),
c->get_padding_above(),
c->get_data_dilation_strides());
break;
}
case OP_TYPEID::ConvolutionBackpropData:
{
// Note that args[1] and args[0] are switched here from the usual order.
const op::v0::ConvolutionBackpropData* c =
static_cast<const op::v0::ConvolutionBackpropData*>(&node);
reference::convolution_backprop_in<T>(args[1]->get_data_ptr<const T>(),
args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
c->get_input_shape(1),
c->get_input_shape(0),
c->get_data_batch_shape(),
c->get_data_dilation_strides_forward(),
c->get_window_dilation_strides_forward(),
c->compute_backward_delta_out_pad_below(),
c->compute_backward_delta_out_pad_above(),
c->get_window_movement_strides_forward());
break;
}
case OP_TYPEID::Cos:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::cos<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Cosh:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::cosh<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::CTCLoss_v4:
{
const op::v4::CTCLoss* ctc_loss = static_cast<const op::v4::CTCLoss*>(&node);
auto t_int = node.get_input_element_type(1);
if (t_int == element::i32)
{
reference::CTCLoss<T, int32_t>(
args[0]->get_data_ptr<const T>(),
ctc_loss->get_input_shape(0),
args[1]->get_data_ptr<const int32_t>(),
args[2]->get_data_ptr<const int32_t>(),
args[3]->get_data_ptr<const int32_t>(),
args.size() > 4 ? args[4]->get_data_ptr<const int32_t>() : nullptr,
ctc_loss->get_preprocess_collapse_repeated(),
ctc_loss->get_ctc_merge_repeated(),
ctc_loss->get_unique(),
out[0]->get_data_ptr<T>());
}
else if (t_int == element::i64)
{
reference::CTCLoss<T, int64_t>(
args[0]->get_data_ptr<const T>(),
ctc_loss->get_input_shape(0),
args[1]->get_data_ptr<const int64_t>(),
args[2]->get_data_ptr<const int64_t>(),
args[3]->get_data_ptr<const int64_t>(),
args.size() > 4 ? args[4]->get_data_ptr<const int64_t>() : nullptr,
ctc_loss->get_preprocess_collapse_repeated(),
ctc_loss->get_ctc_merge_repeated(),
ctc_loss->get_unique(),
out[0]->get_data_ptr<T>());
}
break;
}
case OP_TYPEID::CumSum:
{
const op::CumSum* cumsum = static_cast<const op::CumSum*>(&node);
auto axis_et = node.get_input_element_type(1);
if (axis_et == element::i32)
{
reference::cumsum<T, int32_t>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const int32_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
cumsum->is_exclusive(),
cumsum->is_reverse());
}
else if (axis_et == element::i64)
{
reference::cumsum<T, int64_t>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const int64_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
cumsum->is_exclusive(),
cumsum->is_reverse());
}
break;
}
case OP_TYPEID::Dequantize:
{
const op::Dequantize* dequantize = static_cast<const op::Dequantize*>(&node);
auto type = dequantize->get_element_type();
if (type == element::f32)
{
reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const float>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<float>(),
node.get_input_shape(0),
node.get_input_shape(1),
dequantize->get_axes());
}
else if (type == element::f64)
{
reference::dequantize<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const double>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<double>(),
node.get_input_shape(0),
node.get_input_shape(1),
dequantize->get_axes());
}
else
{
std::stringstream ss;
ss << "unsupported element type " << type << " op Dequantize";
throw std::runtime_error(ss.str());
}
break;
}
case OP_TYPEID::Dot:
{
const op::Dot* dot = static_cast<const op::Dot*>(&node);
reference::dot(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
dot->get_reduction_axes_count());
break;
}
case OP_TYPEID::EmbeddingBagOffsetsSum_v3:
{
const op::EmbeddingBagOffsetsSum* embed =
static_cast<const op::EmbeddingBagOffsetsSum*>(&node);
auto indicesType = embed->input(1).get_element_type();
size_t indices_num = shape_size(embed->get_input_shape(1));
if (indicesType == element::u64 || indicesType == element::i64)
{
reference::embeddingBagOffsetsSum<T, size_t>(
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const size_t>(),
args[2]->get_data_ptr<const size_t>(),
args.size() > 3 ? args[3]->get_data_ptr<const size_t>() : nullptr,
args.size() > 4 ? args[4]->get_data_ptr<const T>() : nullptr,
out[0]->get_data_ptr<T>(),
indices_num,
embed->get_shape());
}
else if (indicesType == element::u32 || indicesType == element::i32)
{
reference::embeddingBagOffsetsSum<T, unsigned>(
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const unsigned>(),
args[2]->get_data_ptr<const unsigned>(),
args.size() > 3 ? args[3]->get_data_ptr<const unsigned>() : nullptr,
args.size() > 4 ? args[4]->get_data_ptr<const T>() : nullptr,
out[0]->get_data_ptr<T>(),
indices_num,
embed->get_shape());
}
else
{
throw ngraph_error(std::string("Unsupported index type ") +
indicesType.c_type_string() +
std::string(" in EmbeddingBagOffsetsSum"));
}
break;
}
case OP_TYPEID::EmbeddingBagPackedSum_v3:
{
const op::EmbeddingBagPackedSum* embed =
static_cast<const op::EmbeddingBagPackedSum*>(&node);
auto indicesType = embed->input(1).get_element_type();
if (indicesType == element::u64 || indicesType == element::i64)
{
reference::embeddingBagPackedSum<T, size_t>(
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const size_t>(),
args.size() > 2 ? args[2]->get_data_ptr<const T>() : nullptr,
out[0]->get_data_ptr<T>(),
embed->get_input_shape(1),
embed->get_shape());
}
else if (indicesType == element::u32 || indicesType == element::i32)
{
reference::embeddingBagPackedSum<T, unsigned>(
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const unsigned>(),
args.size() > 2 ? args[2]->get_data_ptr<const T>() : nullptr,
out[0]->get_data_ptr<T>(),
embed->get_input_shape(1),
embed->get_shape());
}
else
{
throw ngraph_error(std::string("Unsupported index type ") +
indicesType.c_type_string() +
std::string(" in EmbeddingBagPackedSum"));
}
break;
}
case OP_TYPEID::EmbeddingSegmentsSum_v3:
{
const op::EmbeddingSegmentsSum* embed =
static_cast<const op::EmbeddingSegmentsSum*>(&node);
auto indicesType = embed->input(1).get_element_type();
size_t indices_num = shape_size(embed->get_input_shape(1));
if (indicesType == element::u64 || indicesType == element::i64)
{
reference::embeddingSegmentsSum<T, size_t>(
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const size_t>(),
args[2]->get_data_ptr<const size_t>(),
args.size() > 4 ? args[4]->get_data_ptr<const size_t>() : nullptr,
args.size() > 5 ? args[5]->get_data_ptr<const T>() : nullptr,
out[0]->get_data_ptr<T>(),
embed->get_input_shape(0),
embed->get_input_shape(1),
embed->get_shape());
}
else if (indicesType == element::u32 || indicesType == element::i32)
{
reference::embeddingSegmentsSum<T, unsigned>(
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const unsigned>(),
args[2]->get_data_ptr<const unsigned>(),
args.size() > 4 ? args[4]->get_data_ptr<const unsigned>() : nullptr,
args.size() > 5 ? args[5]->get_data_ptr<const T>() : nullptr,
out[0]->get_data_ptr<T>(),
embed->get_input_shape(0),
embed->get_input_shape(1),
embed->get_shape());
}
else
{
throw ngraph_error(std::string("Unsupported index type ") +
indicesType.c_type_string() +
std::string(" in EmbeddingSegmentsSum"));
}
break;
}
case OP_TYPEID::Erf:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::erf<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::ExtractImagePatches_v3:
{
const op::ExtractImagePatches* extImgPatches =
static_cast<const op::ExtractImagePatches*>(&node);
reference::extractImagePatches<T, size_t>(extImgPatches,
args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
extImgPatches->get_input_shape(0),
extImgPatches->get_shape());
break;
}
case OP_TYPEID::Exp:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::exp<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
#ifdef INTERPRETER_USE_HYBRID
case OP_TYPEID::FunctionCall:
{
auto f = static_cast<const runtime::hybrid::op::FunctionCall*>(&node);
auto backend = f->get_backend();
auto executable = f->get_executable();
std::vector<std::shared_ptr<Tensor>> outputs;
std::vector<std::shared_ptr<Tensor>> inputs;
for (const std::shared_ptr<HostTensor>& t : out)
{
auto backend_tensor = backend->create_tensor(
t->get_element_type(), t->get_shape(), t->get_data_ptr());
outputs.push_back(backend_tensor);
}
for (const std::shared_ptr<HostTensor>& t : args)
{
auto backend_tensor = backend->create_tensor(
t->get_element_type(), t->get_shape(), t->get_data_ptr());
inputs.push_back(backend_tensor);
}
executable->call(outputs, inputs);
break;
}
#endif
case OP_TYPEID::Floor:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::floor<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::GatherND:
{
if (node.get_input_element_type(1) == element::i64)
{
reference::gather_nd<T, int64_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int64_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0));
}
else if (node.get_input_element_type(1) == element::i32)
{
reference::gather_nd<T, int32_t>(args[0]->get_data_ptr<T>(),
args[1]->get_data_ptr<int32_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0));
}
else
{
throw ngraph_error("Unexpected type");
}
break;
}
case OP_TYPEID::GRUCell_v3:
{
const op::v3::GRUCell* gru_cell = static_cast<const op::v3::GRUCell*>(&node);
runtime::reference::gru_cell(args[0]->get_data_ptr<T>(),
args[0]->get_shape(),
args[1]->get_data_ptr<T>(),
args[1]->get_shape(),
args[2]->get_data_ptr<T>(),
args[2]->get_shape(),
args[3]->get_data_ptr<T>(),
args[3]->get_shape(),
args[4]->get_data_ptr<T>(),
args[4]->get_shape(),
out[0]->get_data_ptr<T>(),
gru_cell->get_activations()[0],
gru_cell->get_activations()[1],
gru_cell->get_clip(),
gru_cell->get_linear_before_reset());
break;
}
case OP_TYPEID::LSTMCell_v4:
{
const op::v4::LSTMCell* lstm_cell = static_cast<const op::v4::LSTMCell*>(&node);
runtime::reference::lstm_cell(args[0]->get_data_ptr<T>(),
args[0]->get_shape(),
args[1]->get_data_ptr<T>(),
args[1]->get_shape(),
args[2]->get_data_ptr<T>(),
args[2]->get_shape(),
args[3]->get_data_ptr<T>(),
args[3]->get_shape(),
args[4]->get_data_ptr<T>(),
args[4]->get_shape(),
args[5]->get_data_ptr<T>(),
args[5]->get_shape(),
out[0]->get_data_ptr<T>(),
out[1]->get_data_ptr<T>(),
lstm_cell->get_activations()[0],
lstm_cell->get_activations()[1],
lstm_cell->get_activations()[2],
lstm_cell->get_clip());
break;
}
case OP_TYPEID::RNNCell_v0:
{
const op::v0::RNNCell* rnn_cell = static_cast<const op::v0::RNNCell*>(&node);
runtime::reference::rnn_cell(args[0]->get_data_ptr<T>(),
args[0]->get_shape(),
args[1]->get_data_ptr<T>(),
args[1]->get_shape(),
args[2]->get_data_ptr<T>(),
args[2]->get_shape(),
args[3]->get_data_ptr<T>(),
args[3]->get_shape(),
args[4]->get_data_ptr<T>(),
args[4]->get_shape(),
out[0]->get_data_ptr<T>(),
rnn_cell->get_activations()[0],
rnn_cell->get_clip());
break;
}
case OP_TYPEID::Log:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::log<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::LRN:
{
const op::LRN* lrn = static_cast<const op::LRN*>(&node);
reference::lrn<T>(args[0]->get_data_ptr<const T>(),
lrn->get_reduction_axes(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
lrn->get_alpha(),
lrn->get_beta(),
lrn->get_bias(),
lrn->get_nsize());
break;
}
case OP_TYPEID::Negative:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::negate<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::LogicalNot_v1:
case OP_TYPEID::Not:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::logical_not(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::OneHot:
{
const op::OneHot* oh = static_cast<const op::OneHot*>(&node);
reference::one_hot<T>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
oh->get_one_hot_axis());
break;
}
case OP_TYPEID::Parameter: break;
case OP_TYPEID::Quantize:
{
const op::Quantize* quantize = static_cast<const op::Quantize*>(&node);
auto type = quantize->get_element_type();
if (type == element::u8)
{
reference::quantize<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const uint8_t>(),
out[0]->get_data_ptr<uint8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
quantize->get_axes(),
quantize->get_round_mode());
}
else if (type == element::i8)
{
reference::quantize<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const int8_t>(),
out[0]->get_data_ptr<int8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
quantize->get_axes(),
quantize->get_round_mode());
}
else if (type == element::i32)
{
reference::quantize<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const int32_t>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
quantize->get_axes(),
quantize->get_round_mode());
}
else
{
std::stringstream ss;
ss << "unsupported element type " << type << " op Quantize";
throw std::runtime_error(ss.str());
}
break;
}
case OP_TYPEID::QuantizedConvolution:
{
const op::QuantizedConvolution* qc =
static_cast<const op::QuantizedConvolution*>(&node);
auto input_element_type = qc->get_input_element_type(0);
auto filter_element_type = qc->get_input_element_type(1);
auto output_element_type = qc->get_output_element_type(0);
if (input_element_type == element::u8 && filter_element_type == element::i8 &&
output_element_type == element::i8)
{
reference::convolution<uint8_t, int8_t, int8_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const int8_t>(),
out[0]->get_data_ptr<int8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const int8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const int8_t>());
}
else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
output_element_type == element::u8)
{
reference::convolution<uint8_t, uint8_t, uint8_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const uint8_t>(),
out[0]->get_data_ptr<uint8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const uint8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const uint8_t>());
}
else if (input_element_type == element::u8 && filter_element_type == element::i8 &&
output_element_type == element::i32)
{
reference::convolution<uint8_t, int8_t, int32_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const int8_t>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const int8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const int32_t>());
}
else if (input_element_type == element::u8 && filter_element_type == element::u8 &&
output_element_type == element::i32)
{
reference::convolution<uint8_t, uint8_t, int32_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const uint8_t>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
qc->get_window_movement_strides(),
qc->get_window_dilation_strides(),
qc->get_padding_below(),
qc->get_padding_above(),
qc->get_data_dilation_strides(),
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const uint8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const int32_t>());
}
else
{
std::stringstream ss;
ss << "unsupported element type";
throw std::runtime_error(ss.str());
}
break;
}
case OP_TYPEID::QuantizedDot:
{
const op::QuantizedDot* qd = static_cast<const op::QuantizedDot*>(&node);
auto input0_element_type = qd->get_input_element_type(0);
auto input1_element_type = qd->get_input_element_type(1);
auto output_element_type = qd->get_output_element_type(0);
if (input0_element_type == element::u8 && input1_element_type == element::i8 &&
output_element_type == element::i8)
{
reference::dot<uint8_t, int8_t, int8_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const int8_t>(),
out[0]->get_data_ptr<int8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
1,
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const int8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const int8_t>());
}
else if (input0_element_type == element::u8 && input1_element_type == element::u8 &&
output_element_type == element::u8)
{
reference::dot<uint8_t, uint8_t, uint8_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const uint8_t>(),
out[0]->get_data_ptr<uint8_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
1,
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const uint8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const uint8_t>());
}
else if (input0_element_type == element::u8 && input1_element_type == element::u8 &&
output_element_type == element::i32)
{
reference::dot<uint8_t, uint8_t, int32_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const uint8_t>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
1,
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const uint8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const int32_t>());
}
else if (input0_element_type == element::u8 && input1_element_type == element::i8 &&
output_element_type == element::i32)
{
reference::dot<uint8_t, int8_t, int32_t, int32_t>(
args[0]->get_data_ptr<const uint8_t>(),
args[1]->get_data_ptr<const int8_t>(),
out[0]->get_data_ptr<int32_t>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_output_shape(0),
1,
args[2]->get_data_ptr<const float>(),
args[3]->get_data_ptr<const uint8_t>(),
args[4]->get_data_ptr<const float>(),
args[5]->get_data_ptr<const int8_t>(),
args[6]->get_data_ptr<const float>(),
args[7]->get_data_ptr<const int32_t>());
}
else
{
std::stringstream ss;
ss << "unsupported element type";
throw std::runtime_error(ss.str());
}
break;
}
case OP_TYPEID::Relu:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::relu<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::ReplaceSlice:
{
const op::ReplaceSlice* slice = static_cast<const op::ReplaceSlice*>(&node);
reference::replace_slice<T>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(1),
slice->get_lower_bounds(),
slice->get_upper_bounds(),
slice->get_strides(),
node.get_output_shape(0));
break;
}
case OP_TYPEID::Reverse:
{
const op::Reverse* reverse = static_cast<const op::Reverse*>(&node);
reference::reverse(args[0]->get_data_ptr<const char>(),
out[0]->get_data_ptr<char>(),
node.get_input_shape(0),
node.get_output_shape(0),
reverse->get_reversed_axes(),
args[0]->get_element_type().size());
break;
}
case OP_TYPEID::ReverseSequence:
{
const op::ReverseSequence* reverse = static_cast<const op::ReverseSequence*>(&node);
if (node.get_input_element_type(1) == element::i32)
{
reference::reverse_sequence<T, int32_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
reverse->get_batch_axis(),
reverse->get_sequence_axis(),
args[1]->get_data_ptr<const int32_t>());
}
else
{
throw ngraph_error("only int32 indices are supported");
}
break;
}
case OP_TYPEID::Round:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::round<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Select:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::select<T>(args[0]->get_data_ptr<const char>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
element_count);
break;
}
case OP_TYPEID::Sigmoid:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sigmoid<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Sign:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sign<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Sin:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sin<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Sinh:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sinh<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Sqrt:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::sqrt<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Tan:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::tan<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::Tanh:
{
size_t element_count = shape_size(node.get_output_shape(0));
reference::tanh<T>(
args[0]->get_data_ptr<const T>(), out[0]->get_data_ptr<T>(), element_count);
break;
}
case OP_TYPEID::TopK:
{
const op::TopK* topk = static_cast<const op::TopK*>(&node);
if (node.get_output_element_type(0) == element::i64)
{
reference::topk<T, int64_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int64_t>(),
out[1]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
topk->get_top_k_axis(),
topk->get_k(),
topk->get_compute_max(),
topk->get_sort());
}
else if (node.get_output_element_type(0) == element::i32)
{
reference::topk<T, int32_t>(args[0]->get_data_ptr<const T>(),
out[0]->get_data_ptr<int32_t>(),
out[1]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_output_shape(0),
topk->get_top_k_axis(),
topk->get_k(),
topk->get_compute_max(),
topk->get_sort());
}
else
{
throw ngraph_error("Unexpected type");
}
break;
}
case OP_TYPEID::DetectionOutput_v0:
{
const op::DetectionOutput* detOut = static_cast<const op::DetectionOutput*>(&node);
reference::referenceDetectionOutput<T> refDetOut(
detOut->get_attrs(), node.get_input_shape(0), node.get_input_shape(2));
if (node.get_input_size() == 3)
{
refDetOut.run(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const T>(),
nullptr,
nullptr,
out[0]->get_data_ptr<T>());
}
else if (node.get_input_size() == 5)
{
refDetOut.run(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const T>(),
args[2]->get_data_ptr<const T>(),
args[3]->get_data_ptr<const T>(),
args[4]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>());
}
else
{
throw ngraph_error("DetectionOutput layer supports only 3 or 5 inputs");
}
break;
}
case OP_TYPEID::ScatterNDUpdate_v3:
{
const op::ScatterNDUpdate* scatterNDUpd =
static_cast<const op::v3::ScatterNDUpdate*>(&node);
auto idxType = scatterNDUpd->get_input_element_type(1);
if (idxType == element::i32)
{
reference::scatterNdUpdate<T, int32_t>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const int32_t>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2));
}
else if (idxType == element::i64)
{
reference::scatterNdUpdate<T, int64_t>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const int64_t>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2));
}
else
{
throw ngraph_error(
"ScatterNDUpdate layer support only i32 and i64 'indices' input precision!");
}
break;
}
case OP_TYPEID::GatherTree_v1:
{
reference::gather_tree(args[0]->get_data_ptr<const char>(),
args[1]->get_data_ptr<const char>(),
args[2]->get_data_ptr<const char>(),
args[3]->get_data_ptr<const char>(),
out[0]->get_data_ptr<char>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2),
node.get_input_shape(3),
args[1]->get_element_type());
break;
}
// Fused Ops are not supported in interpreter. They need to be decomposed before execution
case OP_TYPEID::DepthToSpace:
case OP_TYPEID::FakeQuantize:
case OP_TYPEID::Gather:
case OP_TYPEID::Gelu:
case OP_TYPEID::GRN:
case OP_TYPEID::GroupConvolution:
case OP_TYPEID::GroupConvolutionBackpropData:
case OP_TYPEID::HardSigmoid:
case OP_TYPEID::Interpolate:
case OP_TYPEID::LSTMSequence:
case OP_TYPEID::MVN:
case OP_TYPEID::NormalizeL2:
case OP_TYPEID::PRelu:
case OP_TYPEID::ScatterUpdate_v3:
case OP_TYPEID::Selu:
case OP_TYPEID::ShuffleChannels:
case OP_TYPEID::SpaceToDepth:
case OP_TYPEID::Split:
case OP_TYPEID::SquaredDifference:
case OP_TYPEID::StopGradient:
case OP_TYPEID::TensorIterator:
case OP_TYPEID::Tile:
case OP_TYPEID::UnknownOp:
throw unsupported_op("Unsupported op '" + node.description() + "'");
case OP_TYPEID::Add:
case OP_TYPEID::Broadcast:
case OP_TYPEID::Clamp:
case OP_TYPEID::Concat:
case OP_TYPEID::Constant:
case OP_TYPEID::Divide:
case OP_TYPEID::Equal:
case OP_TYPEID::Greater:
case OP_TYPEID::GreaterEq:
case OP_TYPEID::Less:
case OP_TYPEID::LessEq:
case OP_TYPEID::LessEqual_v1:
case OP_TYPEID::LogicalAnd_v1:
case OP_TYPEID::LogicalOr_v1:
case OP_TYPEID::LogicalXor_v1:
case OP_TYPEID::MatMul:
case OP_TYPEID::Max:
case OP_TYPEID::Maximum:
case OP_TYPEID::Min:
case OP_TYPEID::Minimum:
case OP_TYPEID::Multiply:
case OP_TYPEID::NonZero_v3:
case OP_TYPEID::NotEqual:
case OP_TYPEID::Or:
case OP_TYPEID::Power:
case OP_TYPEID::Product:
case OP_TYPEID::Range:
case OP_TYPEID::Reshape:
case OP_TYPEID::Result:
case OP_TYPEID::ShapeOf_v3:
case OP_TYPEID::ShapeOf:
case OP_TYPEID::Softmax:
case OP_TYPEID::Squeeze:
case OP_TYPEID::Sum:
case OP_TYPEID::Subtract:
case OP_TYPEID::Unsqueeze:
case OP_TYPEID::Xor:
case OP_TYPEID::Slice:
// These ops are handled by op evaluators so nothing to do
break;
#if defined(__GNUC__) && !(__GNUC__ == 4 && __GNUC_MINOR__ == 8)
#pragma GCC diagnostic pop
#endif
}
}
};
NGRAPH_SUPPRESS_DEPRECATED_END