enable eltwise skip test

Signed-off-by: HU Yuan2 <yuan2.hu@intel.com>
This commit is contained in:
HU Yuan2 2023-08-07 14:31:20 +08:00
parent d2947e2385
commit c9f0a6f225
5 changed files with 42 additions and 10 deletions

View File

@ -1804,7 +1804,7 @@ bool Eltwise::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op
}
Eltwise::Eltwise(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context) :
Node(op, context, EltwiseShapeInferFactory()), broadcastingPolicy(Undefined) {
Node(op, context, EltwiseShapeInferFactory(op)), broadcastingPolicy(Undefined) {
std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) {
IE_THROW(NotImplemented) << errorMessage;

View File

@ -52,6 +52,26 @@ Result EltwiseShapeInfer::infer(
return { { std::move(output_shape) }, ShapeInferStatus::success };
}
Result NoBroadCastEltwiseShapeInfer::infer(
const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
const std::unordered_map<size_t, MemoryPtr>& data_dependency) {
if (input_shapes.size() == 1) {
return { { std::move(input_shapes[0].get()) }, ShapeInferStatus::success };
} else {
auto input_shape = input_shapes[0].get();
auto output_shape = input_shapes[1].get();
if (input_shape.size() != output_shape.size()) {
OPENVINO_THROW("Eltwise shape infer input and output shapes rank mismatch");
}
for (size_t j = 0; j < input_shapes.size(); ++j) {
if (input_shape[j] != output_shape[j]) {
OPENVINO_THROW("Eltwise shape infer input shapes dim index: ", j, " mismatch");
}
}
return { { std::move(output_shape) }, ShapeInferStatus::success };
}
}
} // namespace node
} // namespace intel_cpu
} // namespace ov

View File

@ -26,11 +26,30 @@ public:
}
};
class NoBroadCastEltwiseShapeInfer : public ShapeInferEmptyPads {
public:
Result infer(
const std::vector<std::reference_wrapper<const VectorDims>>& input_shapes,
const std::unordered_map<size_t, MemoryPtr>& data_dependency) override;
port_mask_t get_port_mask() const override {
return EMPTY_PORT_MASK;
}
};
class EltwiseShapeInferFactory : public ShapeInferFactory {
public:
EltwiseShapeInferFactory(std::shared_ptr<ov::Node> op) : m_op(op) {}
ShapeInferPtr makeShapeInfer() const override {
return std::make_shared<EltwiseShapeInfer>();
const auto& autob = m_op->get_autob();
if (autob.m_type == ov::op::AutoBroadcastType::NONE) {
return std::make_shared<NoBroadCastEltwiseShapeInfer>();
} else {
return std::make_shared<EltwiseShapeInfer>();
}
}
private:
std::shared_ptr<ov::Node> m_op;
};
} // namespace node
} // namespace intel_cpu

View File

@ -83,7 +83,6 @@ TYPED_TEST_P(CpuShapeInferenceTest_BEA, shape_inference_aubtob_none) {
}
TYPED_TEST_P(CpuShapeInferenceTest_BEA, shape_inference_aubtob_none_incompatible_shapes) {
GTEST_SKIP() << "Skipping test, please check CVS-108946";
auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
@ -92,7 +91,6 @@ TYPED_TEST_P(CpuShapeInferenceTest_BEA, shape_inference_aubtob_none_incompatible
std::vector<StaticShape> static_input_shapes = {StaticShape{3, 4, 6, 5}, StaticShape{3, 1, 6, 1}},
static_output_shapes = {StaticShape{}};
//TODO cvs-108946, below test can't pass.
OV_EXPECT_THROW(unit_test::cpu_test_shape_infer(node.get(), static_input_shapes, static_output_shapes),
ov::Exception,
testing::HasSubstr("Eltwise shape infer input shapes dim index:"));

View File

@ -31,11 +31,6 @@ namespace {
#define INTEL_CPU_CUSTOM_SHAPE_INFER(__prim, __type) \
registerNodeIfRequired(intel_cpu, __prim, __type, __prim)
class EltwiseShapeInferTestFactory : public node::EltwiseShapeInferFactory {
public:
EltwiseShapeInferTestFactory(std::shared_ptr<ov::Node> op) : EltwiseShapeInferFactory() {}
};
class ShapeOfShapeInferTestFactory : public node::ShapeOfShapeInferFactory {
public:
ShapeOfShapeInferTestFactory(std::shared_ptr<ov::Node> op) : ShapeOfShapeInferFactory() {}
@ -45,7 +40,7 @@ class CustomShapeInferFF : public openvino::cc::Factory<Type, ShapeInferFactory*
public:
CustomShapeInferFF():Factory("CpuCustomShapeInferTestFactory") {
INTEL_CPU_CUSTOM_SHAPE_INFER(node::AdaptivePoolingShapeInferFactory, Type::AdaptivePooling);
INTEL_CPU_CUSTOM_SHAPE_INFER(EltwiseShapeInferTestFactory, Type::Eltwise);
INTEL_CPU_CUSTOM_SHAPE_INFER(node::EltwiseShapeInferFactory, Type::Eltwise);
INTEL_CPU_CUSTOM_SHAPE_INFER(node::FCShapeInferFactory, Type::FullyConnected);
INTEL_CPU_CUSTOM_SHAPE_INFER(node::TransposeShapeInferFactory, Type::Transpose);
INTEL_CPU_CUSTOM_SHAPE_INFER(ShapeOfShapeInferTestFactory, Type::ShapeOf);