Enable einsum shape inferenxe test (#10603)

This commit is contained in:
Evgenya Stepyreva 2022-02-23 18:29:12 +03:00 committed by GitHub
parent 9dec8db964
commit e544f5e66f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -120,7 +120,7 @@ TEST(type_prop, einsum_dynamicshape_dotproduct) {
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape); auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation); auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
ASSERT_EQ(O->get_element_type(), element::f32); ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape)); ASSERT_EQ(O->get_output_partial_shape(0), out_shape);
} }
TEST(type_prop, einsum_dynamicshape_diagextraction) { TEST(type_prop, einsum_dynamicshape_diagextraction) {
@ -130,13 +130,10 @@ TEST(type_prop, einsum_dynamicshape_diagextraction) {
auto I1 = make_shared<op::Parameter>(element::i32, input1_shape); auto I1 = make_shared<op::Parameter>(element::i32, input1_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation); auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation);
ASSERT_EQ(O->get_element_type(), element::i32); ASSERT_EQ(O->get_element_type(), element::i32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape)); ASSERT_EQ(O->get_output_partial_shape(0), out_shape);
} }
TEST(type_prop, DISABLED_einsum_dynamicshape_ellipsis1) { TEST(type_prop, einsum_dynamicshape_ellipsis1) {
// TODO: fix bug #53518 - PartialShape::broadcast_merge_into or Dimension::broadcast_merge
// to support broadcasting between Dimension(3, 5) and Dimension(1, 3)
// for which the result must be Dimension(3, 5)
std::string equation = "a...b,b...->a..."; std::string equation = "a...b,b...->a...";
const auto input1_shape = PartialShape{11, 1, Dimension(3, 5), 3}; const auto input1_shape = PartialShape{11, 1, Dimension(3, 5), 3};
const auto input2_shape = PartialShape{3, 11, 7, Dimension(1, 3)}; const auto input2_shape = PartialShape{3, 11, 7, Dimension(1, 3)};
@ -145,7 +142,7 @@ TEST(type_prop, DISABLED_einsum_dynamicshape_ellipsis1) {
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape); auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation); auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
ASSERT_EQ(O->get_element_type(), element::f32); ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape)); ASSERT_EQ(O->get_output_partial_shape(0), out_shape);
} }
TEST(type_prop, einsum_implicitmode_mixedcaseletters) { TEST(type_prop, einsum_implicitmode_mixedcaseletters) {
@ -156,7 +153,7 @@ TEST(type_prop, einsum_implicitmode_mixedcaseletters) {
const auto out_shape = PartialShape{1, Dimension(4, 5), Dimension(2, 3)}; const auto out_shape = PartialShape{1, Dimension(4, 5), Dimension(2, 3)};
auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation); auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation);
ASSERT_EQ(O->get_element_type(), element::f32); ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape)); ASSERT_EQ(O->get_output_partial_shape(0), out_shape);
} }
TEST(type_prop, einsum_implicitmode_mixedcaseletters2) { TEST(type_prop, einsum_implicitmode_mixedcaseletters2) {
@ -169,7 +166,7 @@ TEST(type_prop, einsum_implicitmode_mixedcaseletters2) {
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape); auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation); auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
ASSERT_EQ(O->get_element_type(), element::f32); ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape)); ASSERT_EQ(O->get_output_partial_shape(0), out_shape);
} }
TEST(type_prop, einsum_implicitmode_repeatedlabels) { TEST(type_prop, einsum_implicitmode_repeatedlabels) {
@ -182,7 +179,7 @@ TEST(type_prop, einsum_implicitmode_repeatedlabels) {
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape); auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation); auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
ASSERT_EQ(O->get_element_type(), element::f32); ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape)); ASSERT_EQ(O->get_output_partial_shape(0), out_shape);
} }
TEST(type_prop, einsum_implicitmode_innerprod) { TEST(type_prop, einsum_implicitmode_innerprod) {
@ -195,7 +192,7 @@ TEST(type_prop, einsum_implicitmode_innerprod) {
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape); auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation); auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
ASSERT_EQ(O->get_element_type(), element::f32); ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape)); ASSERT_EQ(O->get_output_partial_shape(0), out_shape);
} }
TEST(type_prop, einsum_dynamicrank_multimatmul) { TEST(type_prop, einsum_dynamicrank_multimatmul) {
@ -223,7 +220,7 @@ TEST(type_prop, einsum_dynamicrank_multimatmul2) {
auto I3 = make_shared<op::Parameter>(element::i32, input3_shape); auto I3 = make_shared<op::Parameter>(element::i32, input3_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2, I3}, equation); auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2, I3}, equation);
ASSERT_EQ(O->get_element_type(), element::i32); ASSERT_EQ(O->get_element_type(), element::i32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape)); ASSERT_EQ(O->get_output_partial_shape(0), out_shape);
} }
TEST(type_prop, einsum_incorrectequation_subscriptnumber) { TEST(type_prop, einsum_incorrectequation_subscriptnumber) {