Refactor StaticShapeLoopLayerTest (#20963)
This commit is contained in:
parent
bcb38796ce
commit
4bde741de4
@ -2,139 +2,124 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "single_layer_tests/loop.hpp"
|
||||
#include "single_op_tests/loop.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
using namespace InferenceEngine;
|
||||
|
||||
namespace {
|
||||
std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::I32
|
||||
};
|
||||
using ov::test::StaticShapeLoopLayerTest;
|
||||
|
||||
std::map<std::string, std::string> netConfigurations = {
|
||||
{GPUConfigParams::KEY_GPU_ENABLE_LOOP_UNROLLING, PluginConfigParams::NO}
|
||||
};
|
||||
std::vector<ov::element::Type> model_types = {
|
||||
ov::element::f32,
|
||||
ov::element::i32
|
||||
};
|
||||
|
||||
static const std::vector<std::tuple<bool, int64_t, int64_t, int64_t>> static_loop_types_axis_0 {
|
||||
// GCC4.8 limitation: have to specify type of each element in list
|
||||
// static_trip_count | max | dynamic_exit | axis
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ true , 10, -1, 0 }, // n_iter 10, no dynamic exit
|
||||
};
|
||||
static const std::vector<std::tuple<bool, int64_t, int64_t, int64_t>> static_loop_types_axis_0 {
|
||||
// GCC4.8 limitation: have to specify type of each element in list
|
||||
// static_trip_count | max | dynamic_exit | axis
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ true , 10, -1, 0 }, // n_iter 10, no dynamic exit
|
||||
};
|
||||
|
||||
std::vector<InferenceEngine::SizeVector> inputs_0 = {
|
||||
{1, 4, 2}
|
||||
};
|
||||
std::vector<ov::Shape> inputs_0 = {
|
||||
{1, 4, 2}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_StaticShapeLoop_axis_0, StaticShapeLoopTest,
|
||||
testing::Combine(
|
||||
/* unrolling */ testing::ValuesIn(std::vector<bool>{false}),
|
||||
/* static_continue_cond */ testing::Values(true),
|
||||
/* args_papck */ testing::ValuesIn(static_loop_types_axis_0),
|
||||
/* start_value */ testing::Values<int64_t>(0),
|
||||
/* data_shape */ testing::ValuesIn(inputs_0),
|
||||
/* data_prc */ testing::ValuesIn(netPrecisions),
|
||||
/* device */ testing::Values<std::string>(ov::test::utils::DEVICE_GPU),
|
||||
/* configuration */ testing::Values<std::map<std::string, std::string>>(netConfigurations)),
|
||||
StaticShapeLoopTest::getTestCaseName);
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_StaticShapeLoop_axis_0, StaticShapeLoopLayerTest,
|
||||
testing::Combine(
|
||||
/* unrolling */ testing::ValuesIn(std::vector<bool>{false}),
|
||||
/* static_continue_cond */ testing::Values(true),
|
||||
/* args_papck */ testing::ValuesIn(static_loop_types_axis_0),
|
||||
/* start_value */ testing::Values<int64_t>(0),
|
||||
/* data_shape */ testing::ValuesIn(inputs_0),
|
||||
/* data_prc */ testing::ValuesIn(model_types),
|
||||
/* device */ testing::Values<std::string>(ov::test::utils::DEVICE_GPU)),
|
||||
StaticShapeLoopLayerTest::getTestCaseName);
|
||||
|
||||
static const std::vector<std::tuple<bool, int64_t, int64_t, int64_t>> static_loop_types_1 {
|
||||
// GCC4.8 limitation: have to specify type of each element in list
|
||||
// static_trip_count | max | dynamic_exit | axis
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ true , 5, -1, 1 }, // n_iter 5, no dynamic exit
|
||||
};
|
||||
static const std::vector<std::tuple<bool, int64_t, int64_t, int64_t>> static_loop_types_1 {
|
||||
// GCC4.8 limitation: have to specify type of each element in list
|
||||
// static_trip_count | max | dynamic_exit | axis
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ true , 5, -1, 1 }, // n_iter 5, no dynamic exit
|
||||
};
|
||||
|
||||
std::vector<InferenceEngine::SizeVector> inputs_1 = {
|
||||
{2, 1, 4, 6}
|
||||
};
|
||||
std::vector<ov::Shape> inputs_1 = {
|
||||
{2, 1, 4, 6}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_StaticShapeLoop_axis_1, StaticShapeLoopTest,
|
||||
testing::Combine(
|
||||
/* unrolling */ testing::ValuesIn(std::vector<bool>{false}),
|
||||
/* static_continue_cond */ testing::Values(true),
|
||||
/* args_papck */ testing::ValuesIn(static_loop_types_1),
|
||||
/* start_value */ testing::Values<int64_t>(0),
|
||||
/* data_shape */ testing::ValuesIn(inputs_1),
|
||||
/* data_prc */ testing::ValuesIn(netPrecisions),
|
||||
/* device */ testing::Values<std::string>(ov::test::utils::DEVICE_GPU),
|
||||
/* configuration */ testing::Values<std::map<std::string, std::string>>(netConfigurations)),
|
||||
StaticShapeLoopTest::getTestCaseName);
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_StaticShapeLoop_axis_1, StaticShapeLoopLayerTest,
|
||||
testing::Combine(
|
||||
/* unrolling */ testing::ValuesIn(std::vector<bool>{false}),
|
||||
/* static_continue_cond */ testing::Values(true),
|
||||
/* args_papck */ testing::ValuesIn(static_loop_types_1),
|
||||
/* start_value */ testing::Values<int64_t>(0),
|
||||
/* data_shape */ testing::ValuesIn(inputs_1),
|
||||
/* data_prc */ testing::ValuesIn(model_types),
|
||||
/* device */ testing::Values<std::string>(ov::test::utils::DEVICE_GPU)),
|
||||
StaticShapeLoopLayerTest::getTestCaseName);
|
||||
|
||||
static const std::vector<std::tuple<bool, int64_t, int64_t, int64_t>> static_loop_types_2 {
|
||||
// GCC4.8 limitation: have to specify type of each element in list
|
||||
// static_trip_count | max | dynamic_exit | axis
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ true , 10, -1, 2 }, // n_iter 10, no dynamic exit
|
||||
};
|
||||
static const std::vector<std::tuple<bool, int64_t, int64_t, int64_t>> static_loop_types_2 {
|
||||
// GCC4.8 limitation: have to specify type of each element in list
|
||||
// static_trip_count | max | dynamic_exit | axis
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ true , 10, -1, 2 }, // n_iter 10, no dynamic exit
|
||||
};
|
||||
|
||||
std::vector<InferenceEngine::SizeVector> inputs_2 = {
|
||||
{2, 4, 1, 6}
|
||||
};
|
||||
std::vector<ov::Shape> inputs_2 = {
|
||||
{2, 4, 1, 6}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_StaticShapeLoop_axis_2, StaticShapeLoopTest,
|
||||
testing::Combine(
|
||||
/* unrolling */ testing::ValuesIn(std::vector<bool>{false}),
|
||||
/* static_continue_cond */ testing::Values(true),
|
||||
/* args_papck */ testing::ValuesIn(static_loop_types_2),
|
||||
/* start_value */ testing::Values<int64_t>(0),
|
||||
/* data_shape */ testing::ValuesIn(inputs_2),
|
||||
/* data_prc */ testing::ValuesIn(netPrecisions),
|
||||
/* device */ testing::Values<std::string>(ov::test::utils::DEVICE_GPU),
|
||||
/* configuration */ testing::Values<std::map<std::string, std::string>>(netConfigurations)),
|
||||
StaticShapeLoopTest::getTestCaseName);
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_StaticShapeLoop_axis_2, StaticShapeLoopLayerTest,
|
||||
testing::Combine(
|
||||
/* unrolling */ testing::ValuesIn(std::vector<bool>{false}),
|
||||
/* static_continue_cond */ testing::Values(true),
|
||||
/* args_papck */ testing::ValuesIn(static_loop_types_2),
|
||||
/* start_value */ testing::Values<int64_t>(0),
|
||||
/* data_shape */ testing::ValuesIn(inputs_2),
|
||||
/* data_prc */ testing::ValuesIn(model_types),
|
||||
/* device */ testing::Values<std::string>(ov::test::utils::DEVICE_GPU)),
|
||||
StaticShapeLoopLayerTest::getTestCaseName);
|
||||
|
||||
static const std::vector<std::tuple<bool, int64_t, int64_t, int64_t>> static_loop_types_no_auto_concat {
|
||||
// GCC4.8 limitation: have to specify type of each element in list
|
||||
// static_trip_count | max | dynamic_exit | axis
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ true , 10, -1, -1 }, // n_iter 5, no dynamic exit
|
||||
};
|
||||
static const std::vector<std::tuple<bool, int64_t, int64_t, int64_t>> static_loop_types_no_auto_concat {
|
||||
// GCC4.8 limitation: have to specify type of each element in list
|
||||
// static_trip_count | max | dynamic_exit | axis
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ true , 10, -1, -1 }, // n_iter 5, no dynamic exit
|
||||
};
|
||||
|
||||
std::vector<InferenceEngine::SizeVector> inputs_no_auto_concat = {
|
||||
{4, 20, 12}
|
||||
};
|
||||
std::vector<ov::Shape> inputs_no_auto_concat = {
|
||||
{4, 20, 12}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_StaticShapeLoop_no_auto_concat, StaticShapeLoopTest,
|
||||
testing::Combine(
|
||||
/* unrolling */ testing::ValuesIn(std::vector<bool>{false}),
|
||||
/* static_continue_cond */ testing::Values(true),
|
||||
/* args_papck */ testing::ValuesIn(static_loop_types_no_auto_concat),
|
||||
/* start_value */ testing::Values<int64_t>(0),
|
||||
/* data_shape */ testing::ValuesIn(inputs_no_auto_concat),
|
||||
/* data_prc */ testing::ValuesIn(netPrecisions),
|
||||
/* device */ testing::Values<std::string>(ov::test::utils::DEVICE_GPU),
|
||||
/* configuration */ testing::Values<std::map<std::string, std::string>>(netConfigurations)),
|
||||
StaticShapeLoopTest::getTestCaseName);
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_StaticShapeLoop_no_auto_concat, StaticShapeLoopLayerTest,
|
||||
testing::Combine(
|
||||
/* unrolling */ testing::ValuesIn(std::vector<bool>{false}),
|
||||
/* static_continue_cond */ testing::Values(true),
|
||||
/* args_papck */ testing::ValuesIn(static_loop_types_no_auto_concat),
|
||||
/* start_value */ testing::Values<int64_t>(0),
|
||||
/* data_shape */ testing::ValuesIn(inputs_no_auto_concat),
|
||||
/* data_prc */ testing::ValuesIn(model_types),
|
||||
/* device */ testing::Values<std::string>(ov::test::utils::DEVICE_GPU)),
|
||||
StaticShapeLoopLayerTest::getTestCaseName);
|
||||
|
||||
static const std::vector<std::tuple<bool, int64_t, int64_t, int64_t>> static_loop_types_dynamic_exit {
|
||||
// GCC4.8 limitation: have to specify type of each element in list
|
||||
// static_trip_count | max | dynamic_exit | axis
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ true , 5, 3, -1 }, // n_iter 3, dynamic exit on 3
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ true , 5, 7, 1 }, // n_iter 5, dynamic exit not reached
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ true , -1, 5, -1 }, // n_iter 5, inf loop with dynamic exit on 5
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ false , 5, 3, -1 }, // | same with dynamic trip count
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ false , 5, 7, 1 }, // |
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ false , -1, 5, -1 } // |
|
||||
};
|
||||
static const std::vector<std::tuple<bool, int64_t, int64_t, int64_t>> static_loop_types_dynamic_exit {
|
||||
// GCC4.8 limitation: have to specify type of each element in list
|
||||
// static_trip_count | max | dynamic_exit | axis
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ true , 5, 3, -1 }, // n_iter 3, dynamic exit on 3
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ true , 5, 7, 1 }, // n_iter 5, dynamic exit not reached
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ true , -1, 5, -1 }, // n_iter 5, inf loop with dynamic exit on 5
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ false , 5, 3, -1 }, // | same with dynamic trip count
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ false , 5, 7, 1 }, // |
|
||||
std::tuple<bool, int64_t, int64_t, int64_t>{ false , -1, 5, -1 } // |
|
||||
};
|
||||
|
||||
std::vector<InferenceEngine::SizeVector> inputs_dynamic_exit = {
|
||||
{4, 1, 2}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_StaticShapeLoop_dynamic_exit, StaticShapeLoopTest,
|
||||
testing::Combine(
|
||||
/* unrolling */ testing::ValuesIn(std::vector<bool>{false}),
|
||||
/* static_continue_cond */ testing::Values(true),
|
||||
/* args_papck */ testing::ValuesIn(static_loop_types_dynamic_exit),
|
||||
/* start_value */ testing::Values<int64_t>(0),
|
||||
/* data_shape */ testing::ValuesIn(inputs_dynamic_exit),
|
||||
/* data_prc */ testing::ValuesIn(netPrecisions),
|
||||
/* device */ testing::Values<std::string>(ov::test::utils::DEVICE_GPU),
|
||||
/* configuration */ testing::Values<std::map<std::string, std::string>>(netConfigurations)),
|
||||
StaticShapeLoopTest::getTestCaseName);
|
||||
std::vector<ov::Shape> inputs_dynamic_exit = {
|
||||
{4, 1, 2}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_StaticShapeLoop_dynamic_exit, StaticShapeLoopLayerTest,
|
||||
testing::Combine(
|
||||
/* unrolling */ testing::ValuesIn(std::vector<bool>{false}),
|
||||
/* static_continue_cond */ testing::Values(true),
|
||||
/* args_papck */ testing::ValuesIn(static_loop_types_dynamic_exit),
|
||||
/* start_value */ testing::Values<int64_t>(0),
|
||||
/* data_shape */ testing::ValuesIn(inputs_dynamic_exit),
|
||||
/* data_prc */ testing::ValuesIn(model_types),
|
||||
/* device */ testing::Values<std::string>(ov::test::utils::DEVICE_GPU)),
|
||||
StaticShapeLoopLayerTest::getTestCaseName);
|
||||
} // namespace
|
||||
|
@ -11,5 +11,9 @@ namespace test {
|
||||
TEST_P(LoopLayerTest, Inference) {
|
||||
run();
|
||||
}
|
||||
|
||||
TEST_P(StaticShapeLoopLayerTest, Inference) {
|
||||
run();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
@ -32,6 +32,33 @@ class LoopLayerTest : public testing::WithParamInterface<LoopParams>,
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<LoopParams> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
using StaticShapeLoopParams = typename std::tuple<
|
||||
bool,
|
||||
bool,
|
||||
std::tuple<
|
||||
bool,
|
||||
int64_t,
|
||||
int64_t,
|
||||
int64_t
|
||||
>,
|
||||
int64_t,
|
||||
ov::Shape,
|
||||
ov::element::Type,
|
||||
std::string>;
|
||||
|
||||
/**
|
||||
* Test case with static SHAPE version of loop operation.
|
||||
* Total iteration count is dynamic.
|
||||
*/
|
||||
class StaticShapeLoopLayerTest : public testing::WithParamInterface<StaticShapeLoopParams>,
|
||||
virtual public ov::test::SubgraphBaseStaticTest {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<StaticShapeLoopParams> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
@ -4,11 +4,15 @@
|
||||
|
||||
#include "shared_test_classes/single_op/loop.hpp"
|
||||
|
||||
#include "transformations/control_flow/unroll_tensor_iterator.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/result.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/loop.hpp"
|
||||
#include "openvino/op/less.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "transformations/control_flow/unroll_tensor_iterator.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
@ -121,5 +125,132 @@ void LoopLayerTest::SetUp() {
|
||||
auto result2 = std::make_shared<ov::op::v0::Result>(out2);
|
||||
function = std::make_shared<ov::Model>(ov::ResultVector{result0, result1, result2}, params, "loop");
|
||||
}
|
||||
|
||||
std::string StaticShapeLoopLayerTest::getTestCaseName(const testing::TestParamInfo<StaticShapeLoopParams> &obj) {
|
||||
bool unrolling;
|
||||
bool static_iter_num;
|
||||
bool static_continue_cond;
|
||||
int64_t max_iter_num;
|
||||
int64_t dynamic_exit;
|
||||
int64_t axis;
|
||||
int64_t start_value;
|
||||
ov::Shape data_shape;
|
||||
ov::element::Type model_type;
|
||||
std::string target_device;
|
||||
auto args_papck = std::tie(static_iter_num, max_iter_num, dynamic_exit, axis);
|
||||
std::tie(
|
||||
unrolling,
|
||||
static_continue_cond,
|
||||
args_papck,
|
||||
start_value,
|
||||
data_shape,
|
||||
model_type,
|
||||
target_device) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "unrolling=" << std::to_string(unrolling) << "_";
|
||||
result << "static_iter_num=" << std::to_string(static_iter_num) << "_";
|
||||
result << "static_continue_cond=" << std::to_string(static_continue_cond) << "_";
|
||||
result << "max_iter_num=" << std::to_string(max_iter_num) << "_";
|
||||
result << "dynamic_exit=" << std::to_string(dynamic_exit) << "_";
|
||||
result << "axis=" << std::to_string(axis) << "_";
|
||||
result << "start_value=" << std::to_string(start_value) << "_";
|
||||
result << "max_iter_num=" << std::to_string(max_iter_num) << "_";
|
||||
result << "IS=" << ov::test::utils::vec2str(data_shape) << "_";
|
||||
result << "modelType=" << model_type.get_type_name() << "_";
|
||||
result << "targetDevice=" << target_device << "_";
|
||||
|
||||
auto res_str = result.str();
|
||||
std::replace(res_str.begin(), res_str.end(), '-', '_');
|
||||
return res_str;
|
||||
}
|
||||
|
||||
void StaticShapeLoopLayerTest::SetUp() {
|
||||
bool unrolling;
|
||||
bool static_iter_num;
|
||||
bool static_continue_cond;
|
||||
int64_t max_iter_num;
|
||||
int64_t dynamic_exit;
|
||||
int64_t axis;
|
||||
int64_t start_value;
|
||||
ov::Shape data_shape;
|
||||
ov::element::Type model_type;
|
||||
auto args_papck = std::tie(static_iter_num, max_iter_num, dynamic_exit, axis);
|
||||
std::tie(
|
||||
unrolling,
|
||||
static_continue_cond,
|
||||
args_papck,
|
||||
start_value,
|
||||
data_shape,
|
||||
model_type,
|
||||
targetDevice) = GetParam();
|
||||
|
||||
const auto ngShape = ov::Shape{data_shape};
|
||||
const auto scalarShape = ov::Shape{};
|
||||
|
||||
ngraph::ParameterVector params{};
|
||||
auto cond_input_create = [¶ms] (ov::element::Type model_type, const ov::Shape &shape, int value = 0, bool is_static = false)
|
||||
-> std::shared_ptr<ov::Node> {
|
||||
if (is_static)
|
||||
return std::make_shared<ov::op::v0::Constant>(model_type, shape, value);
|
||||
|
||||
auto input = std::make_shared<ov::op::v0::Parameter>(model_type, shape);
|
||||
params.push_back(input);
|
||||
return input;
|
||||
};
|
||||
|
||||
auto start = cond_input_create(model_type, ngShape);
|
||||
auto count = cond_input_create(ov::element::i64, scalarShape, max_iter_num, static_iter_num);
|
||||
auto skip = cond_input_create(ov::element::boolean, scalarShape, true, static_continue_cond);
|
||||
|
||||
//
|
||||
// count skip start count skip start
|
||||
// / /
|
||||
// ___*___*____ __________*___*____ | idx | data | out |
|
||||
// | idx in | | ex_val idx in | | 0 | 7 | 7 |
|
||||
// | | / | | | / | / | | 1 | 7 | 8 |
|
||||
// | add | | less add | | 2 | 8 | 10 |
|
||||
// | | true | | | | | | 3 | 10 | 13 |
|
||||
// | | | | | | | | ~~~~~ * * * ~~~~~
|
||||
// | out cnd | | cnd out |
|
||||
// |___*____*___| |____*_____*________|
|
||||
// Full loop Dynamic exit loop
|
||||
// n_iter = count n_iter = ex_val
|
||||
//
|
||||
auto b_indx = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{});
|
||||
auto b_data = std::make_shared<ov::op::v0::Parameter>(model_type, ngShape);
|
||||
auto b_indx_cast = std::make_shared<ov::op::v0::Convert>(b_indx, model_type);
|
||||
auto b_add = std::make_shared<ov::op::v1::Add>(b_data, b_indx_cast);
|
||||
|
||||
std::shared_ptr<ov::Node> b_cond;
|
||||
if (dynamic_exit == -1) {
|
||||
b_cond = std::make_shared<ov::op::v0::Constant>(ov::element::boolean, ov::Shape{}, true);
|
||||
} else {
|
||||
auto b_exit_value = std::make_shared<ov::op::v0::Constant>(ov::element::i64, scalarShape, dynamic_exit);
|
||||
b_cond = std::make_shared<ov::op::v1::Less>(b_indx, b_exit_value);
|
||||
}
|
||||
|
||||
auto body = std::make_shared<ov::Model>(
|
||||
ov::OutputVector {b_cond, b_add}, // TODO: check with reverse
|
||||
ov::ParameterVector {b_indx, b_data}); // TODO: check with reverse
|
||||
|
||||
auto loop = std::make_shared<ov::op::v5::Loop>(count, skip);
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports({0, 0});
|
||||
loop->set_merged_input(b_data, start, b_add);
|
||||
if (axis == -1)
|
||||
loop->get_iter_value(b_add, -1);
|
||||
else
|
||||
loop->get_concatenated_slices(b_add, 0, 1, 1, -1, axis);
|
||||
|
||||
function = std::make_shared<ov::Model>(
|
||||
ov::OutputVector {loop},
|
||||
params);
|
||||
if (unrolling) {
|
||||
ov::pass::Manager manager;
|
||||
manager.register_pass<ov::pass::UnrollTensorIterator>();
|
||||
manager.run_passes(function);
|
||||
}
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user