Add slt in template plugin/experimental detectron grid generator (#8923)
* Remove fp16 of Convert layer test from skip_tests.config.cpp as it works now * update repo * add op reference test of ExperimentalDetectronPriorGridGenerator * implement actual_comparision_size for compare * update slt for actual comparison size and add visitor api test * fixed clang error
This commit is contained in:
parent
92760949bf
commit
9faf661250
@ -71,92 +71,95 @@ void CommonReferenceTest::Validate() {
|
||||
|
||||
ASSERT_EQ(refOutData.size(), actualOutData.size());
|
||||
for (size_t i = 0; i < refOutData.size(); i++) {
|
||||
ValidateBlobs(refOutData[i], actualOutData[i], threshold, abs_threshold);
|
||||
ValidateBlobs(refOutData[i], actualOutData[i], threshold, abs_threshold, actual_comparision_size);
|
||||
}
|
||||
}
|
||||
|
||||
void CommonReferenceTest::ValidateBlobs(const ov::runtime::Tensor& refBlob, const ov::runtime::Tensor& outBlob,
|
||||
float threshold, float abs_threshold) {
|
||||
float threshold, float abs_threshold, size_t actual_comparision_size) {
|
||||
ASSERT_EQ(refBlob.get_element_type(), outBlob.get_element_type());
|
||||
ASSERT_EQ(refBlob.get_byte_size(), outBlob.get_byte_size());
|
||||
|
||||
if (actual_comparision_size == 0)
|
||||
actual_comparision_size = refBlob.get_size();
|
||||
|
||||
const auto& element_type = refBlob.get_element_type();
|
||||
switch (element_type) {
|
||||
case ov::element::bf16:
|
||||
LayerTestsUtils::LayerTestsCommon::Compare<ov::bfloat16, ov::bfloat16>(
|
||||
refBlob.data<const ov::bfloat16>(), outBlob.data<const ov::bfloat16>(),
|
||||
refBlob.get_size(), threshold, abs_threshold);
|
||||
actual_comparision_size, threshold, abs_threshold);
|
||||
break;
|
||||
case ov::element::f16:
|
||||
LayerTestsUtils::LayerTestsCommon::Compare<ov::float16, ov::float16>(
|
||||
refBlob.data<const ov::float16>(), outBlob.data<const ov::float16>(),
|
||||
refBlob.get_size(), threshold, abs_threshold);
|
||||
actual_comparision_size, threshold, abs_threshold);
|
||||
break;
|
||||
case ov::element::f32:
|
||||
LayerTestsUtils::LayerTestsCommon::Compare<float, float>(
|
||||
refBlob.data<const float>(), outBlob.data<const float>(),
|
||||
refBlob.get_size(), threshold, abs_threshold);
|
||||
actual_comparision_size, threshold, abs_threshold);
|
||||
break;
|
||||
case ov::element::f64:
|
||||
LayerTestsUtils::LayerTestsCommon::Compare<double, double>(
|
||||
refBlob.data<const double>(), outBlob.data<const double>(),
|
||||
refBlob.get_size(), threshold, abs_threshold);
|
||||
actual_comparision_size, threshold, abs_threshold);
|
||||
break;
|
||||
case ov::element::i8:
|
||||
LayerTestsUtils::LayerTestsCommon::Compare<int8_t, int8_t>(
|
||||
refBlob.data<const int8_t>(), outBlob.data<const int8_t>(),
|
||||
refBlob.get_size(), threshold, abs_threshold);
|
||||
actual_comparision_size, threshold, abs_threshold);
|
||||
break;
|
||||
case ov::element::i16:
|
||||
LayerTestsUtils::LayerTestsCommon::Compare<int16_t, int16_t>(
|
||||
refBlob.data<const int16_t>(), outBlob.data<const int16_t>(),
|
||||
refBlob.get_size(), threshold, abs_threshold);
|
||||
actual_comparision_size, threshold, abs_threshold);
|
||||
break;
|
||||
case ov::element::i32:
|
||||
LayerTestsUtils::LayerTestsCommon::Compare<int32_t, int32_t>(
|
||||
refBlob.data<const int32_t>(), outBlob.data<const int32_t>(),
|
||||
refBlob.get_size(), threshold, abs_threshold);
|
||||
actual_comparision_size, threshold, abs_threshold);
|
||||
break;
|
||||
case ov::element::i64:
|
||||
LayerTestsUtils::LayerTestsCommon::Compare<int64_t, int64_t>(
|
||||
refBlob.data<const int64_t>(), outBlob.data<const int64_t>(),
|
||||
refBlob.get_size(), threshold, abs_threshold);
|
||||
actual_comparision_size, threshold, abs_threshold);
|
||||
break;
|
||||
case ov::element::boolean:
|
||||
LayerTestsUtils::LayerTestsCommon::Compare<bool, bool>(
|
||||
refBlob.data<const bool>(), outBlob.data<const bool>(),
|
||||
refBlob.get_size(), threshold, abs_threshold);
|
||||
actual_comparision_size, threshold, abs_threshold);
|
||||
break;
|
||||
case ov::element::u8:
|
||||
LayerTestsUtils::LayerTestsCommon::Compare<uint8_t, uint8_t>(
|
||||
refBlob.data<const uint8_t>(), outBlob.data<const uint8_t>(),
|
||||
refBlob.get_size(), threshold, abs_threshold);
|
||||
actual_comparision_size, threshold, abs_threshold);
|
||||
break;
|
||||
case ov::element::u16:
|
||||
LayerTestsUtils::LayerTestsCommon::Compare<uint16_t, uint16_t>(
|
||||
refBlob.data<const uint16_t>(), outBlob.data<const uint16_t>(),
|
||||
refBlob.get_size(), threshold, abs_threshold);
|
||||
actual_comparision_size, threshold, abs_threshold);
|
||||
break;
|
||||
case ov::element::u32:
|
||||
LayerTestsUtils::LayerTestsCommon::Compare<uint32_t, uint32_t>(
|
||||
refBlob.data<const uint32_t>(), outBlob.data<const uint32_t>(),
|
||||
refBlob.get_size(), threshold, abs_threshold);
|
||||
actual_comparision_size, threshold, abs_threshold);
|
||||
break;
|
||||
case ov::element::u64:
|
||||
LayerTestsUtils::LayerTestsCommon::Compare<uint64_t, uint64_t>(
|
||||
refBlob.data<const uint64_t>(), outBlob.data<const uint64_t>(),
|
||||
refBlob.get_size(), threshold, abs_threshold);
|
||||
actual_comparision_size, threshold, abs_threshold);
|
||||
break;
|
||||
case ov::element::i4:
|
||||
case ov::element::u4:
|
||||
LayerTestsUtils::LayerTestsCommon::Compare<int8_t, int8_t>(
|
||||
static_cast<const int8_t*>(refBlob.data()), static_cast<const int8_t*>(outBlob.data()),
|
||||
refBlob.get_size() / 2, threshold, abs_threshold);
|
||||
actual_comparision_size / 2, threshold, abs_threshold);
|
||||
break;
|
||||
case ov::element::u1:
|
||||
LayerTestsUtils::LayerTestsCommon::Compare<int8_t, int8_t>(
|
||||
static_cast<const int8_t*>(refBlob.data()), static_cast<const int8_t*>(outBlob.data()),
|
||||
refBlob.get_size() / 8, threshold, abs_threshold);
|
||||
actual_comparision_size / 8, threshold, abs_threshold);
|
||||
break;
|
||||
default:
|
||||
FAIL() << "Comparator for " << element_type << " element type isn't supported";
|
||||
|
@ -23,7 +23,7 @@ public:
|
||||
virtual void Validate();
|
||||
|
||||
static void ValidateBlobs(const ov::runtime::Tensor& refBlob, const ov::runtime::Tensor& outBlob,
|
||||
float threshold, float abs_threshold);
|
||||
float threshold, float abs_threshold, size_t actual_comparision_size = 0);
|
||||
|
||||
protected:
|
||||
const std::string targetDevice;
|
||||
@ -37,6 +37,7 @@ protected:
|
||||
std::vector<ov::runtime::Tensor> actualOutData;
|
||||
float threshold = 1e-2f; // Relative diff
|
||||
float abs_threshold = -1.f; // Absolute diff (not used when negative)
|
||||
size_t actual_comparision_size = 0; // For ref output data is smaller than output blob size
|
||||
};
|
||||
|
||||
template <class T>
|
||||
|
@ -0,0 +1,226 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "openvino/op/experimental_detectron_prior_grid_generator.hpp"
|
||||
#include "base_reference_test.hpp"
|
||||
|
||||
using namespace reference_tests;
|
||||
using namespace ov;
|
||||
|
||||
using Attrs = op::v6::ExperimentalDetectronPriorGridGenerator::Attributes;
|
||||
|
||||
namespace {
|
||||
struct ExperimentalPGGParams {
|
||||
template <class IT>
|
||||
ExperimentalPGGParams(const Attrs& attrs,
|
||||
const PartialShape& priorsShape,
|
||||
const PartialShape& featureMapShape,
|
||||
const PartialShape& imageSizeInfoShape,
|
||||
const Shape& outRefShape,
|
||||
const element::Type& iType,
|
||||
const std::vector<IT>& priorsValues,
|
||||
const std::vector<IT>& refValues,
|
||||
const std::string& testcaseName = "")
|
||||
: attrs(attrs),
|
||||
priorsShape(priorsShape),
|
||||
featureMapShape(featureMapShape),
|
||||
imageSizeInfoShape(imageSizeInfoShape),
|
||||
outRefShape(outRefShape),
|
||||
inType(iType),
|
||||
outType(iType),
|
||||
priorsData(CreateTensor(iType, priorsValues)),
|
||||
refData(CreateTensor(outRefShape, iType, refValues)),
|
||||
testcaseName(testcaseName) {
|
||||
std::vector<IT> featureMapValues(shape_size(featureMapShape.get_shape()));
|
||||
std::iota(featureMapValues.begin(), featureMapValues.end(), 0);
|
||||
featureMapData = CreateTensor(iType, featureMapValues);
|
||||
|
||||
std::vector<IT> imageSizeInfoValues(shape_size(imageSizeInfoShape.get_shape()));
|
||||
std::iota(imageSizeInfoValues.begin(), imageSizeInfoValues.end(), 0);
|
||||
imageSizeInfoData = CreateTensor(iType, imageSizeInfoValues);
|
||||
|
||||
if (shape_size(outRefShape) > refValues.size())
|
||||
actualComparisonSize = refValues.size();
|
||||
else
|
||||
actualComparisonSize = 0;
|
||||
}
|
||||
|
||||
Attrs attrs;
|
||||
PartialShape priorsShape;
|
||||
PartialShape featureMapShape;
|
||||
PartialShape imageSizeInfoShape;
|
||||
Shape outRefShape;
|
||||
size_t actualComparisonSize;
|
||||
ov::element::Type inType;
|
||||
ov::element::Type outType;
|
||||
ov::runtime::Tensor priorsData;
|
||||
ov::runtime::Tensor featureMapData;
|
||||
ov::runtime::Tensor imageSizeInfoData;
|
||||
ov::runtime::Tensor refData;
|
||||
std::string testcaseName;
|
||||
};
|
||||
|
||||
class ReferenceExperimentalPGGLayerTest : public testing::TestWithParam<ExperimentalPGGParams>, public CommonReferenceTest {
|
||||
public:
|
||||
void SetUp() override {
|
||||
auto params = GetParam();
|
||||
function = CreateFunction(params);
|
||||
inputData = {params.priorsData, params.featureMapData, params.imageSizeInfoData};
|
||||
refOutData = {params.refData};
|
||||
|
||||
if (params.actualComparisonSize > 0)
|
||||
actual_comparision_size = params.actualComparisonSize;
|
||||
}
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<ExperimentalPGGParams>& obj) {
|
||||
auto param = obj.param;
|
||||
std::ostringstream result;
|
||||
result << "priorsShape=" << param.priorsShape << "_";
|
||||
result << "featureMapShape=" << param.featureMapShape << "_";
|
||||
result << "imageSizeInfoShape=" << param.imageSizeInfoShape << "_";
|
||||
result << "iType=" << param.inType << "_";
|
||||
result << "oType=" << param.outType << "_";
|
||||
result << "flatten=" << param.attrs.flatten << "_";
|
||||
result << "h=" << param.attrs.h << "_";
|
||||
result << "w=" << param.attrs.w << "_";
|
||||
result << "stride_x=" << param.attrs.stride_x << "_";
|
||||
result << "stride_y=" << param.attrs.stride_y;
|
||||
if (param.testcaseName != "")
|
||||
result << "_" << param.testcaseName;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
private:
|
||||
static std::shared_ptr<Function> CreateFunction(const ExperimentalPGGParams& params) {
|
||||
const auto priors = std::make_shared<op::v0::Parameter>(params.inType, params.priorsShape);
|
||||
const auto featureMap = std::make_shared<op::v0::Parameter>(params.inType, params.featureMapShape);
|
||||
const auto im_info = std::make_shared<op::v0::Parameter>(params.inType, params.imageSizeInfoShape);
|
||||
const auto ExperimentalPGG = std::make_shared<op::v6::ExperimentalDetectronPriorGridGenerator>(priors,
|
||||
featureMap,
|
||||
im_info,
|
||||
params.attrs);
|
||||
return std::make_shared<ov::Function>(NodeVector {ExperimentalPGG}, ParameterVector {priors, featureMap, im_info});
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ReferenceExperimentalPGGLayerTest, CompareWithRefs) {
|
||||
Exec();
|
||||
}
|
||||
|
||||
template <element::Type_t IN_ET>
|
||||
std::vector<ExperimentalPGGParams> generateExperimentalPGGFloatParams() {
|
||||
using T = typename element_type_traits<IN_ET>::value_type;
|
||||
|
||||
std::vector<ExperimentalPGGParams> experimentalPGGParams {
|
||||
ExperimentalPGGParams(Attrs{true, 0, 0, 4.0f, 4.0f},
|
||||
{3, 4},
|
||||
{1, 16, 4, 5},
|
||||
{1, 3, 100, 200},
|
||||
{60, 4},
|
||||
IN_ET,
|
||||
std::vector<T>{-24.5, -12.5, 24.5, 12.5, -16.5, -16.5, 16.5, 16.5, -12.5, -24.5, 12.5, 24.5},
|
||||
std::vector<T>{-22.5, -10.5, 26.5, 14.5, -14.5, -14.5, 18.5, 18.5, -10.5, -22.5, 14.5, 26.5, -18.5, -10.5, 30.5, 14.5,
|
||||
-10.5, -14.5, 22.5, 18.5, -6.5, -22.5, 18.5, 26.5, -14.5, -10.5, 34.5, 14.5, -6.5, -14.5, 26.5, 18.5,
|
||||
-2.5, -22.5, 22.5, 26.5, -10.5, -10.5, 38.5, 14.5, -2.5, -14.5, 30.5, 18.5, 1.5, -22.5, 26.5, 26.5,
|
||||
-6.5, -10.5, 42.5, 14.5, 1.5, -14.5, 34.5, 18.5, 5.5, -22.5, 30.5, 26.5, -22.5, -6.5, 26.5, 18.5,
|
||||
-14.5, -10.5, 18.5, 22.5, -10.5, -18.5, 14.5, 30.5, -18.5, -6.5, 30.5, 18.5, -10.5, -10.5, 22.5, 22.5,
|
||||
-6.5, -18.5, 18.5, 30.5, -14.5, -6.5, 34.5, 18.5, -6.5, -10.5, 26.5, 22.5, -2.5, -18.5, 22.5, 30.5,
|
||||
-10.5, -6.5, 38.5, 18.5, -2.5, -10.5, 30.5, 22.5, 1.5, -18.5, 26.5, 30.5, -6.5, -6.5, 42.5, 18.5,
|
||||
1.5, -10.5, 34.5, 22.5, 5.5, -18.5, 30.5, 30.5, -22.5, -2.5, 26.5, 22.5, -14.5, -6.5, 18.5, 26.5,
|
||||
-10.5, -14.5, 14.5, 34.5, -18.5, -2.5, 30.5, 22.5, -10.5, -6.5, 22.5, 26.5, -6.5, -14.5, 18.5, 34.5,
|
||||
-14.5, -2.5, 34.5, 22.5, -6.5, -6.5, 26.5, 26.5, -2.5, -14.5, 22.5, 34.5, -10.5, -2.5, 38.5, 22.5,
|
||||
-2.5, -6.5, 30.5, 26.5, 1.5, -14.5, 26.5, 34.5, -6.5, -2.5, 42.5, 22.5, 1.5, -6.5, 34.5, 26.5,
|
||||
5.5, -14.5, 30.5, 34.5, -22.5, 1.5, 26.5, 26.5, -14.5, -2.5, 18.5, 30.5, -10.5, -10.5, 14.5, 38.5,
|
||||
-18.5, 1.5, 30.5, 26.5, -10.5, -2.5, 22.5, 30.5, -6.5, -10.5, 18.5, 38.5, -14.5, 1.5, 34.5, 26.5,
|
||||
-6.5, -2.5, 26.5, 30.5, -2.5, -10.5, 22.5, 38.5, -10.5, 1.5, 38.5, 26.5, -2.5, -2.5, 30.5, 30.5,
|
||||
1.5, -10.5, 26.5, 38.5, -6.5, 1.5, 42.5, 26.5, 1.5, -2.5, 34.5, 30.5, 5.5, -10.5, 30.5, 38.5}),
|
||||
ExperimentalPGGParams(Attrs{false, 0, 0, 8.0f, 8.0f},
|
||||
{3, 4},
|
||||
{1, 16, 3, 7},
|
||||
{1, 3, 100, 200},
|
||||
{3, 7, 3, 4},
|
||||
IN_ET,
|
||||
std::vector<T>{-44.5, -24.5, 44.5, 24.5, -32.5, -32.5, 32.5, 32.5, -24.5, -44.5, 24.5, 44.5},
|
||||
std::vector<T>{-40.5, -20.5, 48.5, 28.5, -28.5, -28.5, 36.5, 36.5, -20.5, -40.5, 28.5, 48.5, -32.5, -20.5, 56.5, 28.5,
|
||||
-20.5, -28.5, 44.5, 36.5, -12.5, -40.5, 36.5, 48.5, -24.5, -20.5, 64.5, 28.5, -12.5, -28.5, 52.5, 36.5,
|
||||
-4.5, -40.5, 44.5, 48.5, -16.5, -20.5, 72.5, 28.5, -4.5, -28.5, 60.5, 36.5, 3.5, -40.5, 52.5, 48.5,
|
||||
-8.5, -20.5, 80.5, 28.5, 3.5, -28.5, 68.5, 36.5, 11.5, -40.5, 60.5, 48.5, -0.5, -20.5, 88.5, 28.5,
|
||||
11.5, -28.5, 76.5, 36.5, 19.5, -40.5, 68.5, 48.5, 7.5, -20.5, 96.5, 28.5, 19.5, -28.5, 84.5, 36.5,
|
||||
27.5, -40.5, 76.5, 48.5, -40.5, -12.5, 48.5, 36.5, -28.5, -20.5, 36.5, 44.5, -20.5, -32.5, 28.5, 56.5,
|
||||
-32.5, -12.5, 56.5, 36.5, -20.5, -20.5, 44.5, 44.5, -12.5, -32.5, 36.5, 56.5, -24.5, -12.5, 64.5, 36.5,
|
||||
-12.5, -20.5, 52.5, 44.5, -4.5, -32.5, 44.5, 56.5, -16.5, -12.5, 72.5, 36.5, -4.5, -20.5, 60.5, 44.5,
|
||||
3.5, -32.5, 52.5, 56.5, -8.5, -12.5, 80.5, 36.5, 3.5, -20.5, 68.5, 44.5, 11.5, -32.5, 60.5, 56.5,
|
||||
-0.5, -12.5, 88.5, 36.5, 11.5, -20.5, 76.5, 44.5, 19.5, -32.5, 68.5, 56.5, 7.5, -12.5, 96.5, 36.5,
|
||||
19.5, -20.5, 84.5, 44.5, 27.5, -32.5, 76.5, 56.5, -40.5, -4.5, 48.5, 44.5, -28.5, -12.5, 36.5, 52.5,
|
||||
-20.5, -24.5, 28.5, 64.5, -32.5, -4.5, 56.5, 44.5, -20.5, -12.5, 44.5, 52.5, -12.5, -24.5, 36.5, 64.5,
|
||||
-24.5, -4.5, 64.5, 44.5, -12.5, -12.5, 52.5, 52.5, -4.5, -24.5, 44.5, 64.5, -16.5, -4.5, 72.5, 44.5,
|
||||
-4.5, -12.5, 60.5, 52.5, 3.5, -24.5, 52.5, 64.5, -8.5, -4.5, 80.5, 44.5, 3.5, -12.5, 68.5, 52.5,
|
||||
11.5, -24.5, 60.5, 64.5, -0.5, -4.5, 88.5, 44.5, 11.5, -12.5, 76.5, 52.5, 19.5, -24.5, 68.5, 64.5,
|
||||
7.5, -4.5, 96.5, 44.5, 19.5, -12.5, 84.5, 52.5, 27.5, -24.5, 76.5, 64.5}),
|
||||
ExperimentalPGGParams(Attrs{true, 3, 6, 64.0f, 64.0f},
|
||||
{3, 4},
|
||||
{1, 16, 100, 100},
|
||||
{1, 3, 100, 200},
|
||||
{30000, 4},
|
||||
IN_ET,
|
||||
std::vector<T>{-364.5, -184.5, 364.5, 184.5, -256.5, -256.5, 256.5, 256.5, -180.5, -360.5, 180.5, 360.5},
|
||||
std::vector<T>{-332.5, -152.5, 396.5, 216.5, -224.5, -224.5, 288.5, 288.5, -148.5, -328.5, 212.5, 392.5, -268.5, -152.5,
|
||||
460.5, 216.5, -160.5, -224.5, 352.5, 288.5, -84.5, -328.5, 276.5, 392.5, -204.5, -152.5, 524.5, 216.5,
|
||||
-96.5, -224.5, 416.5, 288.5, -20.5, -328.5, 340.5, 392.5, -140.5, -152.5, 588.5, 216.5, -32.5, -224.5,
|
||||
480.5, 288.5, 43.5, -328.5, 404.5, 392.5, -76.5, -152.5, 652.5, 216.5, 31.5, -224.5, 544.5, 288.5,
|
||||
107.5, -328.5, 468.5, 392.5, -12.5, -152.5, 716.5, 216.5, 95.5, -224.5, 608.5, 288.5, 171.5, -328.5,
|
||||
532.5, 392.5, -332.5, -88.5, 396.5, 280.5, -224.5, -160.5, 288.5, 352.5, -148.5, -264.5, 212.5, 456.5,
|
||||
-268.5, -88.5, 460.5, 280.5, -160.5, -160.5, 352.5, 352.5, -84.5, -264.5, 276.5, 456.5, -204.5, -88.5,
|
||||
524.5, 280.5, -96.5, -160.5, 416.5, 352.5, -20.5, -264.5, 340.5, 456.5, -140.5, -88.5, 588.5, 280.5,
|
||||
-32.5, -160.5, 480.5, 352.5, 43.5, -264.5, 404.5, 456.5, -76.5, -88.5, 652.5, 280.5, 31.5, -160.5,
|
||||
544.5, 352.5, 107.5, -264.5, 468.5, 456.5, -12.5, -88.5, 716.5, 280.5, 95.5, -160.5, 608.5, 352.5,
|
||||
171.5, -264.5, 532.5, 456.5, -332.5, -24.5, 396.5, 344.5, -224.5, -96.5, 288.5, 416.5, -148.5, -200.5,
|
||||
212.5, 520.5, -268.5, -24.5, 460.5, 344.5, -160.5, -96.5, 352.5, 416.5, -84.5, -200.5, 276.5, 520.5,
|
||||
-204.5, -24.5, 524.5, 344.5, -96.5, -96.5, 416.5, 416.5, -20.5, -200.5, 340.5, 520.5, -140.5, -24.5,
|
||||
588.5, 344.5, -32.5, -96.5, 480.5, 416.5, 43.5, -200.5, 404.5, 520.5, -76.5, -24.5, 652.5, 344.5,
|
||||
31.5, -96.5, 544.5, 416.5, 107.5, -200.5, 468.5, 520.5, -12.5, -24.5, 716.5, 344.5, 95.5, -96.5,
|
||||
608.5, 416.5, 171.5, -200.5, 532.5, 520.5}),
|
||||
ExperimentalPGGParams(Attrs{false, 5, 3, 32.0f, 32.0f},
|
||||
{3, 4},
|
||||
{1, 16, 100, 100},
|
||||
{1, 3, 100, 200},
|
||||
{100, 100, 3, 4},
|
||||
IN_ET,
|
||||
std::vector<T>{-180.5, -88.5, 180.5, 88.5, -128.5, -128.5, 128.5, 128.5, -92.5, -184.5, 92.5, 184.5},
|
||||
std::vector<T>{-164.5, -72.5, 196.5, 104.5, -112.5, -112.5, 144.5, 144.5, -76.5, -168.5, 108.5, 200.5, -132.5, -72.5,
|
||||
228.5, 104.5, -80.5, -112.5, 176.5, 144.5, -44.5, -168.5, 140.5, 200.5, -100.5, -72.5, 260.5, 104.5,
|
||||
-48.5, -112.5, 208.5, 144.5, -12.5, -168.5, 172.5, 200.5, -164.5, -40.5, 196.5, 136.5, -112.5, -80.5,
|
||||
144.5, 176.5, -76.5, -136.5, 108.5, 232.5, -132.5, -40.5, 228.5, 136.5, -80.5, -80.5, 176.5, 176.5,
|
||||
-44.5, -136.5, 140.5, 232.5, -100.5, -40.5, 260.5, 136.5, -48.5, -80.5, 208.5, 176.5, -12.5, -136.5,
|
||||
172.5, 232.5, -164.5, -8.5, 196.5, 168.5, -112.5, -48.5, 144.5, 208.5, -76.5, -104.5, 108.5, 264.5,
|
||||
-132.5, -8.5, 228.5, 168.5, -80.5, -48.5, 176.5, 208.5, -44.5, -104.5, 140.5, 264.5, -100.5, -8.5,
|
||||
260.5, 168.5, -48.5, -48.5, 208.5, 208.5, -12.5, -104.5, 172.5, 264.5, -164.5, 23.5, 196.5, 200.5,
|
||||
-112.5, -16.5, 144.5, 240.5, -76.5, -72.5, 108.5, 296.5, -132.5, 23.5, 228.5, 200.5, -80.5, -16.5,
|
||||
176.5, 240.5, -44.5, -72.5, 140.5, 296.5, -100.5, 23.5, 260.5, 200.5, -48.5, -16.5, 208.5, 240.5,
|
||||
-12.5, -72.5, 172.5, 296.5, -164.5, 55.5, 196.5, 232.5, -112.5, 15.5, 144.5, 272.5, -76.5, -40.5,
|
||||
108.5, 328.5, -132.5, 55.5, 228.5, 232.5, -80.5, 15.5, 176.5, 272.5, -44.5, -40.5, 140.5, 328.5,
|
||||
-100.5, 55.5, 260.5, 232.5, -48.5, 15.5, 208.5, 272.5, -12.5, -40.5, 172.5, 328.5}),
|
||||
};
|
||||
return experimentalPGGParams;
|
||||
}
|
||||
|
||||
std::vector<ExperimentalPGGParams> generateExperimentalPGGCombinedParams() {
|
||||
const std::vector<std::vector<ExperimentalPGGParams>> experimentalPGGTypeParams {
|
||||
generateExperimentalPGGFloatParams<element::Type_t::f64>(),
|
||||
generateExperimentalPGGFloatParams<element::Type_t::f32>(),
|
||||
generateExperimentalPGGFloatParams<element::Type_t::f16>(),
|
||||
generateExperimentalPGGFloatParams<element::Type_t::bf16>(),
|
||||
};
|
||||
std::vector<ExperimentalPGGParams> combinedParams;
|
||||
|
||||
for (const auto& params : experimentalPGGTypeParams) {
|
||||
combinedParams.insert(combinedParams.end(), params.begin(), params.end());
|
||||
}
|
||||
return combinedParams;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_ExperimentalDetectronPriorGridGenerator_With_Hardcoded_Refs, ReferenceExperimentalPGGLayerTest,
|
||||
testing::ValuesIn(generateExperimentalPGGCombinedParams()), ReferenceExperimentalPGGLayerTest::getTestCaseName);
|
||||
} // namespace
|
@ -106,6 +106,8 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
R"(.*ReferenceMulticlassNmsTest.*esiType=i64.*evoType=i64.*)",
|
||||
// CVS-64096
|
||||
R"(.*ReferenceNonMaxSuppressionTest.*esiType=i32.*evoType=i32.*)",
|
||||
// CVS-64102
|
||||
R"(.*ReferenceExperimentalPGGLayerTest.*iType=bf16.*stride_x=(32|64).*)",
|
||||
};
|
||||
|
||||
#ifdef _WIN32
|
||||
|
@ -318,6 +318,7 @@ set(SRC
|
||||
visitors/op/exp.cpp
|
||||
visitors/op/experimental_detectron_detection_output.cpp
|
||||
visitors/op/experimental_detectron_generate_proposals.cpp
|
||||
visitors/op/experimental_detectron_prior_grid_generator.cpp
|
||||
visitors/op/experimental_detectron_topkrois.cpp
|
||||
visitors/op/extractimagepatches.cpp
|
||||
visitors/op/fake_quantize.cpp
|
||||
@ -489,7 +490,6 @@ set(MULTI_TEST_SRC
|
||||
backend/builder_reduce_ops_opset1.in.cpp
|
||||
backend/dyn_reshape.in.cpp
|
||||
backend/dynamic.in.cpp
|
||||
backend/experimental_detectron_prior_grid.in.cpp
|
||||
backend/function_name.in.cpp
|
||||
backend/interpolate.in.cpp
|
||||
backend/multiple_backends.in.cpp
|
||||
|
@ -1,172 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2021 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
// clang-format off
|
||||
#ifdef ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
|
||||
#define DEFAULT_FLOAT_TOLERANCE_BITS ${BACKEND_NAME}_FLOAT_TOLERANCE_BITS
|
||||
#endif
|
||||
|
||||
#ifdef ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
|
||||
#define DEFAULT_DOUBLE_TOLERANCE_BITS ${BACKEND_NAME}_DOUBLE_TOLERANCE_BITS
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "runtime/backend.hpp"
|
||||
#include "ngraph/runtime/tensor.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/all_close.hpp"
|
||||
#include "util/all_close_f.hpp"
|
||||
#include "util/ndarray.hpp"
|
||||
#include "engines_util/random.hpp"
|
||||
#include "util/test_control.hpp"
|
||||
#include "engines_util/execute_tools.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
using Attrs = op::v6::ExperimentalDetectronPriorGridGenerator::Attributes;
|
||||
using GridGenerator = op::v6::ExperimentalDetectronPriorGridGenerator;
|
||||
|
||||
static string s_manifest = "${MANIFEST}";
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, experimental_detectron_prior_grid_eval) {
|
||||
std::vector<std::vector<float>> priors_value = {
|
||||
{-24.5, -12.5, 24.5, 12.5, -16.5, -16.5, 16.5, 16.5, -12.5, -24.5, 12.5, 24.5},
|
||||
{-44.5, -24.5, 44.5, 24.5, -32.5, -32.5, 32.5, 32.5, -24.5, -44.5, 24.5, 44.5},
|
||||
{-364.5, -184.5, 364.5, 184.5, -256.5, -256.5, 256.5, 256.5, -180.5, -360.5, 180.5, 360.5},
|
||||
{-180.5, -88.5, 180.5, 88.5, -128.5, -128.5, 128.5, 128.5, -92.5, -184.5, 92.5, 184.5}};
|
||||
|
||||
struct ShapesAndAttrs {
|
||||
Attrs attrs;
|
||||
Shape priors_shape;
|
||||
Shape feature_map_shape;
|
||||
Shape im_data_shape;
|
||||
Shape ref_out_shape;
|
||||
};
|
||||
|
||||
std::vector<ShapesAndAttrs> shapes_and_attrs = {
|
||||
{{true, 0, 0, 4.0f, 4.0f}, {3, 4}, {1, 16, 4, 5}, {1, 3, 100, 200}, {60, 4}},
|
||||
{{false, 0, 0, 8.0f, 8.0f}, {3, 4}, {1, 16, 3, 7}, {1, 3, 100, 200}, {3, 7, 3, 4}},
|
||||
{{true, 3, 6, 64.0f, 64.0f}, {3, 4}, {1, 16, 100, 100}, {1, 3, 100, 200}, {30000, 4}},
|
||||
{{false, 5, 3, 32.0f, 32.0f}, {3, 4}, {1, 16, 100, 100}, {1, 3, 100, 200}, {100, 100, 3, 4}}};
|
||||
|
||||
std::vector<std::vector<float>> expected_results = {
|
||||
{-22.5, -10.5, 26.5, 14.5, -14.5, -14.5, 18.5, 18.5, -10.5, -22.5, 14.5, 26.5, -18.5, -10.5, 30.5, 14.5,
|
||||
-10.5, -14.5, 22.5, 18.5, -6.5, -22.5, 18.5, 26.5, -14.5, -10.5, 34.5, 14.5, -6.5, -14.5, 26.5, 18.5,
|
||||
-2.5, -22.5, 22.5, 26.5, -10.5, -10.5, 38.5, 14.5, -2.5, -14.5, 30.5, 18.5, 1.5, -22.5, 26.5, 26.5,
|
||||
-6.5, -10.5, 42.5, 14.5, 1.5, -14.5, 34.5, 18.5, 5.5, -22.5, 30.5, 26.5, -22.5, -6.5, 26.5, 18.5,
|
||||
-14.5, -10.5, 18.5, 22.5, -10.5, -18.5, 14.5, 30.5, -18.5, -6.5, 30.5, 18.5, -10.5, -10.5, 22.5, 22.5,
|
||||
-6.5, -18.5, 18.5, 30.5, -14.5, -6.5, 34.5, 18.5, -6.5, -10.5, 26.5, 22.5, -2.5, -18.5, 22.5, 30.5,
|
||||
-10.5, -6.5, 38.5, 18.5, -2.5, -10.5, 30.5, 22.5, 1.5, -18.5, 26.5, 30.5, -6.5, -6.5, 42.5, 18.5,
|
||||
1.5, -10.5, 34.5, 22.5, 5.5, -18.5, 30.5, 30.5, -22.5, -2.5, 26.5, 22.5, -14.5, -6.5, 18.5, 26.5,
|
||||
-10.5, -14.5, 14.5, 34.5, -18.5, -2.5, 30.5, 22.5, -10.5, -6.5, 22.5, 26.5, -6.5, -14.5, 18.5, 34.5,
|
||||
-14.5, -2.5, 34.5, 22.5, -6.5, -6.5, 26.5, 26.5, -2.5, -14.5, 22.5, 34.5, -10.5, -2.5, 38.5, 22.5,
|
||||
-2.5, -6.5, 30.5, 26.5, 1.5, -14.5, 26.5, 34.5, -6.5, -2.5, 42.5, 22.5, 1.5, -6.5, 34.5, 26.5,
|
||||
5.5, -14.5, 30.5, 34.5, -22.5, 1.5, 26.5, 26.5, -14.5, -2.5, 18.5, 30.5, -10.5, -10.5, 14.5, 38.5,
|
||||
-18.5, 1.5, 30.5, 26.5, -10.5, -2.5, 22.5, 30.5, -6.5, -10.5, 18.5, 38.5, -14.5, 1.5, 34.5, 26.5,
|
||||
-6.5, -2.5, 26.5, 30.5, -2.5, -10.5, 22.5, 38.5, -10.5, 1.5, 38.5, 26.5, -2.5, -2.5, 30.5, 30.5,
|
||||
1.5, -10.5, 26.5, 38.5, -6.5, 1.5, 42.5, 26.5, 1.5, -2.5, 34.5, 30.5, 5.5, -10.5, 30.5, 38.5},
|
||||
{-40.5, -20.5, 48.5, 28.5, -28.5, -28.5, 36.5, 36.5, -20.5, -40.5, 28.5, 48.5, -32.5, -20.5, 56.5, 28.5,
|
||||
-20.5, -28.5, 44.5, 36.5, -12.5, -40.5, 36.5, 48.5, -24.5, -20.5, 64.5, 28.5, -12.5, -28.5, 52.5, 36.5,
|
||||
-4.5, -40.5, 44.5, 48.5, -16.5, -20.5, 72.5, 28.5, -4.5, -28.5, 60.5, 36.5, 3.5, -40.5, 52.5, 48.5,
|
||||
-8.5, -20.5, 80.5, 28.5, 3.5, -28.5, 68.5, 36.5, 11.5, -40.5, 60.5, 48.5, -0.5, -20.5, 88.5, 28.5,
|
||||
11.5, -28.5, 76.5, 36.5, 19.5, -40.5, 68.5, 48.5, 7.5, -20.5, 96.5, 28.5, 19.5, -28.5, 84.5, 36.5,
|
||||
27.5, -40.5, 76.5, 48.5, -40.5, -12.5, 48.5, 36.5, -28.5, -20.5, 36.5, 44.5, -20.5, -32.5, 28.5, 56.5,
|
||||
-32.5, -12.5, 56.5, 36.5, -20.5, -20.5, 44.5, 44.5, -12.5, -32.5, 36.5, 56.5, -24.5, -12.5, 64.5, 36.5,
|
||||
-12.5, -20.5, 52.5, 44.5, -4.5, -32.5, 44.5, 56.5, -16.5, -12.5, 72.5, 36.5, -4.5, -20.5, 60.5, 44.5,
|
||||
3.5, -32.5, 52.5, 56.5, -8.5, -12.5, 80.5, 36.5, 3.5, -20.5, 68.5, 44.5, 11.5, -32.5, 60.5, 56.5,
|
||||
-0.5, -12.5, 88.5, 36.5, 11.5, -20.5, 76.5, 44.5, 19.5, -32.5, 68.5, 56.5, 7.5, -12.5, 96.5, 36.5,
|
||||
19.5, -20.5, 84.5, 44.5, 27.5, -32.5, 76.5, 56.5, -40.5, -4.5, 48.5, 44.5, -28.5, -12.5, 36.5, 52.5,
|
||||
-20.5, -24.5, 28.5, 64.5, -32.5, -4.5, 56.5, 44.5, -20.5, -12.5, 44.5, 52.5, -12.5, -24.5, 36.5, 64.5,
|
||||
-24.5, -4.5, 64.5, 44.5, -12.5, -12.5, 52.5, 52.5, -4.5, -24.5, 44.5, 64.5, -16.5, -4.5, 72.5, 44.5,
|
||||
-4.5, -12.5, 60.5, 52.5, 3.5, -24.5, 52.5, 64.5, -8.5, -4.5, 80.5, 44.5, 3.5, -12.5, 68.5, 52.5,
|
||||
11.5, -24.5, 60.5, 64.5, -0.5, -4.5, 88.5, 44.5, 11.5, -12.5, 76.5, 52.5, 19.5, -24.5, 68.5, 64.5,
|
||||
7.5, -4.5, 96.5, 44.5, 19.5, -12.5, 84.5, 52.5, 27.5, -24.5, 76.5, 64.5},
|
||||
{-332.5, -152.5, 396.5, 216.5, -224.5, -224.5, 288.5, 288.5, -148.5, -328.5, 212.5, 392.5, -268.5, -152.5,
|
||||
460.5, 216.5, -160.5, -224.5, 352.5, 288.5, -84.5, -328.5, 276.5, 392.5, -204.5, -152.5, 524.5, 216.5,
|
||||
-96.5, -224.5, 416.5, 288.5, -20.5, -328.5, 340.5, 392.5, -140.5, -152.5, 588.5, 216.5, -32.5, -224.5,
|
||||
480.5, 288.5, 43.5, -328.5, 404.5, 392.5, -76.5, -152.5, 652.5, 216.5, 31.5, -224.5, 544.5, 288.5,
|
||||
107.5, -328.5, 468.5, 392.5, -12.5, -152.5, 716.5, 216.5, 95.5, -224.5, 608.5, 288.5, 171.5, -328.5,
|
||||
532.5, 392.5, -332.5, -88.5, 396.5, 280.5, -224.5, -160.5, 288.5, 352.5, -148.5, -264.5, 212.5, 456.5,
|
||||
-268.5, -88.5, 460.5, 280.5, -160.5, -160.5, 352.5, 352.5, -84.5, -264.5, 276.5, 456.5, -204.5, -88.5,
|
||||
524.5, 280.5, -96.5, -160.5, 416.5, 352.5, -20.5, -264.5, 340.5, 456.5, -140.5, -88.5, 588.5, 280.5,
|
||||
-32.5, -160.5, 480.5, 352.5, 43.5, -264.5, 404.5, 456.5, -76.5, -88.5, 652.5, 280.5, 31.5, -160.5,
|
||||
544.5, 352.5, 107.5, -264.5, 468.5, 456.5, -12.5, -88.5, 716.5, 280.5, 95.5, -160.5, 608.5, 352.5,
|
||||
171.5, -264.5, 532.5, 456.5, -332.5, -24.5, 396.5, 344.5, -224.5, -96.5, 288.5, 416.5, -148.5, -200.5,
|
||||
212.5, 520.5, -268.5, -24.5, 460.5, 344.5, -160.5, -96.5, 352.5, 416.5, -84.5, -200.5, 276.5, 520.5,
|
||||
-204.5, -24.5, 524.5, 344.5, -96.5, -96.5, 416.5, 416.5, -20.5, -200.5, 340.5, 520.5, -140.5, -24.5,
|
||||
588.5, 344.5, -32.5, -96.5, 480.5, 416.5, 43.5, -200.5, 404.5, 520.5, -76.5, -24.5, 652.5, 344.5,
|
||||
31.5, -96.5, 544.5, 416.5, 107.5, -200.5, 468.5, 520.5, -12.5, -24.5, 716.5, 344.5, 95.5, -96.5,
|
||||
608.5, 416.5, 171.5, -200.5, 532.5, 520.5},
|
||||
{-164.5, -72.5, 196.5, 104.5, -112.5, -112.5, 144.5, 144.5, -76.5, -168.5, 108.5, 200.5, -132.5, -72.5,
|
||||
228.5, 104.5, -80.5, -112.5, 176.5, 144.5, -44.5, -168.5, 140.5, 200.5, -100.5, -72.5, 260.5, 104.5,
|
||||
-48.5, -112.5, 208.5, 144.5, -12.5, -168.5, 172.5, 200.5, -164.5, -40.5, 196.5, 136.5, -112.5, -80.5,
|
||||
144.5, 176.5, -76.5, -136.5, 108.5, 232.5, -132.5, -40.5, 228.5, 136.5, -80.5, -80.5, 176.5, 176.5,
|
||||
-44.5, -136.5, 140.5, 232.5, -100.5, -40.5, 260.5, 136.5, -48.5, -80.5, 208.5, 176.5, -12.5, -136.5,
|
||||
172.5, 232.5, -164.5, -8.5, 196.5, 168.5, -112.5, -48.5, 144.5, 208.5, -76.5, -104.5, 108.5, 264.5,
|
||||
-132.5, -8.5, 228.5, 168.5, -80.5, -48.5, 176.5, 208.5, -44.5, -104.5, 140.5, 264.5, -100.5, -8.5,
|
||||
260.5, 168.5, -48.5, -48.5, 208.5, 208.5, -12.5, -104.5, 172.5, 264.5, -164.5, 23.5, 196.5, 200.5,
|
||||
-112.5, -16.5, 144.5, 240.5, -76.5, -72.5, 108.5, 296.5, -132.5, 23.5, 228.5, 200.5, -80.5, -16.5,
|
||||
176.5, 240.5, -44.5, -72.5, 140.5, 296.5, -100.5, 23.5, 260.5, 200.5, -48.5, -16.5, 208.5, 240.5,
|
||||
-12.5, -72.5, 172.5, 296.5, -164.5, 55.5, 196.5, 232.5, -112.5, 15.5, 144.5, 272.5, -76.5, -40.5,
|
||||
108.5, 328.5, -132.5, 55.5, 228.5, 232.5, -80.5, 15.5, 176.5, 272.5, -44.5, -40.5, 140.5, 328.5,
|
||||
-100.5, 55.5, 260.5, 232.5, -48.5, 15.5, 208.5, 272.5, -12.5, -40.5, 172.5, 328.5}};
|
||||
|
||||
std::size_t i = 0;
|
||||
for (const auto& s : shapes_and_attrs) {
|
||||
auto priors = std::make_shared<op::Parameter>(element::f32, s.priors_shape);
|
||||
auto feature_map = std::make_shared<op::Parameter>(element::f32, s.feature_map_shape);
|
||||
auto im_data = std::make_shared<op::Parameter>(element::f32, s.im_data_shape);
|
||||
|
||||
auto grid_gen = std::make_shared<GridGenerator>(priors, feature_map, im_data, s.attrs);
|
||||
|
||||
auto f = make_shared<Function>(grid_gen, ParameterVector{priors, feature_map, im_data});
|
||||
|
||||
auto backend = runtime::Backend::create("${BACKEND_NAME}");
|
||||
|
||||
auto priors_data = priors_value[i];
|
||||
|
||||
auto& ref_results = expected_results[i];
|
||||
|
||||
std::vector<float> feature_map_data(shape_size(s.feature_map_shape));
|
||||
std::iota(feature_map_data.begin(), feature_map_data.end(), 0);
|
||||
std::vector<float> image_data(shape_size(s.im_data_shape));
|
||||
std::iota(image_data.begin(), image_data.end(), 0);
|
||||
|
||||
auto output_priors = backend->create_tensor(element::f32, s.ref_out_shape);
|
||||
|
||||
auto backend_priors = backend->create_tensor(element::f32, s.priors_shape);
|
||||
auto backend_feature_map = backend->create_tensor(element::f32, s.feature_map_shape);
|
||||
auto backend_im_data = backend->create_tensor(element::f32, s.im_data_shape);
|
||||
copy_data(backend_priors, priors_data);
|
||||
copy_data(backend_feature_map, feature_map_data);
|
||||
copy_data(backend_im_data, image_data);
|
||||
|
||||
auto handle = backend->compile(f);
|
||||
|
||||
handle->call({output_priors}, {backend_priors, backend_feature_map, backend_im_data});
|
||||
|
||||
auto output_priors_value = read_vector<float>(output_priors);
|
||||
|
||||
std::vector<float> actual_results(output_priors_value.begin(),
|
||||
output_priors_value.begin() + ref_results.size());
|
||||
EXPECT_EQ(ref_results, actual_results);
|
||||
++i;
|
||||
}
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/opsets/opset6.hpp"
|
||||
#include "util/visitor.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using ngraph::test::NodeBuilder;
|
||||
using ngraph::test::ValueMap;
|
||||
|
||||
using ExperimentalGenerator = opset6::ExperimentalDetectronPriorGridGenerator;
|
||||
using Attrs = opset6::ExperimentalDetectronPriorGridGenerator::Attributes;
|
||||
|
||||
TEST(attributes, detectron_prior_grid_generator) {
|
||||
NodeBuilder::get_ops().register_factory<ExperimentalGenerator>();
|
||||
|
||||
Attrs attrs;
|
||||
attrs.flatten = true;
|
||||
attrs.h = 3;
|
||||
attrs.w = 6;
|
||||
attrs.stride_x = 64;
|
||||
attrs.stride_y = 64;
|
||||
|
||||
auto priors = std::make_shared<op::Parameter>(element::f32, Shape{3, 4});
|
||||
auto feature_map = std::make_shared<op::Parameter>(element::f32, Shape{1, 16, 100, 100});
|
||||
auto im_data = std::make_shared<op::Parameter>(element::f32, Shape{1, 3, 100, 200});
|
||||
|
||||
auto proposals = std::make_shared<ExperimentalGenerator>(priors, feature_map, im_data, attrs);
|
||||
|
||||
NodeBuilder builder(proposals);
|
||||
|
||||
auto g_proposals = ov::as_type_ptr<ExperimentalGenerator>(builder.create());
|
||||
|
||||
const auto expected_attr_count = 5;
|
||||
EXPECT_EQ(builder.get_value_map_size(), expected_attr_count);
|
||||
|
||||
EXPECT_EQ(g_proposals->get_attrs().flatten, proposals->get_attrs().flatten);
|
||||
EXPECT_EQ(g_proposals->get_attrs().h, proposals->get_attrs().h);
|
||||
EXPECT_EQ(g_proposals->get_attrs().w, proposals->get_attrs().w);
|
||||
EXPECT_EQ(g_proposals->get_attrs().stride_x, proposals->get_attrs().stride_x);
|
||||
EXPECT_EQ(g_proposals->get_attrs().stride_y, proposals->get_attrs().stride_y);
|
||||
}
|
Loading…
Reference in New Issue
Block a user