Use new evaluate in template plugin (#15753)

* Use new evaluate method in template plugin

* Add tensor at the end of each iteration

* Remove class TemporaryOverrideOutputs

* Set shape of tensor after evaluate

* Revert "Remove class TemporaryOverrideOutputs"

This reverts commit e345ba9188.

* Update tensors when evaluate passed

* Copy data Tensor when HostTensor was initialized

* Set shape to output tensor in TemporaryOverrideOutputs

* Fix code style

* Add test

* Remove unused code

* Create reshape with scalar when shape is empty

* Reshape, special_zero = true

* Revert "Create reshape with scalar when shape is empty"

This reverts commit 0f901f419a.

* Use Shape with size zero and value max_int for dynamic tensors

* Restore Shape{0} for dynamic tensors

* Revert "Restore Shape{0} for dynamic tensors"

This reverts commit cb2d0e58eb.

* Temporary remove the test

* Use shape{0} for dynamic tensors

* Revert "Use shape{0} for dynamic tensors"

This reverts commit 08460a486b.

* Use Shape{0} for dynamic tensors

* Use new evaluate in template plugin
- Add tensor conversion between ov::Tensor <-> HostTensor
- Add shape utils to create special case shape to be dynamic shape
- Utils are in dev API to remove duplicates

* Move WA for set shape into the ov::tensor.

* Remove dynamic shape from or_tensor helper

* Mark tensor conversion utils as deprecated
- move shape util as core internal only
- update transpose test to not use deprecated functions

* Add missing deprecate suppression macro

---------

Co-authored-by: Artur Kulikowski <artur.kulikowski@intel.com>
This commit is contained in:
Pawel Raasz 2023-02-20 07:50:42 +01:00 committed by GitHub
parent 7cffe848d6
commit 69728cb4ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 251 additions and 131 deletions

View File

@ -0,0 +1,54 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "ngraph/runtime/host_tensor.hpp"
#include "openvino/runtime/tensor.hpp"
namespace ov {
namespace util {
/**
* @brief Wrap host tensor into ov::Tensor.
*
* @param t Input tensor for conversion.
* @return ov::Tensor which points to host tensor data. Can return not allocated or special dynamic depends on input
* tensor state.
*/
OPENVINO_DEPRECATED("This function is deprecated and will be removed soon.")
OPENVINO_API Tensor wrap_tensor(const ngraph::HostTensorPtr& t);
/**
* @brief Wrap node output into ov::Tensor.
*
* @param output Node output to make tensor.
* @return ov::Tensor from output properties.
*/
OPENVINO_DEPRECATED("This function is deprecated and will be removed soon.")
OPENVINO_API Tensor wrap_tensor(const Output<Node>& output);
/**
* @brief Make vector of wrapped tensors.
*
* @param tensors Input vector of host tensor to convert.
* @return ov::TensorVectors, can contains not allocated or dynamic tensor depends on input tensor properties.
*/
OPENVINO_DEPRECATED("This function is deprecated and will be removed soon.")
OPENVINO_API TensorVector wrap_tensors(const std::vector<ngraph::HostTensorPtr>& tensors);
/**
* @brief Update output host tensors if they got dynamic shapee before evaluation (not allocated).
*
* Other tensor not requires update as they are created from outputs and points to same data blob.
*
* @param output_values Temporary ov::Tensor vector created from outputs for evaluation
* @param outputs Output host tensors vector to update.
*/
OPENVINO_DEPRECATED("This function is deprecated and will be removed soon.")
OPENVINO_API void update_output_host_tensors(const std::vector<ngraph::HostTensorPtr>& output_values,
const ov::TensorVector& outputs);
} // namespace util
} // namespace ov

View File

@ -8,6 +8,8 @@
#include "ngraph/validation_util.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/opsets/opset10.hpp"
#include "shape_util.hpp"
#include "tensor_conversion_util.hpp"
namespace {
using namespace ov;
@ -80,18 +82,9 @@ ov::Tensor evaluate_bound(const Output<Node>& output, bool is_upper, bool invali
for (const auto& node : order) {
ov::TensorVector outputs;
for (const auto& out : node->outputs()) {
const auto& out_shape = out.get_partial_shape();
const auto& out_et = out.get_element_type();
if (out_et.is_dynamic()) {
outputs.emplace_back();
} else if (out_shape.is_static()) {
outputs.emplace_back(out_et, out_shape.to_shape());
} else if (out_shape.rank().is_static()) {
outputs.emplace_back(out_et, Shape(out_shape.rank().get_length()));
} else {
outputs.emplace_back(out_et, Shape{0});
}
OPENVINO_SUPPRESS_DEPRECATED_START
outputs.push_back(util::wrap_tensor(out));
OPENVINO_SUPPRESS_DEPRECATED_END
}
if (is_upper ? node->evaluate_upper(outputs) : node->evaluate_lower(outputs)) {
@ -176,11 +169,12 @@ ov::Tensor equality_mask(const ov::Tensor& tensor, const std::shared_ptr<op::v0:
}
ov::Tensor or_tensor(const ov::Tensor& lhs, const ov::Tensor& rhs) {
auto outs = ov::TensorVector{{lhs.get_element_type(), Shape{0}}};
op::v1::LogicalOr(std::make_shared<op::v0::Parameter>(lhs.get_element_type(), lhs.get_shape()),
std::make_shared<op::v0::Parameter>(rhs.get_element_type(), rhs.get_shape()),
ngraph::op::AutoBroadcastType::NUMPY)
.evaluate(outs, ov::TensorVector{lhs, rhs});
auto logical_or = op::v1::LogicalOr(std::make_shared<op::v0::Parameter>(lhs.get_element_type(), lhs.get_shape()),
std::make_shared<op::v0::Parameter>(rhs.get_element_type(), rhs.get_shape()),
op::AutoBroadcastType::NUMPY);
auto outs = ov::TensorVector{{lhs.get_element_type(), logical_or.get_output_shape(0)}};
logical_or.evaluate(outs, ov::TensorVector{lhs, rhs});
return outs.front();
}
@ -466,8 +460,10 @@ bool ov::default_label_evaluator(const Node* node,
for (size_t i = 0; i < outputs_count; ++i) {
const auto& partial_shape = node->get_output_partial_shape(i);
// Set shape for static or Shape{0} for dynamic to postpone memory allocation
auto shape = partial_shape.is_static() ? partial_shape.to_shape() : Shape{0};
// Set shape for static or special dynamic if partial shape is dynamic.
OPENVINO_SUPPRESS_DEPRECATED_START
auto shape = partial_shape.is_static() ? partial_shape.to_shape() : util::make_dynamic_shape();
OPENVINO_SUPPRESS_DEPRECATED_END
outputs.emplace_back(element::from<label_t>(), shape);
}

View File

@ -27,6 +27,7 @@
#include "openvino/op/util/variable_extension.hpp"
#include "openvino/pass/manager.hpp"
#include "shared_node_info.hpp"
#include "tensor_conversion_util.hpp"
#include "transformations/smart_reshape/smart_reshape.hpp"
using namespace std;
@ -485,50 +486,15 @@ int64_t ov::Model::get_result_index(const Output<const Node>& value) const {
return -1;
}
namespace {
inline ov::Tensor create_tmp_tensor(const ngraph::HostTensorPtr& tensor) {
if (tensor->get_partial_shape().is_static()) {
ov::Shape shape = tensor->get_shape();
return ov::Tensor(tensor->get_element_type(), shape, tensor->get_data_ptr());
} else {
if (tensor->get_element_type().is_dynamic()) {
return {};
} else {
return ov::Tensor(tensor->get_element_type(), {0});
}
}
}
inline ov::TensorVector create_tmp_tensors(const ngraph::HostTensorVector& tensors) {
ov::TensorVector result;
result.reserve(tensors.size());
for (const auto& tensor : tensors) {
result.emplace_back(create_tmp_tensor(tensor));
}
return result;
}
inline void update_output_tensors(const ngraph::HostTensorVector& output_values, const ov::TensorVector& outputs) {
OPENVINO_ASSERT(output_values.size(), outputs.size());
for (size_t i = 0; i < outputs.size(); i++) {
const auto& tensor = output_values[i];
if (tensor->get_partial_shape().is_dynamic()) {
tensor->set_element_type(outputs[i].get_element_type());
tensor->set_shape(outputs[i].get_shape());
void* dst_data = tensor->get_data_ptr();
memcpy(dst_data, outputs[i].data(), tensor->get_size_in_bytes());
}
}
}
} // namespace
bool ov::Model::evaluate(const HostTensorVector& output_tensors,
const HostTensorVector& input_tensors,
EvaluationContext evaluation_context) const {
ov::TensorVector outputs = create_tmp_tensors(output_tensors);
ov::TensorVector inputs = create_tmp_tensors(input_tensors);
OPENVINO_SUPPRESS_DEPRECATED_START
auto outputs = ov::util::wrap_tensors(output_tensors);
auto inputs = ov::util::wrap_tensors(input_tensors);
bool sts = evaluate(outputs, inputs, std::move(evaluation_context));
update_output_tensors(output_tensors, outputs);
ov::util::update_output_host_tensors(output_tensors, outputs);
OPENVINO_SUPPRESS_DEPRECATED_END
return sts;
}
@ -560,13 +526,7 @@ bool ov::Model::evaluate(ov::TensorVector& output_tensors,
for (const auto& v : node->outputs()) {
auto it = output_tensor_map.find(v);
if (it == output_tensor_map.end()) {
if (v.get_partial_shape().is_dynamic() || v.get_element_type().is_dynamic()) {
ov::Tensor c = create_tmp_tensor(std::make_shared<HostTensor>(v));
output_tensors.push_back(c);
} else {
ov::Tensor c(v.get_element_type(), v.get_shape());
output_tensors.push_back(c);
}
output_tensors.push_back(util::wrap_tensor(v));
} else {
output_tensors.push_back(it->second);
}

View File

@ -20,7 +20,9 @@
#include "ngraph/pattern/matcher.hpp"
#include "openvino/core/descriptor/input.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "shape_util.hpp"
#include "shared_node_info.hpp"
#include "tensor_conversion_util.hpp"
using namespace std;
@ -693,30 +695,23 @@ protected:
}
};
inline ov::Tensor create_tensor_from_output(const ov::Output<ov::Node>& output) {
if (output.get_element_type().is_dynamic()) {
return ov::Tensor();
} else if (output.get_partial_shape().is_dynamic()) {
return ov::Tensor(output.get_element_type(), {0});
inline ngraph::HostTensorPtr make_tmp_host_tensor(const ov::Tensor& t) {
OPENVINO_SUPPRESS_DEPRECATED_START
if (!t) {
return std::make_shared<DynamicTensor>(ov::element::dynamic);
} else if (ov::util::is_dynamic_shape(t.get_shape())) {
return std::make_shared<DynamicTensor>(t.get_element_type());
} else {
return std::make_shared<ngraph::runtime::HostTensor>(t.get_element_type(), t.get_shape(), t.data());
}
return ov::Tensor(output.get_element_type(), output.get_shape());
OPENVINO_SUPPRESS_DEPRECATED_END
}
inline ngraph::HostTensorVector create_tmp_tensors(const ov::TensorVector& tensors) {
ngraph::HostTensorVector result;
result.reserve(tensors.size());
for (const auto& tensor : tensors) {
if (!tensor || tensor.get_shape() == ov::Shape{0}) {
auto el_type = ov::element::dynamic;
if (tensor)
el_type = tensor.get_element_type();
// Create dynamic tensor
result.emplace_back(std::make_shared<DynamicTensor>(el_type));
} else {
result.emplace_back(std::make_shared<ngraph::runtime::HostTensor>(tensor.get_element_type(),
tensor.get_shape(),
tensor.data()));
}
result.push_back(make_tmp_host_tensor(tensor));
}
return result;
}
@ -800,11 +795,11 @@ bool ov::Node::constant_fold(OutputVector& output_values, const OutputVector& in
}
TensorVector output_tensors;
OPENVINO_SUPPRESS_DEPRECATED_START
for (const auto& output : outputs()) {
output_tensors.push_back(create_tensor_from_output(output));
output_tensors.push_back(ov::util::wrap_tensor(output));
}
OPENVINO_SUPPRESS_DEPRECATED_START
if (evaluate(output_tensors, input_tensors)) {
for (size_t i = 0; i < output_tensors.size(); ++i) {
output_values[i] = make_shared<ngraph::op::Constant>(output_tensors[i]);

View File

@ -16,6 +16,7 @@
#include "ngraph/op/select.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/divide.hpp"
#include "shape_util.hpp"
using namespace std;
using namespace ngraph;
@ -72,12 +73,12 @@ ov::Tensor equality_mask(const ov::Tensor& tensor, const shared_ptr<op::Constant
}
ov::Tensor or_tensor(const ov::Tensor& lhs, const ov::Tensor& rhs) {
auto outs = ov::TensorVector{{lhs.get_element_type(), Shape{0}}};
auto logical_or = op::v1::LogicalOr(std::make_shared<op::Parameter>(lhs.get_element_type(), lhs.get_shape()),
std::make_shared<op::Parameter>(rhs.get_element_type(), rhs.get_shape()),
ngraph::op::AutoBroadcastType::NUMPY);
op::v1::LogicalOr(std::make_shared<op::Parameter>(lhs.get_element_type(), lhs.get_shape()),
std::make_shared<op::Parameter>(rhs.get_element_type(), rhs.get_shape()),
ngraph::op::AutoBroadcastType::NUMPY)
.evaluate(outs, ov::TensorVector{lhs, rhs});
auto outs = ov::TensorVector{{lhs.get_element_type(), logical_or.get_output_shape(0)}};
logical_or.evaluate(outs, ov::TensorVector{lhs, rhs});
return outs.front();
}

View File

@ -9,6 +9,7 @@
#include "openvino/core/attribute_visitor.hpp"
#include "openvino/op/softsign.hpp"
#include "openvino/runtime/tensor.hpp"
#include "shape_util.hpp"
namespace {
template <ov::element::Type_t ET>
@ -90,6 +91,7 @@ bool ov::op::v9::SoftSign::evaluate(ov::TensorVector& outputs,
const auto& in = inputs[0];
auto& out = outputs[0];
out.set_shape(in.get_shape());
return evaluate_softsign(in, out);
}

View File

@ -9,6 +9,7 @@
#include "openvino/core/except.hpp"
#include "openvino/runtime/tensor.hpp"
#include "runtime/blob_allocator.hpp"
#include "shape_util.hpp"
namespace ov {
@ -98,7 +99,14 @@ element::Type Tensor::get_element_type() const {
}
void Tensor::set_shape(const ov::Shape& shape) {
OV_TENSOR_STATEMENT(_impl->setShape({shape.begin(), shape.end()}));
// WA for tensor conversion from host tensor with dynamic shape.
if (util::is_dynamic_shape(get_shape())) {
_impl = make_blob_with_precision(
{_impl->getTensorDesc().getPrecision(), shape, ie::TensorDesc::getLayoutByRank(shape.size())});
_impl->allocate();
} else {
OV_TENSOR_STATEMENT(_impl->setShape({shape.begin(), shape.end()}));
}
}
Shape Tensor::get_shape() const {

View File

@ -6,6 +6,8 @@
#include <algorithm>
#include "shape_util.hpp"
using namespace ngraph;
template <>
@ -73,3 +75,18 @@ PartialShape ngraph::inject_pairs(const PartialShape& shape,
return PartialShape{result_dims};
}
}
namespace ov {
namespace util {
Shape make_dynamic_shape() {
return Shape{0, std::numeric_limits<size_t>::max()};
}
bool is_dynamic_shape(const Shape& s) {
OPENVINO_SUPPRESS_DEPRECATED_START
static const auto dyn_shape = make_dynamic_shape();
OPENVINO_SUPPRESS_DEPRECATED_END
return s == dyn_shape;
}
} // namespace util
} // namespace ov

View File

@ -0,0 +1,30 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/shape.hpp"
namespace ov {
namespace util {
/**
* @brief Makes spacial version of 2D ov::Shape which is recognize as dynamic.
*
* This is special case used for tensor <-> host tensor conversion to indicate that tensor got dynamic shape.
*
* @return 2-D shape with {0, SIZE_MAX}
*/
OPENVINO_DEPRECATED("This function is deprecated and will be removed soon.")
Shape make_dynamic_shape();
/**
* @brief Check if Shape is marked as dynamic.
*
* @param s Shape for check.
* @return True if shape is dynamic otherwise false.
*/
OPENVINO_DEPRECATED("This function is deprecated and will be removed soon.")
bool is_dynamic_shape(const Shape& s);
} // namespace util
} // namespace ov

View File

@ -0,0 +1,62 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "tensor_conversion_util.hpp"
#include "shape_util.hpp"
namespace ov {
namespace util {
OPENVINO_SUPPRESS_DEPRECATED_START
Tensor wrap_tensor(const ngraph::HostTensorPtr& t) {
const auto& et = t->get_element_type();
const auto& p_shape = t->get_partial_shape();
if (et.is_dynamic()) {
return {};
} else if (p_shape.is_static()) {
return {et, p_shape.to_shape(), t->get_data_ptr()};
} else {
return {et, make_dynamic_shape()};
}
}
Tensor wrap_tensor(const Output<Node>& output) {
const auto& et = output.get_element_type();
const auto& p_shape = output.get_partial_shape();
if (et.is_dynamic()) {
return {};
} else if (p_shape.is_static()) {
return {et, p_shape.to_shape()};
} else {
return {et, make_dynamic_shape()};
}
}
ov::TensorVector wrap_tensors(const std::vector<ngraph::HostTensorPtr>& tensors) {
ov::TensorVector out;
out.reserve(tensors.size());
for (const auto& ht : tensors) {
out.push_back(ov::util::wrap_tensor(ht));
}
return out;
}
void update_output_host_tensors(const std::vector<ngraph::HostTensorPtr>& output_values,
const ov::TensorVector& outputs) {
OPENVINO_ASSERT(output_values.size() == outputs.size());
for (size_t i = 0; i < output_values.size(); ++i) {
auto& ht = output_values[i];
auto& t = outputs[i];
if (ht->get_partial_shape().is_dynamic()) {
ht->set_element_type(t.get_element_type());
ht->set_shape(t.get_shape());
std::memcpy(ht->get_data_ptr(), t.data(), t.get_byte_size());
}
}
}
OPENVINO_SUPPRESS_DEPRECATED_END
} // namespace util
} // namespace ov

View File

@ -7,6 +7,7 @@
#include "gmock/gmock.h"
#include "openvino/opsets/opset9.hpp"
#include "sequnce_generator.hpp"
#include "transpose_shape_inference.hpp"
#include "type_prop.hpp"
namespace {
@ -21,26 +22,27 @@ using namespace ov;
using namespace ov::opset9;
using namespace testing;
using TestParam = std::tuple<std::vector<int32_t>, PartialShape>;
using TestParam = std::tuple<std::vector<int32_t>, Shape>;
class TransposeEvalBoundTest : public TestWithParam<TestParam> {
protected:
void SetUp() override {
std::tie(axes_order, p_shape) = GetParam();
std::tie(axes_order, shape) = GetParam();
std::generate_n(std::back_inserter(lower_values), shape_size(p_shape.get_min_shape()), SeqGen<int32_t>(-10));
std::generate_n(std::back_inserter(upper_values), shape_size(p_shape.get_min_shape()), SeqGen<int32_t>(20));
std::generate_n(std::back_inserter(lower_values), shape_size(shape), SeqGen<int32_t>(-10));
std::generate_n(std::back_inserter(upper_values), shape_size(shape), SeqGen<int32_t>(20));
lower_v_tensor = ov::Tensor(dtype, p_shape.get_min_shape(), lower_values.data());
upper_v_tensor = ov::Tensor(dtype, p_shape.get_min_shape(), upper_values.data());
lower_v_tensor = ov::Tensor(dtype, shape, lower_values.data());
upper_v_tensor = ov::Tensor(dtype, shape, upper_values.data());
axes_v_tensor = ov::Tensor(dtype, Shape{axes_order.size()}, axes_order.data());
arg = std::make_shared<Parameter>(dtype, p_shape);
arg = std::make_shared<Parameter>(dtype, shape);
order = std::make_shared<Parameter>(dtype, Shape{axes_order.size()});
transpose = std::make_shared<Transpose>(arg, order);
// prepare result tensors for evaluation
result = exp_result = TensorVector{Tensor(dtype, {0})};
auto a = std::vector<int64_t>(axes_order.begin(), axes_order.end());
result = exp_result = TensorVector{Tensor(dtype, op::v1::calc_output_shape(transpose.get(), shape, a))};
}
void node_set_lower_and_upper(Node* node, const ov::Tensor& lower, const ov::Tensor& upper) {
@ -53,7 +55,7 @@ protected:
}
}
PartialShape p_shape;
Shape shape;
element::Type dtype{element::from<int32_t>()};
element::Type label_dtype{element::from<label_t>()};
@ -69,19 +71,19 @@ protected:
INSTANTIATE_TEST_SUITE_P(evaluate_bound,
TransposeEvalBoundTest,
Values(std::make_tuple(std::vector<int32_t>{0}, PartialShape{4}),
std::make_tuple(std::vector<int32_t>{0, 1}, PartialShape{2, 5}),
std::make_tuple(std::vector<int32_t>{1, 0}, PartialShape{2, 5}),
std::make_tuple(std::vector<int32_t>{0, 1, 2}, PartialShape{2, 3, 1}),
std::make_tuple(std::vector<int32_t>{1, 2, 0}, PartialShape{2, 3, 1}),
std::make_tuple(std::vector<int32_t>{1, 3, 2, 0}, PartialShape{2, 3, 1, 5})),
Values(std::make_tuple(std::vector<int32_t>{0}, Shape{4}),
std::make_tuple(std::vector<int32_t>{0, 1}, Shape{2, 5}),
std::make_tuple(std::vector<int32_t>{1, 0}, Shape{2, 5}),
std::make_tuple(std::vector<int32_t>{0, 1, 2}, Shape{2, 3, 1}),
std::make_tuple(std::vector<int32_t>{1, 2, 0}, Shape{2, 3, 1}),
std::make_tuple(std::vector<int32_t>{1, 3, 2, 0}, Shape{2, 3, 1, 5})),
PrintToStringParamName());
TEST_P(TransposeEvalBoundTest, evaluate_lower) {
node_set_lower_and_upper(arg.get(), lower_v_tensor, upper_v_tensor);
node_set_lower_and_upper(order.get(), axes_v_tensor, axes_v_tensor);
const auto inputs = TensorVector{Tensor(dtype, p_shape.get_min_shape(), lower_values.data()),
const auto inputs = TensorVector{Tensor(dtype, shape, lower_values.data()),
Tensor(dtype, Shape{axes_order.size()}, axes_order.data())};
// evaluate expected values
const auto exp_evaluate = transpose->evaluate(exp_result, inputs);
@ -108,7 +110,7 @@ TEST_P(TransposeEvalBoundTest, evaluate_upper) {
node_set_lower_and_upper(arg.get(), lower_v_tensor, upper_v_tensor);
node_set_lower_and_upper(order.get(), axes_v_tensor, axes_v_tensor);
auto inputs = TensorVector{Tensor(dtype, p_shape.get_min_shape(), upper_values.data()),
auto inputs = TensorVector{Tensor(dtype, shape, upper_values.data()),
Tensor(dtype, Shape{axes_order.size()}, axes_order.data())};
// evaluate expected values
transpose->evaluate(exp_result, inputs);
@ -132,9 +134,9 @@ TEST_P(TransposeEvalBoundTest, evaluate_upper_but_order_has_no_bounds_set) {
}
TEST_P(TransposeEvalBoundTest, evaluate_label_but_empty_label_set) {
exp_result = TensorVector{Tensor(label_dtype, {0})};
exp_result = TensorVector{Tensor(label_dtype, exp_result.front().get_shape())};
labels.resize(shape_size(p_shape.get_shape()), 0);
labels.resize(shape_size(shape), 0);
arg->get_default_output().get_tensor().set_value_label(labels);
node_set_lower_and_upper(order.get(), axes_v_tensor, axes_v_tensor);
@ -143,23 +145,23 @@ TEST_P(TransposeEvalBoundTest, evaluate_label_but_empty_label_set) {
}
TEST_P(TransposeEvalBoundTest, evaluate_label_but_order_has_no_bound_set) {
exp_result = TensorVector{Tensor(label_dtype, {0})};
exp_result = TensorVector{Tensor(label_dtype, exp_result.front().get_shape())};
std::generate_n(std::back_inserter(labels), shape_size(p_shape.get_shape()), SeqGen<label_t>(30));
std::generate_n(std::back_inserter(labels), shape_size(shape), SeqGen<label_t>(30));
arg->get_default_output().get_tensor().set_value_label(labels);
ASSERT_FALSE(transpose->evaluate_label(out_labels));
}
TEST_P(TransposeEvalBoundTest, evaluate_label) {
exp_result = TensorVector{Tensor(label_dtype, {0})};
exp_result = TensorVector{Tensor(label_dtype, exp_result.front().get_shape())};
std::generate_n(std::back_inserter(labels), shape_size(p_shape.get_shape()), SeqGen<label_t>(5));
std::generate_n(std::back_inserter(labels), shape_size(shape), SeqGen<label_t>(5));
arg->get_default_output().get_tensor().set_value_label(labels);
node_set_lower_and_upper(order.get(), axes_v_tensor, axes_v_tensor);
auto inputs = TensorVector{Tensor(label_dtype, p_shape.get_shape(), labels.data()),
auto inputs = TensorVector{Tensor(label_dtype, shape, labels.data()),
Tensor(dtype, Shape{axes_order.size()}, axes_order.data())};
auto exp_eval_result = transpose->evaluate(exp_result, inputs);

View File

@ -360,7 +360,9 @@ TEST_F(OVExtensionTests, load_old_extension) {
TEST_F(OVExtensionTests, load_incorrect_extension) {
EXPECT_THROW(core.add_extension(getIncorrectExtensionPath()), ov::Exception);
}
TEST_F(OVExtensionTests, load_relative) {
EXPECT_NO_THROW(core.add_extension(getRelativeOVExtensionPath()));
}
#endif // defined(ENABLE_OV_IR_FRONTEND)

View File

@ -13,6 +13,7 @@
#include "ngraph/type/bfloat16.hpp"
#include "ngraph/type/float16.hpp"
#include "ngraph/util.hpp"
#include "tensor_conversion_util.hpp"
using namespace std;
using namespace ngraph;
@ -130,21 +131,6 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
op_outputs.push_back(host_tensor);
}
// get op type
element::Type type;
if (ov::is_type<op::Convert>(op) || ov::is_type<op::v0::PriorBox>(op) || ov::is_type<op::v8::PriorBox>(op)) {
type = op->get_input_element_type(0);
} else if (ov::is_type<op::v1::Equal>(op) || ov::is_type<op::v1::Greater>(op) ||
ov::is_type<op::v1::GreaterEqual>(op) || ov::is_type<op::v1::Less>(op) ||
ov::is_type<op::v1::LessEqual>(op) || ov::is_type<op::v1::NotEqual>(op)) {
// Get the type of the second input, not the first
// All BinaryElementwiseComparision ops have the same type for inputs
// Select has bool for first input and the type we are interested in for the second
type = op->get_input_element_type(1);
} else {
type = op->get_output_element_type(0);
}
if (m_performance_counters_enabled) {
m_timer_map[op].start();
}
@ -159,8 +145,13 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
}
}
const auto tensor_inputs = ov::util::wrap_tensors(op_inputs);
auto tensor_outputs = ov::util::wrap_tensors(op_outputs);
// Call evaluate for cloned_node with static shapes
if (!cloned_node->evaluate(op_outputs, op_inputs, eval_context)) {
if (cloned_node->evaluate(tensor_outputs, tensor_inputs, eval_context)) {
ov::util::update_output_host_tensors(op_outputs, tensor_outputs);
} else {
evaluate_node(cloned_node, op_outputs, op_inputs);
}
if (m_performance_counters_enabled) {