Implement nGraph shell for Einsum-7 (#5282)

* Implement nGraph shell for Einsum-7

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Correct doxygen formats

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Apply clang format change

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Support implicit mode and capital letters

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Correct and optimize the code based on review

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Correct private methods and its API, add more tests

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Make equation aux methods public and remove regex usage

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Make is_subscript_correct function local

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Correct check for missed ellipsis and add test for it

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2021-04-24 09:41:48 +03:00 committed by GitHub
parent 7d0cae8bb5
commit fcea3f8a0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 729 additions and 0 deletions

View File

@ -0,0 +1,69 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
{
namespace op
{
namespace v7
{
/// \brief Einsum operation.
class NGRAPH_API Einsum : public Op
{
public:
NGRAPH_RTTI_DECLARATION;
Einsum() = default;
///
/// \brief Constructs Einsum operation.
///
/// \param inputs Input nodes on which Einsum operation performs
/// contraction
///
/// \param equation Einstein summation convention
///
Einsum(const OutputVector& inputs, const std::string& equation);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
/// \brief Check correctness of equation format and extract input subscripts
/// and output subscript
///
/// \param equation Equation to be parsed and checked
///
/// \param input_subscripts A vector of extracted input subscripts
///
/// \param output_subscript An output subscript
///
static void parse_equation(const std::string& equation,
std::vector<std::string>& input_subscripts,
std::string& output_subscript);
/// \brief Extract labels (from subscript) that can be alphabetic letters or
/// ellipsis
///
/// \param subscript Subscript
///
/// \return A vector of extracted labels from the input subscript in the order
/// of appearence
///
static std::vector<std::string> extract_labels(const std::string& subscript);
private:
std::string m_equation;
};
} // namespace v7
} // namespace op
} // namespace ngraph

View File

@ -41,6 +41,7 @@
#include "ngraph/op/detection_output.hpp" #include "ngraph/op/detection_output.hpp"
#include "ngraph/op/dft.hpp" #include "ngraph/op/dft.hpp"
#include "ngraph/op/divide.hpp" #include "ngraph/op/divide.hpp"
#include "ngraph/op/einsum.hpp"
#include "ngraph/op/elu.hpp" #include "ngraph/op/elu.hpp"
#include "ngraph/op/embedding_segments_sum.hpp" #include "ngraph/op/embedding_segments_sum.hpp"
#include "ngraph/op/embeddingbag_offsets_sum.hpp" #include "ngraph/op/embeddingbag_offsets_sum.hpp"

View File

@ -171,6 +171,7 @@ NGRAPH_OP(ReadValue, ngraph::op::v6) // new version
// New operations added in opset7 // New operations added in opset7
NGRAPH_OP(DFT, ngraph::op::v7) NGRAPH_OP(DFT, ngraph::op::v7)
NGRAPH_OP(Einsum, ngraph::op::v7)
NGRAPH_OP(Gelu, ngraph::op::v7) NGRAPH_OP(Gelu, ngraph::op::v7)
NGRAPH_OP(IDFT, ngraph::op::v7) NGRAPH_OP(IDFT, ngraph::op::v7)
NGRAPH_OP(Roll, ngraph::op::v7) NGRAPH_OP(Roll, ngraph::op::v7)

View File

@ -0,0 +1,308 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <algorithm>
#include <cctype>
#include <ngraph/validation_util.hpp>
#include <string>
#include <unordered_map>
#include "itt.hpp"
#include "ngraph/op/einsum.hpp"
using namespace std;
using namespace ngraph;
NGRAPH_RTTI_DEFINITION(op::v7::Einsum, "Einsum", 7);
op::v7::Einsum::Einsum(const OutputVector& inputs, const std::string& equation)
: Op(inputs)
, m_equation(equation)
{
// normalize input equation by removing extra white-spaces from the equation
m_equation.erase(std::remove_if(m_equation.begin(), m_equation.end(), ::isspace),
m_equation.end());
constructor_validate_and_infer_types();
}
/// \brief Check that a subscript contains only alphabetic letters or
/// alphabetic letters with one ellipsis
///
/// \param subscripts A subscript to check its format
///
/// \param is_ellipsis_met Marker if ellipsis is met in the subscript
///
/// \return true - correct subscript, false - otherwise
///
bool is_subscript_correct(const std::string& subscript, bool& is_ellipsis_met)
{
is_ellipsis_met = false;
auto subscript_length = subscript.length();
for (size_t ch_idx = 0; ch_idx < subscript_length; ++ch_idx)
{
if (is_ellipsis_met == false && ((subscript_length - ch_idx) > 2) &&
(subscript.substr(ch_idx, 3).compare("...") == 0))
{
// mark that ellipsis is met once
is_ellipsis_met = true;
// make additional increment since ellipsis consists of three dots.
ch_idx += 2;
}
else if (std::isalpha(subscript[ch_idx]) == 0)
{
return false;
}
}
return true;
}
void op::v7::Einsum::parse_equation(const std::string& equation,
std::vector<std::string>& input_subscripts,
std::string& output_subscript)
{
NGRAPH_OP_SCOPE(v7_Einsum_parse_equation);
// split equation to input subscripts and an output subscript
auto pos_output_delimeter = equation.find("->");
auto input_subscripts_str = equation.substr(0, pos_output_delimeter);
// split the input subscripts into a vector of input subscripts
bool is_ellipsis_met = false;
input_subscripts.clear();
std::istringstream input;
input.str(input_subscripts_str);
for (std::string input_subscript; std::getline(input, input_subscript, ',');)
{
bool local_is_ellipsis_met = false;
// check that input subscript contains only alphabetic letter or ellipsis
NGRAPH_CHECK(is_subscript_correct(input_subscript, local_is_ellipsis_met),
"Input subscript of Einsum equation must consist of either only "
"alphabetic letters or alphabetic letters with one ellipsis.");
// mark that ellipsis is met at least in one input subscript
if (local_is_ellipsis_met)
{
is_ellipsis_met = true;
}
input_subscripts.push_back(input_subscript);
}
if (pos_output_delimeter == std::string::npos)
{
// recover output subscript
output_subscript = "";
for (auto const& input_subscript : input_subscripts)
{
for (auto const& label : input_subscript)
{
if (std::isalpha(label) && output_subscript.find(label) == std::string::npos)
{
output_subscript += label;
}
}
}
std::sort(output_subscript.begin(), output_subscript.end());
if (is_ellipsis_met)
{
output_subscript = "..." + output_subscript;
}
}
else
{
output_subscript = equation.substr(pos_output_delimeter + 2);
bool output_is_ellipsis_met = false;
// check that the output subscript has the correct format
NGRAPH_CHECK(is_subscript_correct(output_subscript, output_is_ellipsis_met),
"Output subscript of Einsum equation must consist of either only "
"alphabetic letters or alphabetic letters with one ellipsis.");
// if the ellipsis is met in input subscripts, one ellipsis must be in the output subscript
NGRAPH_CHECK(is_ellipsis_met == output_is_ellipsis_met,
"Output subscript of Einsum equation must contain one ellipsis if "
"ellipsis is met in any input subscript.");
}
}
std::vector<std::string> op::v7::Einsum::extract_labels(const std::string& subscript)
{
NGRAPH_OP_SCOPE(v7_Einsum_extract_labels);
std::vector<std::string> labels;
labels.clear();
auto subscript_length = subscript.length();
for (size_t ch_idx = 0; ch_idx < subscript_length; ++ch_idx)
{
if (std::isalpha(subscript[ch_idx]))
{
labels.push_back(subscript.substr(ch_idx, 1));
}
else if (((subscript_length - ch_idx) > 2) &&
(subscript.substr(ch_idx, 3).compare("...") == 0))
{
labels.push_back("...");
// make additional increment since ellipsis consists of three dots.
ch_idx += 2;
}
else
{
NGRAPH_CHECK(false, "Einsum equation has invalid label.");
}
}
return labels;
}
void op::v7::Einsum::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(v7_Einsum_validate_and_infer_types);
// check that Einsum operation has at least one input
auto num_inputs = get_input_size();
NODE_VALIDATION_CHECK(this, num_inputs > 0, "Einsum must have at least one input.");
// check that all inputs have the same type and the type is numeric
const auto& input_type_0 = get_input_element_type(0);
NODE_VALIDATION_CHECK(this,
input_type_0.is_real() || input_type_0.is_integral_number(),
"The input type for Einsum operation must be numeric.");
for (size_t input_idx = 1; input_idx < num_inputs; ++input_idx)
{
const auto& input_type_i = get_input_element_type(input_idx);
NODE_VALIDATION_CHECK(this,
input_type_0 == input_type_i,
"Inputs to Einsum operation must have the same type.");
}
// check that equation has correct format and extract input and output subscripts
std::vector<std::string> input_subscripts;
std::string output_subscript;
parse_equation(m_equation, input_subscripts, output_subscript);
// a number of input subscripts must match with a number of input tensors
NODE_VALIDATION_CHECK(
this,
input_subscripts.size() == num_inputs,
"Equation must contain a number of subscripts equal to a number of Einsum inputs.");
// create a dictionary with dimension sizes (or ranges in case dynamic shapes) for each label
// and check their compatibility in case repeating labels
unordered_map<string, PartialShape> label_to_shape;
label_to_shape.clear();
for (size_t input_idx = 0; input_idx < num_inputs; ++input_idx)
{
const auto& pshape = get_input_partial_shape(input_idx);
std::vector<std::string> labels;
labels = extract_labels(input_subscripts[input_idx]);
if (pshape.rank().is_static())
{
size_t input_rank = pshape.rank().get_length();
// check that a rank is greater or equal to a number of labels
// these numbers are always equal if there is no ellipsis in the subscript
NODE_VALIDATION_CHECK(
this,
input_rank >= labels.size(),
"Input rank must be greater or equal to a number of labels in the "
"corresponding input subscript.");
for (size_t label_ind = 0, dim_ind = 0;
label_ind < labels.size() && dim_ind < input_rank;
++label_ind)
{
auto const& label = labels[label_ind];
if (label.compare("...") == 0)
{
size_t num_broadcasted_dims = input_rank - labels.size() + 1;
auto current_sub_pshape = PartialShape(std::vector<Dimension>(
pshape.begin() + dim_ind, pshape.begin() + dim_ind + num_broadcasted_dims));
if (label_to_shape.find(label) == label_to_shape.end())
{
label_to_shape[label] = current_sub_pshape;
}
else
{
bool is_broadcast_success =
PartialShape::broadcast_merge_into(label_to_shape[label],
current_sub_pshape,
op::AutoBroadcastType::NUMPY);
NODE_VALIDATION_CHECK(this,
is_broadcast_success,
"Input dimensions labeled with ellipsis for Einsum "
"must be broadcastable.");
}
dim_ind += num_broadcasted_dims;
}
else
{
if (label_to_shape.find(label) == label_to_shape.end())
{
label_to_shape[label] = PartialShape{pshape[dim_ind]};
}
else
{
NODE_VALIDATION_CHECK(
this,
label_to_shape[label].compatible(PartialShape{pshape[label_ind]}),
"Different input dimensions indicated by the same labels for Einsum "
"must be compatible.");
PartialShape::merge_into(label_to_shape[label],
PartialShape{pshape[dim_ind]});
}
++dim_ind;
}
}
}
else
{
for (auto const& label : labels)
{
NODE_VALIDATION_CHECK(this,
label != "...",
"The subscript corresponding to a dynamic rank input must "
"not contain ellipsis.");
if (label_to_shape.find(label) == label_to_shape.end())
{
label_to_shape[label] = PartialShape{Dimension::dynamic()};
}
}
}
}
// compute the output shape
std::vector<std::string> output_labels;
output_labels = extract_labels(output_subscript);
std::vector<Dimension> output_pshape_vector;
for (auto const& output_label : output_labels)
{
NODE_VALIDATION_CHECK(this,
label_to_shape.find(output_label) != label_to_shape.end(),
"Label in output subscript of Einsum equation must enter at least "
"one input subscript.");
output_pshape_vector.insert(output_pshape_vector.end(),
label_to_shape[output_label].begin(),
label_to_shape[output_label].end());
}
set_output_type(0, input_type_0, PartialShape(output_pshape_vector));
}
bool op::v7::Einsum::visit_attributes(AttributeVisitor& visitor)
{
NGRAPH_OP_SCOPE(v7_Einsum_visit_attributes);
visitor.on_attribute("equation", m_equation);
return true;
}
shared_ptr<Node> op::v7::Einsum::clone_with_new_inputs(const OutputVector& new_args) const
{
NGRAPH_OP_SCOPE(v7_Einsum_clone_with_new_inputs);
check_new_args_count(this, new_args);
return make_shared<v7::Einsum>(new_args, m_equation);
}

View File

@ -115,6 +115,7 @@ set(SRC
type_prop/depth_to_space.cpp type_prop/depth_to_space.cpp
type_prop/dft.cpp type_prop/dft.cpp
type_prop/dyn_reshape.cpp type_prop/dyn_reshape.cpp
type_prop/einsum.cpp
type_prop/experimental_detectron_generate_proposals.cpp type_prop/experimental_detectron_generate_proposals.cpp
type_prop/experimental_detectron_roi_feature_extractor.cpp type_prop/experimental_detectron_roi_feature_extractor.cpp
type_prop/experimental_detectron_topkrois.cpp type_prop/experimental_detectron_topkrois.cpp

View File

@ -0,0 +1,349 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
TEST(type_prop, einsum_staticshape_dotproduct)
{
std::string equation = "i,i->";
Shape input1_shape{3};
Shape input2_shape{3};
Shape out_shape{};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_EQ(O->get_shape(), out_shape);
}
TEST(type_prop, einsum_staticshape_matmul)
{
std::string equation = "ab,bc->ac";
Shape input1_shape{2, 3};
Shape input2_shape{3, 4};
Shape out_shape{2, 4};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_EQ(O->get_shape(), out_shape);
}
TEST(type_prop, einsum_staticshape_trace)
{
std::string equation = "kii->k";
Shape input1_shape{2, 3, 3};
Shape out_shape{2};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation);
ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_EQ(O->get_shape(), out_shape);
}
TEST(type_prop, einsum_staticshape_diagextraction)
{
std::string equation = "kii->ki";
Shape input1_shape{2, 3, 3};
Shape out_shape{2, 3};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation);
ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_EQ(O->get_shape(), out_shape);
}
TEST(type_prop, einsum_staticshape_transpose)
{
std::string equation = "ijk->kij";
Shape input1_shape{1, 2, 3};
Shape out_shape{3, 1, 2};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation);
ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_EQ(O->get_shape(), out_shape);
}
TEST(type_prop, einsum_staticshape_multimatmul)
{
std::string equation = "ab,bcd,bc->ca";
Shape input1_shape{2, 5};
Shape input2_shape{5, 3, 6};
Shape input3_shape{5, 3};
Shape out_shape{3, 2};
auto I1 = make_shared<op::Parameter>(element::i32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::i32, input2_shape);
auto I3 = make_shared<op::Parameter>(element::i32, input3_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2, I3}, equation);
ASSERT_EQ(O->get_element_type(), element::i32);
ASSERT_EQ(O->get_shape(), out_shape);
}
TEST(type_prop, einsum_staticshape_ellipsis)
{
std::string equation = "a...->...";
Shape input1_shape{5, 3};
Shape out_shape{3};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation);
ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_EQ(O->get_shape(), out_shape);
}
TEST(type_prop, einsum_staticshape_ellipsis2)
{
std::string equation = "a...,...->a...";
Shape input1_shape{3, 5};
Shape input2_shape{1};
Shape out_shape{3, 5};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_EQ(O->get_shape(), out_shape);
}
TEST(type_prop, einsum_staticshape_ellipsis3)
{
std::string equation = "a...b,b...->a...";
Shape input1_shape{11, 1, 4, 3};
Shape input2_shape{3, 11, 7, 1};
Shape out_shape{11, 11, 7, 4};
auto I1 = make_shared<op::Parameter>(element::i32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::i32, input2_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
ASSERT_EQ(O->get_element_type(), element::i32);
ASSERT_EQ(O->get_shape(), out_shape);
}
TEST(type_prop, einsum_dynamicshape_dotproduct)
{
std::string equation = "a,ab->ab";
const auto input1_shape = PartialShape{Dimension(2, 7)};
const auto input2_shape = PartialShape{Dimension(3, 10), 3};
const auto out_shape = PartialShape{Dimension(3, 7), 3};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
}
TEST(type_prop, einsum_dynamicshape_diagextraction)
{
std::string equation = "xyzxy->xyz";
const auto input1_shape = PartialShape{Dimension(2, 7), Dimension(1, 5), 4, Dimension(3, 5), 3};
const auto out_shape = PartialShape{Dimension(3, 5), 3, 4};
auto I1 = make_shared<op::Parameter>(element::i32, input1_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation);
ASSERT_EQ(O->get_element_type(), element::i32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
}
TEST(type_prop, DISABLED_einsum_dynamicshape_ellipsis1)
{
// TODO: fix bug #53518 - PartialShape::broadcast_merge_into or Dimension::broadcast_merge
// to support broadcasting between Dimension(3, 5) and Dimension(1, 3)
// for which the result must be Dimension(3, 5)
std::string equation = "a...b,b...->a...";
const auto input1_shape = PartialShape{11, 1, Dimension(3, 5), 3};
const auto input2_shape = PartialShape{3, 11, 7, Dimension(1, 3)};
const auto out_shape = PartialShape{11, 11, 7, Dimension(3, 5)};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
}
TEST(type_prop, einsum_implicitmode_mixedcaseletters)
{
// the following equation is equivalent to "AbC->ACb"
std::string equation = "AbC";
const auto input1_shape = PartialShape{1, Dimension(2, 3), Dimension(4, 5)};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
const auto out_shape = PartialShape{1, Dimension(4, 5), Dimension(2, 3)};
auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation);
ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
}
TEST(type_prop, einsum_implicitmode_mixedcaseletters2)
{
// the following equation is equivalent to "a...b,B...->...Bab"
std::string equation = "a...b,B...";
const auto input1_shape = PartialShape{Dimension(3, 5), 11, 1, 3};
const auto input2_shape = PartialShape{Dimension(1, 3), 3, 1, 7};
const auto out_shape = PartialShape{3, 11, 7, Dimension(1, 3), Dimension(3, 5), 3};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
}
TEST(type_prop, einsum_dynamicrank_multimatmul)
{
std::string equation = "ab,bcd,bc->ca";
Shape input1_shape{2, 5};
PartialShape input2_shape = PartialShape::dynamic();
Shape input3_shape{5, 3};
Shape out_shape{3, 2};
auto I1 = make_shared<op::Parameter>(element::i32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::i32, input2_shape);
auto I3 = make_shared<op::Parameter>(element::i32, input3_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2, I3}, equation);
ASSERT_EQ(O->get_element_type(), element::i32);
ASSERT_EQ(O->get_shape(), out_shape);
}
TEST(type_prop, einsum_dynamicrank_multimatmul2)
{
std::string equation = "ab,bcd,bc->ca";
PartialShape input1_shape = PartialShape::dynamic();
PartialShape input2_shape = PartialShape::dynamic();
PartialShape input3_shape = PartialShape::dynamic();
PartialShape out_shape{Dimension(), Dimension()};
auto I1 = make_shared<op::Parameter>(element::i32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::i32, input2_shape);
auto I3 = make_shared<op::Parameter>(element::i32, input3_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2, I3}, equation);
ASSERT_EQ(O->get_element_type(), element::i32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
}
TEST(type_prop, einsum_incorrectequation_subscriptnumber)
{
std::string equation = "ab,bc,cd->ac";
Shape input1_shape{2, 3};
Shape input2_shape{3, 4};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
try
{
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect number of input subscripts";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Equation must contain a number of subscripts equal to a "
"number of Einsum inputs."));
}
catch (...)
{
FAIL() << "Equation format check failed";
}
}
TEST(type_prop, einsum_incorrectequation_invalidlabels)
{
std::string equation = "a$,Bc->ac";
Shape input1_shape{2, 3};
Shape input2_shape{3, 4};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
try
{
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
// Should have thrown, so fail if it didn't
FAIL() << "Incorrect number of input subscripts";
}
catch (const CheckFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Input subscript of Einsum equation must consist of either only alphabetic "
"letters or alphabetic letters with one ellipsis."));
}
catch (...)
{
FAIL() << "Equation format check failed";
}
}
TEST(type_prop, einsum_incorrectequation_incompatibleshapes)
{
std::string equation = "ab,bc->ac";
Shape input1_shape{2, 10};
Shape input2_shape{3, 4};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
try
{
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
// Should have thrown, so fail if it didn't
FAIL() << "Incompatible dimension indicated by the same labels";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Different input dimensions indicated by the same labels "
"for Einsum must be compatible."));
}
catch (...)
{
FAIL() << "Equation format check failed";
}
}
TEST(type_prop, einsum_incorrectequation_notbroadcastableshapes)
{
std::string equation = "a...b,b...->a...";
Shape input1_shape{11, 1, 4, 3};
Shape input2_shape{3, 11, 7, 5};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
try
{
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
// Should have thrown, so fail if it didn't
FAIL() << "Non-broadcastable shapes covered by ellipsis";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(
error.what(),
std::string(
"Input dimensions labeled with ellipsis for Einsum must be broadcastable."));
}
catch (...)
{
FAIL() << "Equation format check failed";
}
}
TEST(type_prop, einsum_incorrectequation_missedellipsis)
{
std::string equation = "a...b,b...->a";
Shape input1_shape{11, 1, 4, 3};
Shape input2_shape{3, 11, 7, 5};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
try
{
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
// Should have thrown, so fail if it didn't
FAIL() << "Non-broadcastable shapes covered by ellipsis";
}
catch (const CheckFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Output subscript of Einsum equation must contain one "
"ellipsis if ellipsis is met in any input subscript."));
}
catch (...)
{
FAIL() << "Equation format check failed";
}
}