diff --git a/ngraph/core/include/ngraph/op/einsum.hpp b/ngraph/core/include/ngraph/op/einsum.hpp new file mode 100644 index 00000000000..08f066823e9 --- /dev/null +++ b/ngraph/core/include/ngraph/op/einsum.hpp @@ -0,0 +1,69 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "ngraph/node.hpp" +#include "ngraph/op/op.hpp" + +namespace ngraph +{ + namespace op + { + namespace v7 + { + /// \brief Einsum operation. + class NGRAPH_API Einsum : public Op + { + public: + NGRAPH_RTTI_DECLARATION; + + Einsum() = default; + + /// + /// \brief Constructs Einsum operation. + /// + /// \param inputs Input nodes on which Einsum operation performs + /// contraction + /// + /// \param equation Einstein summation convention + /// + Einsum(const OutputVector& inputs, const std::string& equation); + + void validate_and_infer_types() override; + + bool visit_attributes(AttributeVisitor& visitor) override; + + std::shared_ptr + clone_with_new_inputs(const OutputVector& new_args) const override; + + /// \brief Check correctness of equation format and extract input subscripts + /// and output subscript + /// + /// \param equation Equation to be parsed and checked + /// + /// \param input_subscripts A vector of extracted input subscripts + /// + /// \param output_subscript An output subscript + /// + static void parse_equation(const std::string& equation, + std::vector& input_subscripts, + std::string& output_subscript); + + /// \brief Extract labels (from subscript) that can be alphabetic letters or + /// ellipsis + /// + /// \param subscript Subscript + /// + /// \return A vector of extracted labels from the input subscript in the order + /// of appearence + /// + static std::vector extract_labels(const std::string& subscript); + + private: + std::string m_equation; + }; + } // namespace v7 + } // namespace op +} // namespace ngraph diff --git a/ngraph/core/include/ngraph/ops.hpp b/ngraph/core/include/ngraph/ops.hpp index b5a9016c402..5a4ad078616 100644 --- a/ngraph/core/include/ngraph/ops.hpp +++ b/ngraph/core/include/ngraph/ops.hpp @@ -41,6 +41,7 @@ #include "ngraph/op/detection_output.hpp" #include "ngraph/op/dft.hpp" #include "ngraph/op/divide.hpp" +#include "ngraph/op/einsum.hpp" #include "ngraph/op/elu.hpp" #include "ngraph/op/embedding_segments_sum.hpp" #include "ngraph/op/embeddingbag_offsets_sum.hpp" diff --git a/ngraph/core/include/ngraph/opsets/opset7_tbl.hpp b/ngraph/core/include/ngraph/opsets/opset7_tbl.hpp index 8a3d0d6ef9b..b35dba55a1d 100644 --- a/ngraph/core/include/ngraph/opsets/opset7_tbl.hpp +++ b/ngraph/core/include/ngraph/opsets/opset7_tbl.hpp @@ -171,6 +171,7 @@ NGRAPH_OP(ReadValue, ngraph::op::v6) // new version // New operations added in opset7 NGRAPH_OP(DFT, ngraph::op::v7) +NGRAPH_OP(Einsum, ngraph::op::v7) NGRAPH_OP(Gelu, ngraph::op::v7) NGRAPH_OP(IDFT, ngraph::op::v7) NGRAPH_OP(Roll, ngraph::op::v7) diff --git a/ngraph/core/src/op/einsum.cpp b/ngraph/core/src/op/einsum.cpp new file mode 100644 index 00000000000..fbf52ef888b --- /dev/null +++ b/ngraph/core/src/op/einsum.cpp @@ -0,0 +1,308 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include + +#include "itt.hpp" +#include "ngraph/op/einsum.hpp" + +using namespace std; +using namespace ngraph; + +NGRAPH_RTTI_DEFINITION(op::v7::Einsum, "Einsum", 7); + +op::v7::Einsum::Einsum(const OutputVector& inputs, const std::string& equation) + : Op(inputs) + , m_equation(equation) +{ + // normalize input equation by removing extra white-spaces from the equation + m_equation.erase(std::remove_if(m_equation.begin(), m_equation.end(), ::isspace), + m_equation.end()); + + constructor_validate_and_infer_types(); +} + +/// \brief Check that a subscript contains only alphabetic letters or +/// alphabetic letters with one ellipsis +/// +/// \param subscripts A subscript to check its format +/// +/// \param is_ellipsis_met Marker if ellipsis is met in the subscript +/// +/// \return true - correct subscript, false - otherwise +/// +bool is_subscript_correct(const std::string& subscript, bool& is_ellipsis_met) +{ + is_ellipsis_met = false; + auto subscript_length = subscript.length(); + for (size_t ch_idx = 0; ch_idx < subscript_length; ++ch_idx) + { + if (is_ellipsis_met == false && ((subscript_length - ch_idx) > 2) && + (subscript.substr(ch_idx, 3).compare("...") == 0)) + { + // mark that ellipsis is met once + is_ellipsis_met = true; + + // make additional increment since ellipsis consists of three dots. + ch_idx += 2; + } + else if (std::isalpha(subscript[ch_idx]) == 0) + { + return false; + } + } + + return true; +} + +void op::v7::Einsum::parse_equation(const std::string& equation, + std::vector& input_subscripts, + std::string& output_subscript) +{ + NGRAPH_OP_SCOPE(v7_Einsum_parse_equation); + + // split equation to input subscripts and an output subscript + auto pos_output_delimeter = equation.find("->"); + auto input_subscripts_str = equation.substr(0, pos_output_delimeter); + + // split the input subscripts into a vector of input subscripts + bool is_ellipsis_met = false; + input_subscripts.clear(); + std::istringstream input; + input.str(input_subscripts_str); + for (std::string input_subscript; std::getline(input, input_subscript, ',');) + { + bool local_is_ellipsis_met = false; + // check that input subscript contains only alphabetic letter or ellipsis + NGRAPH_CHECK(is_subscript_correct(input_subscript, local_is_ellipsis_met), + "Input subscript of Einsum equation must consist of either only " + "alphabetic letters or alphabetic letters with one ellipsis."); + + // mark that ellipsis is met at least in one input subscript + if (local_is_ellipsis_met) + { + is_ellipsis_met = true; + } + input_subscripts.push_back(input_subscript); + } + + if (pos_output_delimeter == std::string::npos) + { + // recover output subscript + output_subscript = ""; + for (auto const& input_subscript : input_subscripts) + { + for (auto const& label : input_subscript) + { + if (std::isalpha(label) && output_subscript.find(label) == std::string::npos) + { + output_subscript += label; + } + } + } + std::sort(output_subscript.begin(), output_subscript.end()); + if (is_ellipsis_met) + { + output_subscript = "..." + output_subscript; + } + } + else + { + output_subscript = equation.substr(pos_output_delimeter + 2); + bool output_is_ellipsis_met = false; + + // check that the output subscript has the correct format + NGRAPH_CHECK(is_subscript_correct(output_subscript, output_is_ellipsis_met), + "Output subscript of Einsum equation must consist of either only " + "alphabetic letters or alphabetic letters with one ellipsis."); + + // if the ellipsis is met in input subscripts, one ellipsis must be in the output subscript + NGRAPH_CHECK(is_ellipsis_met == output_is_ellipsis_met, + "Output subscript of Einsum equation must contain one ellipsis if " + "ellipsis is met in any input subscript."); + } +} + +std::vector op::v7::Einsum::extract_labels(const std::string& subscript) +{ + NGRAPH_OP_SCOPE(v7_Einsum_extract_labels); + + std::vector labels; + labels.clear(); + auto subscript_length = subscript.length(); + for (size_t ch_idx = 0; ch_idx < subscript_length; ++ch_idx) + { + if (std::isalpha(subscript[ch_idx])) + { + labels.push_back(subscript.substr(ch_idx, 1)); + } + else if (((subscript_length - ch_idx) > 2) && + (subscript.substr(ch_idx, 3).compare("...") == 0)) + { + labels.push_back("..."); + // make additional increment since ellipsis consists of three dots. + ch_idx += 2; + } + else + { + NGRAPH_CHECK(false, "Einsum equation has invalid label."); + } + } + + return labels; +} + +void op::v7::Einsum::validate_and_infer_types() +{ + NGRAPH_OP_SCOPE(v7_Einsum_validate_and_infer_types); + + // check that Einsum operation has at least one input + auto num_inputs = get_input_size(); + NODE_VALIDATION_CHECK(this, num_inputs > 0, "Einsum must have at least one input."); + + // check that all inputs have the same type and the type is numeric + const auto& input_type_0 = get_input_element_type(0); + NODE_VALIDATION_CHECK(this, + input_type_0.is_real() || input_type_0.is_integral_number(), + "The input type for Einsum operation must be numeric."); + for (size_t input_idx = 1; input_idx < num_inputs; ++input_idx) + { + const auto& input_type_i = get_input_element_type(input_idx); + NODE_VALIDATION_CHECK(this, + input_type_0 == input_type_i, + "Inputs to Einsum operation must have the same type."); + } + + // check that equation has correct format and extract input and output subscripts + std::vector input_subscripts; + std::string output_subscript; + parse_equation(m_equation, input_subscripts, output_subscript); + + // a number of input subscripts must match with a number of input tensors + NODE_VALIDATION_CHECK( + this, + input_subscripts.size() == num_inputs, + "Equation must contain a number of subscripts equal to a number of Einsum inputs."); + + // create a dictionary with dimension sizes (or ranges in case dynamic shapes) for each label + // and check their compatibility in case repeating labels + unordered_map label_to_shape; + label_to_shape.clear(); + + for (size_t input_idx = 0; input_idx < num_inputs; ++input_idx) + { + const auto& pshape = get_input_partial_shape(input_idx); + std::vector labels; + labels = extract_labels(input_subscripts[input_idx]); + + if (pshape.rank().is_static()) + { + size_t input_rank = pshape.rank().get_length(); + // check that a rank is greater or equal to a number of labels + // these numbers are always equal if there is no ellipsis in the subscript + NODE_VALIDATION_CHECK( + this, + input_rank >= labels.size(), + "Input rank must be greater or equal to a number of labels in the " + "corresponding input subscript."); + + for (size_t label_ind = 0, dim_ind = 0; + label_ind < labels.size() && dim_ind < input_rank; + ++label_ind) + { + auto const& label = labels[label_ind]; + if (label.compare("...") == 0) + { + size_t num_broadcasted_dims = input_rank - labels.size() + 1; + auto current_sub_pshape = PartialShape(std::vector( + pshape.begin() + dim_ind, pshape.begin() + dim_ind + num_broadcasted_dims)); + if (label_to_shape.find(label) == label_to_shape.end()) + { + label_to_shape[label] = current_sub_pshape; + } + else + { + bool is_broadcast_success = + PartialShape::broadcast_merge_into(label_to_shape[label], + current_sub_pshape, + op::AutoBroadcastType::NUMPY); + NODE_VALIDATION_CHECK(this, + is_broadcast_success, + "Input dimensions labeled with ellipsis for Einsum " + "must be broadcastable."); + } + dim_ind += num_broadcasted_dims; + } + else + { + if (label_to_shape.find(label) == label_to_shape.end()) + { + label_to_shape[label] = PartialShape{pshape[dim_ind]}; + } + else + { + NODE_VALIDATION_CHECK( + this, + label_to_shape[label].compatible(PartialShape{pshape[label_ind]}), + "Different input dimensions indicated by the same labels for Einsum " + "must be compatible."); + PartialShape::merge_into(label_to_shape[label], + PartialShape{pshape[dim_ind]}); + } + ++dim_ind; + } + } + } + else + { + for (auto const& label : labels) + { + NODE_VALIDATION_CHECK(this, + label != "...", + "The subscript corresponding to a dynamic rank input must " + "not contain ellipsis."); + + if (label_to_shape.find(label) == label_to_shape.end()) + { + label_to_shape[label] = PartialShape{Dimension::dynamic()}; + } + } + } + } + + // compute the output shape + std::vector output_labels; + output_labels = extract_labels(output_subscript); + std::vector output_pshape_vector; + + for (auto const& output_label : output_labels) + { + NODE_VALIDATION_CHECK(this, + label_to_shape.find(output_label) != label_to_shape.end(), + "Label in output subscript of Einsum equation must enter at least " + "one input subscript."); + output_pshape_vector.insert(output_pshape_vector.end(), + label_to_shape[output_label].begin(), + label_to_shape[output_label].end()); + } + set_output_type(0, input_type_0, PartialShape(output_pshape_vector)); +} + +bool op::v7::Einsum::visit_attributes(AttributeVisitor& visitor) +{ + NGRAPH_OP_SCOPE(v7_Einsum_visit_attributes); + visitor.on_attribute("equation", m_equation); + return true; +} + +shared_ptr op::v7::Einsum::clone_with_new_inputs(const OutputVector& new_args) const +{ + NGRAPH_OP_SCOPE(v7_Einsum_clone_with_new_inputs); + check_new_args_count(this, new_args); + return make_shared(new_args, m_equation); +} diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index de311ee3261..30246008c0d 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -115,6 +115,7 @@ set(SRC type_prop/depth_to_space.cpp type_prop/dft.cpp type_prop/dyn_reshape.cpp + type_prop/einsum.cpp type_prop/experimental_detectron_generate_proposals.cpp type_prop/experimental_detectron_roi_feature_extractor.cpp type_prop/experimental_detectron_topkrois.cpp diff --git a/ngraph/test/type_prop/einsum.cpp b/ngraph/test/type_prop/einsum.cpp new file mode 100644 index 00000000000..a65fb0677f4 --- /dev/null +++ b/ngraph/test/type_prop/einsum.cpp @@ -0,0 +1,349 @@ +// Copyright (C) 2018-2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "gtest/gtest.h" +#include "ngraph/ngraph.hpp" +#include "util/type_prop.hpp" + +using namespace std; +using namespace ngraph; + +TEST(type_prop, einsum_staticshape_dotproduct) +{ + std::string equation = "i,i->"; + Shape input1_shape{3}; + Shape input2_shape{3}; + Shape out_shape{}; + auto I1 = make_shared(element::f32, input1_shape); + auto I2 = make_shared(element::f32, input2_shape); + auto O = make_shared(OutputVector{I1, I2}, equation); + ASSERT_EQ(O->get_element_type(), element::f32); + ASSERT_EQ(O->get_shape(), out_shape); +} + +TEST(type_prop, einsum_staticshape_matmul) +{ + std::string equation = "ab,bc->ac"; + Shape input1_shape{2, 3}; + Shape input2_shape{3, 4}; + Shape out_shape{2, 4}; + auto I1 = make_shared(element::f32, input1_shape); + auto I2 = make_shared(element::f32, input2_shape); + auto O = make_shared(OutputVector{I1, I2}, equation); + ASSERT_EQ(O->get_element_type(), element::f32); + ASSERT_EQ(O->get_shape(), out_shape); +} + +TEST(type_prop, einsum_staticshape_trace) +{ + std::string equation = "kii->k"; + Shape input1_shape{2, 3, 3}; + Shape out_shape{2}; + auto I1 = make_shared(element::f32, input1_shape); + auto O = make_shared(OutputVector{I1}, equation); + ASSERT_EQ(O->get_element_type(), element::f32); + ASSERT_EQ(O->get_shape(), out_shape); +} + +TEST(type_prop, einsum_staticshape_diagextraction) +{ + std::string equation = "kii->ki"; + Shape input1_shape{2, 3, 3}; + Shape out_shape{2, 3}; + auto I1 = make_shared(element::f32, input1_shape); + auto O = make_shared(OutputVector{I1}, equation); + ASSERT_EQ(O->get_element_type(), element::f32); + ASSERT_EQ(O->get_shape(), out_shape); +} + +TEST(type_prop, einsum_staticshape_transpose) +{ + std::string equation = "ijk->kij"; + Shape input1_shape{1, 2, 3}; + Shape out_shape{3, 1, 2}; + auto I1 = make_shared(element::f32, input1_shape); + auto O = make_shared(OutputVector{I1}, equation); + ASSERT_EQ(O->get_element_type(), element::f32); + ASSERT_EQ(O->get_shape(), out_shape); +} + +TEST(type_prop, einsum_staticshape_multimatmul) +{ + std::string equation = "ab,bcd,bc->ca"; + Shape input1_shape{2, 5}; + Shape input2_shape{5, 3, 6}; + Shape input3_shape{5, 3}; + Shape out_shape{3, 2}; + auto I1 = make_shared(element::i32, input1_shape); + auto I2 = make_shared(element::i32, input2_shape); + auto I3 = make_shared(element::i32, input3_shape); + auto O = make_shared(OutputVector{I1, I2, I3}, equation); + ASSERT_EQ(O->get_element_type(), element::i32); + ASSERT_EQ(O->get_shape(), out_shape); +} + +TEST(type_prop, einsum_staticshape_ellipsis) +{ + std::string equation = "a...->..."; + Shape input1_shape{5, 3}; + Shape out_shape{3}; + auto I1 = make_shared(element::f32, input1_shape); + auto O = make_shared(OutputVector{I1}, equation); + ASSERT_EQ(O->get_element_type(), element::f32); + ASSERT_EQ(O->get_shape(), out_shape); +} + +TEST(type_prop, einsum_staticshape_ellipsis2) +{ + std::string equation = "a...,...->a..."; + Shape input1_shape{3, 5}; + Shape input2_shape{1}; + Shape out_shape{3, 5}; + auto I1 = make_shared(element::f32, input1_shape); + auto I2 = make_shared(element::f32, input2_shape); + auto O = make_shared(OutputVector{I1, I2}, equation); + ASSERT_EQ(O->get_element_type(), element::f32); + ASSERT_EQ(O->get_shape(), out_shape); +} + +TEST(type_prop, einsum_staticshape_ellipsis3) +{ + std::string equation = "a...b,b...->a..."; + Shape input1_shape{11, 1, 4, 3}; + Shape input2_shape{3, 11, 7, 1}; + Shape out_shape{11, 11, 7, 4}; + auto I1 = make_shared(element::i32, input1_shape); + auto I2 = make_shared(element::i32, input2_shape); + auto O = make_shared(OutputVector{I1, I2}, equation); + ASSERT_EQ(O->get_element_type(), element::i32); + ASSERT_EQ(O->get_shape(), out_shape); +} + +TEST(type_prop, einsum_dynamicshape_dotproduct) +{ + std::string equation = "a,ab->ab"; + const auto input1_shape = PartialShape{Dimension(2, 7)}; + const auto input2_shape = PartialShape{Dimension(3, 10), 3}; + const auto out_shape = PartialShape{Dimension(3, 7), 3}; + auto I1 = make_shared(element::f32, input1_shape); + auto I2 = make_shared(element::f32, input2_shape); + auto O = make_shared(OutputVector{I1, I2}, equation); + ASSERT_EQ(O->get_element_type(), element::f32); + ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape)); +} + +TEST(type_prop, einsum_dynamicshape_diagextraction) +{ + std::string equation = "xyzxy->xyz"; + const auto input1_shape = PartialShape{Dimension(2, 7), Dimension(1, 5), 4, Dimension(3, 5), 3}; + const auto out_shape = PartialShape{Dimension(3, 5), 3, 4}; + auto I1 = make_shared(element::i32, input1_shape); + auto O = make_shared(OutputVector{I1}, equation); + ASSERT_EQ(O->get_element_type(), element::i32); + ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape)); +} + +TEST(type_prop, DISABLED_einsum_dynamicshape_ellipsis1) +{ + // TODO: fix bug #53518 - PartialShape::broadcast_merge_into or Dimension::broadcast_merge + // to support broadcasting between Dimension(3, 5) and Dimension(1, 3) + // for which the result must be Dimension(3, 5) + std::string equation = "a...b,b...->a..."; + const auto input1_shape = PartialShape{11, 1, Dimension(3, 5), 3}; + const auto input2_shape = PartialShape{3, 11, 7, Dimension(1, 3)}; + const auto out_shape = PartialShape{11, 11, 7, Dimension(3, 5)}; + auto I1 = make_shared(element::f32, input1_shape); + auto I2 = make_shared(element::f32, input2_shape); + auto O = make_shared(OutputVector{I1, I2}, equation); + ASSERT_EQ(O->get_element_type(), element::f32); + ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape)); +} + +TEST(type_prop, einsum_implicitmode_mixedcaseletters) +{ + // the following equation is equivalent to "AbC->ACb" + std::string equation = "AbC"; + const auto input1_shape = PartialShape{1, Dimension(2, 3), Dimension(4, 5)}; + auto I1 = make_shared(element::f32, input1_shape); + const auto out_shape = PartialShape{1, Dimension(4, 5), Dimension(2, 3)}; + auto O = make_shared(OutputVector{I1}, equation); + ASSERT_EQ(O->get_element_type(), element::f32); + ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape)); +} + +TEST(type_prop, einsum_implicitmode_mixedcaseletters2) +{ + // the following equation is equivalent to "a...b,B...->...Bab" + std::string equation = "a...b,B..."; + const auto input1_shape = PartialShape{Dimension(3, 5), 11, 1, 3}; + const auto input2_shape = PartialShape{Dimension(1, 3), 3, 1, 7}; + const auto out_shape = PartialShape{3, 11, 7, Dimension(1, 3), Dimension(3, 5), 3}; + auto I1 = make_shared(element::f32, input1_shape); + auto I2 = make_shared(element::f32, input2_shape); + auto O = make_shared(OutputVector{I1, I2}, equation); + ASSERT_EQ(O->get_element_type(), element::f32); + ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape)); +} + +TEST(type_prop, einsum_dynamicrank_multimatmul) +{ + std::string equation = "ab,bcd,bc->ca"; + Shape input1_shape{2, 5}; + PartialShape input2_shape = PartialShape::dynamic(); + Shape input3_shape{5, 3}; + Shape out_shape{3, 2}; + auto I1 = make_shared(element::i32, input1_shape); + auto I2 = make_shared(element::i32, input2_shape); + auto I3 = make_shared(element::i32, input3_shape); + auto O = make_shared(OutputVector{I1, I2, I3}, equation); + ASSERT_EQ(O->get_element_type(), element::i32); + ASSERT_EQ(O->get_shape(), out_shape); +} + +TEST(type_prop, einsum_dynamicrank_multimatmul2) +{ + std::string equation = "ab,bcd,bc->ca"; + PartialShape input1_shape = PartialShape::dynamic(); + PartialShape input2_shape = PartialShape::dynamic(); + PartialShape input3_shape = PartialShape::dynamic(); + PartialShape out_shape{Dimension(), Dimension()}; + auto I1 = make_shared(element::i32, input1_shape); + auto I2 = make_shared(element::i32, input2_shape); + auto I3 = make_shared(element::i32, input3_shape); + auto O = make_shared(OutputVector{I1, I2, I3}, equation); + ASSERT_EQ(O->get_element_type(), element::i32); + ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape)); +} + +TEST(type_prop, einsum_incorrectequation_subscriptnumber) +{ + std::string equation = "ab,bc,cd->ac"; + Shape input1_shape{2, 3}; + Shape input2_shape{3, 4}; + auto I1 = make_shared(element::f32, input1_shape); + auto I2 = make_shared(element::f32, input2_shape); + + try + { + auto O = make_shared(OutputVector{I1, I2}, equation); + // Should have thrown, so fail if it didn't + FAIL() << "Incorrect number of input subscripts"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("Equation must contain a number of subscripts equal to a " + "number of Einsum inputs.")); + } + catch (...) + { + FAIL() << "Equation format check failed"; + } +} + +TEST(type_prop, einsum_incorrectequation_invalidlabels) +{ + std::string equation = "a$,Bc->ac"; + Shape input1_shape{2, 3}; + Shape input2_shape{3, 4}; + auto I1 = make_shared(element::f32, input1_shape); + auto I2 = make_shared(element::f32, input2_shape); + + try + { + auto O = make_shared(OutputVector{I1, I2}, equation); + // Should have thrown, so fail if it didn't + FAIL() << "Incorrect number of input subscripts"; + } + catch (const CheckFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string("Input subscript of Einsum equation must consist of either only alphabetic " + "letters or alphabetic letters with one ellipsis.")); + } + catch (...) + { + FAIL() << "Equation format check failed"; + } +} + +TEST(type_prop, einsum_incorrectequation_incompatibleshapes) +{ + std::string equation = "ab,bc->ac"; + Shape input1_shape{2, 10}; + Shape input2_shape{3, 4}; + auto I1 = make_shared(element::f32, input1_shape); + auto I2 = make_shared(element::f32, input2_shape); + + try + { + auto O = make_shared(OutputVector{I1, I2}, equation); + // Should have thrown, so fail if it didn't + FAIL() << "Incompatible dimension indicated by the same labels"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("Different input dimensions indicated by the same labels " + "for Einsum must be compatible.")); + } + catch (...) + { + FAIL() << "Equation format check failed"; + } +} + +TEST(type_prop, einsum_incorrectequation_notbroadcastableshapes) +{ + std::string equation = "a...b,b...->a..."; + Shape input1_shape{11, 1, 4, 3}; + Shape input2_shape{3, 11, 7, 5}; + auto I1 = make_shared(element::f32, input1_shape); + auto I2 = make_shared(element::f32, input2_shape); + + try + { + auto O = make_shared(OutputVector{I1, I2}, equation); + // Should have thrown, so fail if it didn't + FAIL() << "Non-broadcastable shapes covered by ellipsis"; + } + catch (const NodeValidationFailure& error) + { + EXPECT_HAS_SUBSTRING( + error.what(), + std::string( + "Input dimensions labeled with ellipsis for Einsum must be broadcastable.")); + } + catch (...) + { + FAIL() << "Equation format check failed"; + } +} + +TEST(type_prop, einsum_incorrectequation_missedellipsis) +{ + std::string equation = "a...b,b...->a"; + Shape input1_shape{11, 1, 4, 3}; + Shape input2_shape{3, 11, 7, 5}; + auto I1 = make_shared(element::f32, input1_shape); + auto I2 = make_shared(element::f32, input2_shape); + + try + { + auto O = make_shared(OutputVector{I1, I2}, equation); + // Should have thrown, so fail if it didn't + FAIL() << "Non-broadcastable shapes covered by ellipsis"; + } + catch (const CheckFailure& error) + { + EXPECT_HAS_SUBSTRING(error.what(), + std::string("Output subscript of Einsum equation must contain one " + "ellipsis if ellipsis is met in any input subscript.")); + } + catch (...) + { + FAIL() << "Equation format check failed"; + } +}