[TEST] Several more Loop test with static shapes
Signed-off-by: Alexander Peskov <alexander.peskov@intel.com>
This commit is contained in:
parent
3160290e13
commit
d7e3e92b64
@ -3,7 +3,6 @@
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <ngraph/op/util/attr_types.hpp>
|
||||
#include "single_layer_tests/loop.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
@ -12,9 +11,9 @@ using namespace LayerTestsDefinitions;
|
||||
namespace {
|
||||
// without clip values increase rapidly, so use only seq_lenghts = 2
|
||||
std::vector<bool> execute_first_iteration{true};
|
||||
std::vector<bool> is_body_condition_const{true, false};
|
||||
std::vector<bool> body_condition{true, false}; // works only if is_body_condition_const == true
|
||||
std::vector<int64_t> trip_count{1, 10, -1}; // -1 means infinity
|
||||
std::vector<bool> is_body_condition_const{true/*, false*/};
|
||||
std::vector<bool> body_condition{true/*, false*/}; // works only if is_body_condition_const == true
|
||||
std::vector<int64_t> trip_count{1, 10/*, -1*/}; // -1 means infinity
|
||||
std::vector<std::vector<std::pair<std::vector<size_t>, LOOP_IN_TYPE>>> inputs = {
|
||||
{{{32, 1, 10}, LOOP_IN_TYPE::INVARIANT}, {{32, 1, 10}, LOOP_IN_TYPE::INVARIANT}, {{32, 1, 10}, LOOP_IN_TYPE::MERGED}},
|
||||
};
|
||||
@ -31,4 +30,27 @@ namespace {
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
LoopTest::getTestCaseName);
|
||||
|
||||
static const std::vector<std::tuple<bool, int64_t, int64_t, int64_t>> static_loop_types = {
|
||||
// static_trip_count | max | dynamic_exit | axis
|
||||
{ true , 5, -1, -1 }, // n_iter 5, no dynamic exit
|
||||
{ true , 5, 3, -1 }, // n_iter 3, dynamic exit on 3
|
||||
{ true , 5, 7, -1 }, // n_iter 5, dynamic exit not reached
|
||||
{ true , -1, 5, -1 }, // n_iter 5, inf loop with dynamic exit on 5
|
||||
{ true , 5, -1, 1 }, // n_iter 5, const for loop with auto concatenated out
|
||||
{ false , 5, -1, -1 }, // |
|
||||
{ false , 5, 3, -1 }, // | same with dynamic trip count
|
||||
{ false , 5, 7, -1 }, // |
|
||||
{ false , -1, 5, -1 }, // |
|
||||
};
|
||||
|
||||
using namespace testing;
|
||||
INSTANTIATE_TEST_CASE_P(smoke_StaticShapeLoop, StaticShapeLoopTest,
|
||||
Combine(
|
||||
Values(true),
|
||||
ValuesIn(static_loop_types),
|
||||
Values<int64_t>(7),
|
||||
Values<InferenceEngine::SizeVector>({2, 1, 4}),
|
||||
Values<InferenceEngine::Precision>(InferenceEngine::Precision::FP32, InferenceEngine::Precision::I32),
|
||||
Values(CommonTestUtils::DEVICE_CPU)));
|
||||
} // namespace
|
||||
|
@ -29,7 +29,7 @@ using LoopParams = typename std::tuple<
|
||||
std::string>; // Device name
|
||||
|
||||
class LoopTest : public testing::WithParamInterface<LoopParams>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<LoopParams> &obj);
|
||||
|
||||
@ -37,4 +37,46 @@ protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
|
||||
using StaticShapeLoopParams = typename std::tuple<
|
||||
bool,
|
||||
std::tuple<
|
||||
bool,
|
||||
int64_t,
|
||||
int64_t,
|
||||
int64_t
|
||||
>,
|
||||
int64_t,
|
||||
InferenceEngine::SizeVector,
|
||||
InferenceEngine::Precision,
|
||||
std::string
|
||||
>;
|
||||
|
||||
/**
|
||||
* Test case with static SHAPE version of loop operation.
|
||||
* Total iteration count is dynamic.
|
||||
*/
|
||||
class StaticShapeLoopTest : public testing::WithParamInterface<StaticShapeLoopParams>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<StaticShapeLoopParams> &obj);
|
||||
InferenceEngine::Blob::Ptr GenerateInput(const InferenceEngine::InputInfo &info) const override;
|
||||
std::vector<std::vector<std::uint8_t>> CalculateRefs() override;
|
||||
|
||||
private:
|
||||
bool static_iter_num; // trip count provided by constant node
|
||||
bool static_continue_cond; // initial_cond provided by constant node
|
||||
int64_t max_iter_num; // -1 means infinity loop (expected dynamic exit condition in body)
|
||||
int64_t dynamic_exit; // -1 means always true
|
||||
int64_t axis; // -1 means no auto concatenation
|
||||
int64_t start_value;
|
||||
InferenceEngine::SizeVector data_shape;
|
||||
InferenceEngine::Precision data_prc;
|
||||
|
||||
int64_t actual_n_iter();
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
@ -157,5 +157,139 @@ namespace LayerTestsDefinitions {
|
||||
|
||||
TEST_P(LoopTest, CompareWithRefs) {
|
||||
Run();
|
||||
};
|
||||
}
|
||||
|
||||
void StaticShapeLoopTest::SetUp() {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
SetRefMode(LayerTestsUtils::IE);
|
||||
|
||||
auto args_papck = std::tie(static_iter_num, max_iter_num, dynamic_exit, axis);
|
||||
std::tie(
|
||||
static_continue_cond,
|
||||
args_papck,
|
||||
start_value,
|
||||
data_shape,
|
||||
data_prc,
|
||||
targetDevice) = GetParam();
|
||||
|
||||
const auto prc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(data_prc);
|
||||
const auto ngShape = ngraph::Shape{data_shape};
|
||||
const auto scalarShape = ngraph::Shape{};
|
||||
|
||||
ngraph::ParameterVector params{};
|
||||
auto cond_input_create = [¶ms] (ngraph::element::Type prc, const ngraph::Shape &shape, int value = 0, bool is_static = false)
|
||||
-> std::shared_ptr<ngraph::Node> {
|
||||
if (is_static)
|
||||
return std::make_shared<ngraph::opset5::Constant>(prc, shape, value);
|
||||
|
||||
auto input = std::make_shared<ngraph::op::Parameter>(prc, shape);
|
||||
params.push_back(input);
|
||||
return input;
|
||||
};
|
||||
|
||||
auto start = cond_input_create(prc, ngShape);
|
||||
auto count = cond_input_create(ngraph::element::i64, scalarShape, max_iter_num, static_iter_num);
|
||||
auto skip = cond_input_create(ngraph::element::boolean, scalarShape, true, static_continue_cond);
|
||||
|
||||
//
|
||||
// count skip start count skip start
|
||||
// / /
|
||||
// ___*___*____ __________*___*____ | idx | data | out |
|
||||
// | idx in | | ex_val idx in | | 0 | 7 | 7 |
|
||||
// | | / | | | / | / | | 1 | 7 | 8 |
|
||||
// | add | | less add | | 2 | 8 | 10 |
|
||||
// | | true | | | | | | 3 | 10 | 13 |
|
||||
// | | | | | | | | ~~~~~ * * * ~~~~~
|
||||
// | out cnd | | cnd out |
|
||||
// |___*____*___| |____*_____*________|
|
||||
// Full loop Dynamic exit loop
|
||||
// n_iter = count n_iter = ex_val
|
||||
//
|
||||
auto b_indx = std::make_shared<ngraph::op::Parameter>(ngraph::element::i64, ngraph::Shape{});
|
||||
auto b_data = std::make_shared<ngraph::op::Parameter>(prc, ngShape);
|
||||
auto b_indx_cast = std::make_shared<ngraph::op::Convert>(b_indx, prc);
|
||||
auto b_add = std::make_shared<ngraph::op::Add>(b_data, b_indx_cast, ngraph::op::AutoBroadcastSpec::NUMPY);
|
||||
|
||||
std::shared_ptr<ngraph::Node> b_cond;
|
||||
if (dynamic_exit == -1) {
|
||||
b_cond = std::make_shared<ngraph::opset5::Constant>(ngraph::element::boolean, ngraph::Shape{}, true);
|
||||
} else {
|
||||
auto b_exit_value = std::make_shared<ngraph::opset5::Constant>(ngraph::element::i64, scalarShape, dynamic_exit);
|
||||
b_cond = std::make_shared<ngraph::opset5::Less>(b_indx, b_exit_value);
|
||||
}
|
||||
|
||||
auto body = std::make_shared<ngraph::Function>(
|
||||
ngraph::OutputVector {b_cond, b_add}, // TODO: check with reverse
|
||||
ngraph::ParameterVector {b_indx, b_data}); // TODO: check with reverse
|
||||
|
||||
auto loop = std::make_shared<ngraph::opset5::Loop>(count, skip);
|
||||
loop->set_function(body);
|
||||
loop->set_special_body_ports({0, 0});
|
||||
loop->set_merged_input(b_data, start, b_add);
|
||||
if (axis == -1)
|
||||
loop->get_iter_value(b_add, -1);
|
||||
else
|
||||
loop->get_concatenated_slices(b_add, 0, 1, 1, -1, axis);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(
|
||||
ngraph::OutputVector {loop},
|
||||
params);
|
||||
}
|
||||
|
||||
InferenceEngine::Blob::Ptr StaticShapeLoopTest::GenerateInput(const InferenceEngine::InputInfo &info) const {
|
||||
auto tdesc = info.getTensorDesc();
|
||||
auto blob = make_blob_with_precision(tdesc);
|
||||
blob->allocate();
|
||||
|
||||
if (tdesc.getLayout() == InferenceEngine::SCALAR) {
|
||||
auto scalar_1d = CommonTestUtils::make_reshape_view(blob, {1});
|
||||
CommonTestUtils::fill_data_with_broadcast(scalar_1d, 0, {static_cast<float>(max_iter_num)});
|
||||
} else {
|
||||
CommonTestUtils::fill_data_with_broadcast(blob, 0, {static_cast<float>(start_value)});
|
||||
}
|
||||
|
||||
return blob;
|
||||
}
|
||||
|
||||
int64_t StaticShapeLoopTest::actual_n_iter() {
|
||||
constexpr auto INF_N_ITER = std::numeric_limits<int64_t>::max();
|
||||
IE_ASSERT(dynamic_exit != -1 || max_iter_num != -1);
|
||||
|
||||
// dynamic_exit + 1 - because loop body looks like do-while loop with post condition check.
|
||||
return std::min(dynamic_exit == -1 ? INF_N_ITER : dynamic_exit + 1,
|
||||
max_iter_num == -1 ? INF_N_ITER : max_iter_num);
|
||||
}
|
||||
|
||||
// Predefined ref output
|
||||
std::vector<std::vector<std::uint8_t>> StaticShapeLoopTest::CalculateRefs() {
|
||||
bool auto_concat_out = (axis != -1);
|
||||
const auto n_iter = actual_n_iter();
|
||||
|
||||
auto ref_shape = data_shape;
|
||||
if (auto_concat_out)
|
||||
ref_shape[axis] *= n_iter;
|
||||
|
||||
using namespace CommonTestUtils;
|
||||
InferenceEngine::TensorDesc tdesc {data_prc, ref_shape, InferenceEngine::TensorDesc::getLayoutByDims(ref_shape)};
|
||||
std::vector<uint8_t> res(byte_size(tdesc));
|
||||
auto out = make_blob_with_precision(tdesc, res.data());
|
||||
|
||||
std::vector<float> vals(n_iter);
|
||||
float val = start_value;
|
||||
for (int i = 0; i < n_iter; i++) {
|
||||
val += i;
|
||||
vals[i] = val;
|
||||
}
|
||||
|
||||
if (auto_concat_out)
|
||||
fill_data_with_broadcast(out, axis, vals);
|
||||
else
|
||||
fill_data_with_broadcast(out, 0, {val});
|
||||
|
||||
return {res};
|
||||
}
|
||||
|
||||
TEST_P(StaticShapeLoopTest, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
} // namespace LayerTestsDefinitions
|
||||
|
Loading…
Reference in New Issue
Block a user