Support NonConst pads_begin and pads_end in Pad op (#8697)

* Support pads_begin & pads_end as dynamic ops in Pad

* Extend Pad template test w/ NonConst PB & PE cases

* Remove xfails for 69443 after issue was fixed
This commit is contained in:
Vitaliy Urusovskij
2021-11-23 09:27:37 +03:00
committed by GitHub
parent 251883001c
commit 3b88682159
6 changed files with 86 additions and 18 deletions

View File

@@ -116,6 +116,64 @@ TEST_P(ReferencePadTestParamsOk, CompareWithRefs) {
EXPECT_NO_THROW(Exec());
}
class ReferencePadTestNonConstPadsBeginPadsEnd : public ReferencePadTest {
public:
void SetUp() override {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
auto params = GetParam();
function = CreateFunction(params);
inputData = {params.inputData.data, params.padsBegin.data, params.padsEnd.data};
refOutData = {params.expectedOutput.data};
}
private:
static std::shared_ptr<Function> CreateFunction(const PadParams& params) {
const auto data = std::make_shared<op::v0::Parameter>(params.inputData.type,
params.inputData.shape);
const auto padsBegin = std::make_shared<op::v0::Parameter>(params.padsBegin.type,
params.padsBegin.shape);
const auto padsEnd = std::make_shared<op::v0::Parameter>(params.padsEnd.type,
params.padsEnd.shape);
const auto f = [&] {
if (params.useConstValue) {
// pad_value should be used only in CONSTANT mode
const auto padVal = op::v0::Constant::create(params.constantValue.type,
params.constantValue.shape,
params.constantValue.data.data());
return std::make_shared<Function>(std::make_shared<op::v1::Pad>(data,
padsBegin,
padsEnd,
padVal,
params.padMode),
ParameterVector{data, padsBegin, padsEnd});
}
return std::make_shared<Function>(std::make_shared<op::v1::Pad>(data,
padsBegin,
padsEnd,
params.padMode),
ParameterVector{data, padsBegin, padsEnd});
}();
return f;
}
};
TEST_P(ReferencePadTestNonConstPadsBeginPadsEnd, CompareWithRefs) {
Exec();
}
class ReferencePadTestNonConstPadsBeginPadsEndTooLarge : public ReferencePadTestNonConstPadsBeginPadsEnd {};
TEST_P(ReferencePadTestNonConstPadsBeginPadsEndTooLarge, CompareWithRefs) {
EXPECT_ANY_THROW(Exec());
}
class ReferencePadTestNonConstPadsBeginPadsEndParamsOk : public ReferencePadTestNonConstPadsBeginPadsEnd {};
TEST_P(ReferencePadTestNonConstPadsBeginPadsEndParamsOk, CompareWithRefs) {
EXPECT_NO_THROW(Exec());
}
template <element::Type_t ET, element::Type_t ET_INT>
std::vector<PadParams> generateParams() {
using T = typename element_type_traits<ET>::value_type;
@@ -1005,6 +1063,9 @@ std::vector<PadParams> generateCombinedParams() {
INSTANTIATE_TEST_SUITE_P(smoke_Pad_With_Hardcoded_Refs, ReferencePadTest,
testing::ValuesIn(generateCombinedParams()), ReferencePadTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Pad_With_Hardcoded_Refs, ReferencePadTestNonConstPadsBeginPadsEnd,
testing::ValuesIn(generateCombinedParams()), ReferencePadTest::getTestCaseName);
template <element::Type_t ET, element::Type_t ET_INT>
std::vector<PadParams> generateParamsTooLarge() {
using T = typename element_type_traits<ET>::value_type;
@@ -1052,6 +1113,9 @@ std::vector<PadParams> generateCombinedParamsTooLarge() {
INSTANTIATE_TEST_SUITE_P(smoke_Pad_With_Hardcoded_Refs, ReferencePadTestParamsTooLarge,
testing::ValuesIn(generateCombinedParamsTooLarge()), ReferencePadTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Pad_With_Hardcoded_Refs, ReferencePadTestNonConstPadsBeginPadsEndTooLarge,
testing::ValuesIn(generateCombinedParamsTooLarge()), ReferencePadTest::getTestCaseName);
template <element::Type_t ET, element::Type_t ET_INT>
std::vector<PadParams> generateParamsOk() {
using T = typename element_type_traits<ET>::value_type;
@@ -1098,4 +1162,7 @@ std::vector<PadParams> generateCombinedParamsOk() {
INSTANTIATE_TEST_SUITE_P(smoke_Pad_With_Hardcoded_Refs, ReferencePadTestParamsOk,
testing::ValuesIn(generateCombinedParamsOk()), ReferencePadTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Pad_With_Hardcoded_Refs, ReferencePadTestNonConstPadsBeginPadsEndParamsOk,
testing::ValuesIn(generateCombinedParamsOk()), ReferencePadTest::getTestCaseName);
} // namespace