Add interpolate from all opsets to cpu shape infer (#17875)

This commit is contained in:
Pawel Raasz 2023-06-07 11:28:45 +02:00 committed by GitHub
parent 13028397b7
commit f023f5d672
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 2 deletions

View File

@ -5,7 +5,9 @@
#include <openvino/core/node.hpp>
#include <openvino/opsets/opset1.hpp>
#include <openvino/opsets/opset10.hpp>
#include <openvino/opsets/opset11.hpp>
#include <openvino/opsets/opset12.hpp>
#include <openvino/opsets/opset4.hpp>
#include <openvino/opsets/opset5.hpp>
#include <openvino/opsets/opset7.hpp>
@ -634,7 +636,6 @@ template <>
const IStaticShapeInferFactory::TRegistry IStaticShapeInferFactory::registry{
// Default opset
_OV_OP_SHAPE_INFER_MASK_REG(ExperimentalDetectronROIFeatureExtractor, ShapeInferTA, util::bit::mask()),
_OV_OP_SHAPE_INFER_MASK_REG(Interpolate, ShapeInferPaddingTA, util::bit::mask(1, 2, 3)),
_OV_OP_SHAPE_INFER_MASK_REG(Proposal, ShapeInferTA, util::bit::mask()),
_OV_OP_SHAPE_INFER_VA_REG(ReduceL1, ShapeInferTA, op::util::ArithmeticReductionKeepDims, util::bit::mask(1)),
_OV_OP_SHAPE_INFER_VA_REG(ReduceL2, ShapeInferTA, op::util::ArithmeticReductionKeepDims, util::bit::mask(1)),
@ -647,6 +648,10 @@ const IStaticShapeInferFactory::TRegistry IStaticShapeInferFactory::registry{
_OV_OP_SHAPE_INFER_VA_REG(ReduceSum, ShapeInferTA, op::util::ArithmeticReductionKeepDims, util::bit::mask(1)),
_OV_OP_SHAPE_INFER_MASK_REG(Tile, ShapeInferTA, util::bit::mask(1)),
// Operators shape inferences for specific opset version should be specified below
// opset11
_OV_OP_SHAPE_INFER_MASK_REG(opset11::Interpolate, ShapeInferPaddingTA, util::bit::mask(1, 2, 3)),
// opset4
_OV_OP_SHAPE_INFER_MASK_REG(opset4::Interpolate, ShapeInferPaddingTA, util::bit::mask(1, 2)),
// opset1
_OV_OP_SHAPE_INFER_MASK_REG(opset1::Interpolate, ShapeInferTA, util::bit::mask(1)),
_OV_OP_SHAPE_INFER_MASK_REG(opset1::Proposal, ShapeInferTA, util::bit::mask()),

View File

@ -211,7 +211,6 @@ protected:
};
TEST_F(InterpolateV11StaticShapeInferenceTest, default_ctor_no_attributes) {
GTEST_SKIP() << "Enable test when v11 opset will be added to shape inference factory.";
attrs.shape_calculation_mode = ShapeCalcMode::SCALES;
op = make_op();