Implement nGraph shell for Einsum-7 (#5282)
* Implement nGraph shell for Einsum-7 Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Correct doxygen formats Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Apply clang format change Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Support implicit mode and capital letters Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Correct and optimize the code based on review Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Correct private methods and its API, add more tests Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Make equation aux methods public and remove regex usage Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Make is_subscript_correct function local Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Correct check for missed ellipsis and add test for it Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
7d0cae8bb5
commit
fcea3f8a0c
69
ngraph/core/include/ngraph/op/einsum.hpp
Normal file
69
ngraph/core/include/ngraph/op/einsum.hpp
Normal file
@ -0,0 +1,69 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace v7
|
||||
{
|
||||
/// \brief Einsum operation.
|
||||
class NGRAPH_API Einsum : public Op
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
Einsum() = default;
|
||||
|
||||
///
|
||||
/// \brief Constructs Einsum operation.
|
||||
///
|
||||
/// \param inputs Input nodes on which Einsum operation performs
|
||||
/// contraction
|
||||
///
|
||||
/// \param equation Einstein summation convention
|
||||
///
|
||||
Einsum(const OutputVector& inputs, const std::string& equation);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
/// \brief Check correctness of equation format and extract input subscripts
|
||||
/// and output subscript
|
||||
///
|
||||
/// \param equation Equation to be parsed and checked
|
||||
///
|
||||
/// \param input_subscripts A vector of extracted input subscripts
|
||||
///
|
||||
/// \param output_subscript An output subscript
|
||||
///
|
||||
static void parse_equation(const std::string& equation,
|
||||
std::vector<std::string>& input_subscripts,
|
||||
std::string& output_subscript);
|
||||
|
||||
/// \brief Extract labels (from subscript) that can be alphabetic letters or
|
||||
/// ellipsis
|
||||
///
|
||||
/// \param subscript Subscript
|
||||
///
|
||||
/// \return A vector of extracted labels from the input subscript in the order
|
||||
/// of appearence
|
||||
///
|
||||
static std::vector<std::string> extract_labels(const std::string& subscript);
|
||||
|
||||
private:
|
||||
std::string m_equation;
|
||||
};
|
||||
} // namespace v7
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
@ -41,6 +41,7 @@
|
||||
#include "ngraph/op/detection_output.hpp"
|
||||
#include "ngraph/op/dft.hpp"
|
||||
#include "ngraph/op/divide.hpp"
|
||||
#include "ngraph/op/einsum.hpp"
|
||||
#include "ngraph/op/elu.hpp"
|
||||
#include "ngraph/op/embedding_segments_sum.hpp"
|
||||
#include "ngraph/op/embeddingbag_offsets_sum.hpp"
|
||||
|
@ -171,6 +171,7 @@ NGRAPH_OP(ReadValue, ngraph::op::v6) // new version
|
||||
|
||||
// New operations added in opset7
|
||||
NGRAPH_OP(DFT, ngraph::op::v7)
|
||||
NGRAPH_OP(Einsum, ngraph::op::v7)
|
||||
NGRAPH_OP(Gelu, ngraph::op::v7)
|
||||
NGRAPH_OP(IDFT, ngraph::op::v7)
|
||||
NGRAPH_OP(Roll, ngraph::op::v7)
|
||||
|
308
ngraph/core/src/op/einsum.cpp
Normal file
308
ngraph/core/src/op/einsum.cpp
Normal file
@ -0,0 +1,308 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <algorithm>
|
||||
#include <cctype>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/op/einsum.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v7::Einsum, "Einsum", 7);
|
||||
|
||||
op::v7::Einsum::Einsum(const OutputVector& inputs, const std::string& equation)
|
||||
: Op(inputs)
|
||||
, m_equation(equation)
|
||||
{
|
||||
// normalize input equation by removing extra white-spaces from the equation
|
||||
m_equation.erase(std::remove_if(m_equation.begin(), m_equation.end(), ::isspace),
|
||||
m_equation.end());
|
||||
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
/// \brief Check that a subscript contains only alphabetic letters or
|
||||
/// alphabetic letters with one ellipsis
|
||||
///
|
||||
/// \param subscripts A subscript to check its format
|
||||
///
|
||||
/// \param is_ellipsis_met Marker if ellipsis is met in the subscript
|
||||
///
|
||||
/// \return true - correct subscript, false - otherwise
|
||||
///
|
||||
bool is_subscript_correct(const std::string& subscript, bool& is_ellipsis_met)
|
||||
{
|
||||
is_ellipsis_met = false;
|
||||
auto subscript_length = subscript.length();
|
||||
for (size_t ch_idx = 0; ch_idx < subscript_length; ++ch_idx)
|
||||
{
|
||||
if (is_ellipsis_met == false && ((subscript_length - ch_idx) > 2) &&
|
||||
(subscript.substr(ch_idx, 3).compare("...") == 0))
|
||||
{
|
||||
// mark that ellipsis is met once
|
||||
is_ellipsis_met = true;
|
||||
|
||||
// make additional increment since ellipsis consists of three dots.
|
||||
ch_idx += 2;
|
||||
}
|
||||
else if (std::isalpha(subscript[ch_idx]) == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::v7::Einsum::parse_equation(const std::string& equation,
|
||||
std::vector<std::string>& input_subscripts,
|
||||
std::string& output_subscript)
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Einsum_parse_equation);
|
||||
|
||||
// split equation to input subscripts and an output subscript
|
||||
auto pos_output_delimeter = equation.find("->");
|
||||
auto input_subscripts_str = equation.substr(0, pos_output_delimeter);
|
||||
|
||||
// split the input subscripts into a vector of input subscripts
|
||||
bool is_ellipsis_met = false;
|
||||
input_subscripts.clear();
|
||||
std::istringstream input;
|
||||
input.str(input_subscripts_str);
|
||||
for (std::string input_subscript; std::getline(input, input_subscript, ',');)
|
||||
{
|
||||
bool local_is_ellipsis_met = false;
|
||||
// check that input subscript contains only alphabetic letter or ellipsis
|
||||
NGRAPH_CHECK(is_subscript_correct(input_subscript, local_is_ellipsis_met),
|
||||
"Input subscript of Einsum equation must consist of either only "
|
||||
"alphabetic letters or alphabetic letters with one ellipsis.");
|
||||
|
||||
// mark that ellipsis is met at least in one input subscript
|
||||
if (local_is_ellipsis_met)
|
||||
{
|
||||
is_ellipsis_met = true;
|
||||
}
|
||||
input_subscripts.push_back(input_subscript);
|
||||
}
|
||||
|
||||
if (pos_output_delimeter == std::string::npos)
|
||||
{
|
||||
// recover output subscript
|
||||
output_subscript = "";
|
||||
for (auto const& input_subscript : input_subscripts)
|
||||
{
|
||||
for (auto const& label : input_subscript)
|
||||
{
|
||||
if (std::isalpha(label) && output_subscript.find(label) == std::string::npos)
|
||||
{
|
||||
output_subscript += label;
|
||||
}
|
||||
}
|
||||
}
|
||||
std::sort(output_subscript.begin(), output_subscript.end());
|
||||
if (is_ellipsis_met)
|
||||
{
|
||||
output_subscript = "..." + output_subscript;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
output_subscript = equation.substr(pos_output_delimeter + 2);
|
||||
bool output_is_ellipsis_met = false;
|
||||
|
||||
// check that the output subscript has the correct format
|
||||
NGRAPH_CHECK(is_subscript_correct(output_subscript, output_is_ellipsis_met),
|
||||
"Output subscript of Einsum equation must consist of either only "
|
||||
"alphabetic letters or alphabetic letters with one ellipsis.");
|
||||
|
||||
// if the ellipsis is met in input subscripts, one ellipsis must be in the output subscript
|
||||
NGRAPH_CHECK(is_ellipsis_met == output_is_ellipsis_met,
|
||||
"Output subscript of Einsum equation must contain one ellipsis if "
|
||||
"ellipsis is met in any input subscript.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> op::v7::Einsum::extract_labels(const std::string& subscript)
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Einsum_extract_labels);
|
||||
|
||||
std::vector<std::string> labels;
|
||||
labels.clear();
|
||||
auto subscript_length = subscript.length();
|
||||
for (size_t ch_idx = 0; ch_idx < subscript_length; ++ch_idx)
|
||||
{
|
||||
if (std::isalpha(subscript[ch_idx]))
|
||||
{
|
||||
labels.push_back(subscript.substr(ch_idx, 1));
|
||||
}
|
||||
else if (((subscript_length - ch_idx) > 2) &&
|
||||
(subscript.substr(ch_idx, 3).compare("...") == 0))
|
||||
{
|
||||
labels.push_back("...");
|
||||
// make additional increment since ellipsis consists of three dots.
|
||||
ch_idx += 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
NGRAPH_CHECK(false, "Einsum equation has invalid label.");
|
||||
}
|
||||
}
|
||||
|
||||
return labels;
|
||||
}
|
||||
|
||||
void op::v7::Einsum::validate_and_infer_types()
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Einsum_validate_and_infer_types);
|
||||
|
||||
// check that Einsum operation has at least one input
|
||||
auto num_inputs = get_input_size();
|
||||
NODE_VALIDATION_CHECK(this, num_inputs > 0, "Einsum must have at least one input.");
|
||||
|
||||
// check that all inputs have the same type and the type is numeric
|
||||
const auto& input_type_0 = get_input_element_type(0);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_type_0.is_real() || input_type_0.is_integral_number(),
|
||||
"The input type for Einsum operation must be numeric.");
|
||||
for (size_t input_idx = 1; input_idx < num_inputs; ++input_idx)
|
||||
{
|
||||
const auto& input_type_i = get_input_element_type(input_idx);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_type_0 == input_type_i,
|
||||
"Inputs to Einsum operation must have the same type.");
|
||||
}
|
||||
|
||||
// check that equation has correct format and extract input and output subscripts
|
||||
std::vector<std::string> input_subscripts;
|
||||
std::string output_subscript;
|
||||
parse_equation(m_equation, input_subscripts, output_subscript);
|
||||
|
||||
// a number of input subscripts must match with a number of input tensors
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
input_subscripts.size() == num_inputs,
|
||||
"Equation must contain a number of subscripts equal to a number of Einsum inputs.");
|
||||
|
||||
// create a dictionary with dimension sizes (or ranges in case dynamic shapes) for each label
|
||||
// and check their compatibility in case repeating labels
|
||||
unordered_map<string, PartialShape> label_to_shape;
|
||||
label_to_shape.clear();
|
||||
|
||||
for (size_t input_idx = 0; input_idx < num_inputs; ++input_idx)
|
||||
{
|
||||
const auto& pshape = get_input_partial_shape(input_idx);
|
||||
std::vector<std::string> labels;
|
||||
labels = extract_labels(input_subscripts[input_idx]);
|
||||
|
||||
if (pshape.rank().is_static())
|
||||
{
|
||||
size_t input_rank = pshape.rank().get_length();
|
||||
// check that a rank is greater or equal to a number of labels
|
||||
// these numbers are always equal if there is no ellipsis in the subscript
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
input_rank >= labels.size(),
|
||||
"Input rank must be greater or equal to a number of labels in the "
|
||||
"corresponding input subscript.");
|
||||
|
||||
for (size_t label_ind = 0, dim_ind = 0;
|
||||
label_ind < labels.size() && dim_ind < input_rank;
|
||||
++label_ind)
|
||||
{
|
||||
auto const& label = labels[label_ind];
|
||||
if (label.compare("...") == 0)
|
||||
{
|
||||
size_t num_broadcasted_dims = input_rank - labels.size() + 1;
|
||||
auto current_sub_pshape = PartialShape(std::vector<Dimension>(
|
||||
pshape.begin() + dim_ind, pshape.begin() + dim_ind + num_broadcasted_dims));
|
||||
if (label_to_shape.find(label) == label_to_shape.end())
|
||||
{
|
||||
label_to_shape[label] = current_sub_pshape;
|
||||
}
|
||||
else
|
||||
{
|
||||
bool is_broadcast_success =
|
||||
PartialShape::broadcast_merge_into(label_to_shape[label],
|
||||
current_sub_pshape,
|
||||
op::AutoBroadcastType::NUMPY);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
is_broadcast_success,
|
||||
"Input dimensions labeled with ellipsis for Einsum "
|
||||
"must be broadcastable.");
|
||||
}
|
||||
dim_ind += num_broadcasted_dims;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (label_to_shape.find(label) == label_to_shape.end())
|
||||
{
|
||||
label_to_shape[label] = PartialShape{pshape[dim_ind]};
|
||||
}
|
||||
else
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
label_to_shape[label].compatible(PartialShape{pshape[label_ind]}),
|
||||
"Different input dimensions indicated by the same labels for Einsum "
|
||||
"must be compatible.");
|
||||
PartialShape::merge_into(label_to_shape[label],
|
||||
PartialShape{pshape[dim_ind]});
|
||||
}
|
||||
++dim_ind;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto const& label : labels)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
label != "...",
|
||||
"The subscript corresponding to a dynamic rank input must "
|
||||
"not contain ellipsis.");
|
||||
|
||||
if (label_to_shape.find(label) == label_to_shape.end())
|
||||
{
|
||||
label_to_shape[label] = PartialShape{Dimension::dynamic()};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// compute the output shape
|
||||
std::vector<std::string> output_labels;
|
||||
output_labels = extract_labels(output_subscript);
|
||||
std::vector<Dimension> output_pshape_vector;
|
||||
|
||||
for (auto const& output_label : output_labels)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
label_to_shape.find(output_label) != label_to_shape.end(),
|
||||
"Label in output subscript of Einsum equation must enter at least "
|
||||
"one input subscript.");
|
||||
output_pshape_vector.insert(output_pshape_vector.end(),
|
||||
label_to_shape[output_label].begin(),
|
||||
label_to_shape[output_label].end());
|
||||
}
|
||||
set_output_type(0, input_type_0, PartialShape(output_pshape_vector));
|
||||
}
|
||||
|
||||
bool op::v7::Einsum::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Einsum_visit_attributes);
|
||||
visitor.on_attribute("equation", m_equation);
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v7::Einsum::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
NGRAPH_OP_SCOPE(v7_Einsum_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<v7::Einsum>(new_args, m_equation);
|
||||
}
|
@ -115,6 +115,7 @@ set(SRC
|
||||
type_prop/depth_to_space.cpp
|
||||
type_prop/dft.cpp
|
||||
type_prop/dyn_reshape.cpp
|
||||
type_prop/einsum.cpp
|
||||
type_prop/experimental_detectron_generate_proposals.cpp
|
||||
type_prop/experimental_detectron_roi_feature_extractor.cpp
|
||||
type_prop/experimental_detectron_topkrois.cpp
|
||||
|
349
ngraph/test/type_prop/einsum.cpp
Normal file
349
ngraph/test/type_prop/einsum.cpp
Normal file
@ -0,0 +1,349 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, einsum_staticshape_dotproduct)
|
||||
{
|
||||
std::string equation = "i,i->";
|
||||
Shape input1_shape{3};
|
||||
Shape input2_shape{3};
|
||||
Shape out_shape{};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::f32);
|
||||
ASSERT_EQ(O->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_staticshape_matmul)
|
||||
{
|
||||
std::string equation = "ab,bc->ac";
|
||||
Shape input1_shape{2, 3};
|
||||
Shape input2_shape{3, 4};
|
||||
Shape out_shape{2, 4};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::f32);
|
||||
ASSERT_EQ(O->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_staticshape_trace)
|
||||
{
|
||||
std::string equation = "kii->k";
|
||||
Shape input1_shape{2, 3, 3};
|
||||
Shape out_shape{2};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::f32);
|
||||
ASSERT_EQ(O->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_staticshape_diagextraction)
|
||||
{
|
||||
std::string equation = "kii->ki";
|
||||
Shape input1_shape{2, 3, 3};
|
||||
Shape out_shape{2, 3};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::f32);
|
||||
ASSERT_EQ(O->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_staticshape_transpose)
|
||||
{
|
||||
std::string equation = "ijk->kij";
|
||||
Shape input1_shape{1, 2, 3};
|
||||
Shape out_shape{3, 1, 2};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::f32);
|
||||
ASSERT_EQ(O->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_staticshape_multimatmul)
|
||||
{
|
||||
std::string equation = "ab,bcd,bc->ca";
|
||||
Shape input1_shape{2, 5};
|
||||
Shape input2_shape{5, 3, 6};
|
||||
Shape input3_shape{5, 3};
|
||||
Shape out_shape{3, 2};
|
||||
auto I1 = make_shared<op::Parameter>(element::i32, input1_shape);
|
||||
auto I2 = make_shared<op::Parameter>(element::i32, input2_shape);
|
||||
auto I3 = make_shared<op::Parameter>(element::i32, input3_shape);
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2, I3}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::i32);
|
||||
ASSERT_EQ(O->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_staticshape_ellipsis)
|
||||
{
|
||||
std::string equation = "a...->...";
|
||||
Shape input1_shape{5, 3};
|
||||
Shape out_shape{3};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::f32);
|
||||
ASSERT_EQ(O->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_staticshape_ellipsis2)
|
||||
{
|
||||
std::string equation = "a...,...->a...";
|
||||
Shape input1_shape{3, 5};
|
||||
Shape input2_shape{1};
|
||||
Shape out_shape{3, 5};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::f32);
|
||||
ASSERT_EQ(O->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_staticshape_ellipsis3)
|
||||
{
|
||||
std::string equation = "a...b,b...->a...";
|
||||
Shape input1_shape{11, 1, 4, 3};
|
||||
Shape input2_shape{3, 11, 7, 1};
|
||||
Shape out_shape{11, 11, 7, 4};
|
||||
auto I1 = make_shared<op::Parameter>(element::i32, input1_shape);
|
||||
auto I2 = make_shared<op::Parameter>(element::i32, input2_shape);
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::i32);
|
||||
ASSERT_EQ(O->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_dynamicshape_dotproduct)
|
||||
{
|
||||
std::string equation = "a,ab->ab";
|
||||
const auto input1_shape = PartialShape{Dimension(2, 7)};
|
||||
const auto input2_shape = PartialShape{Dimension(3, 10), 3};
|
||||
const auto out_shape = PartialShape{Dimension(3, 7), 3};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_dynamicshape_diagextraction)
|
||||
{
|
||||
std::string equation = "xyzxy->xyz";
|
||||
const auto input1_shape = PartialShape{Dimension(2, 7), Dimension(1, 5), 4, Dimension(3, 5), 3};
|
||||
const auto out_shape = PartialShape{Dimension(3, 5), 3, 4};
|
||||
auto I1 = make_shared<op::Parameter>(element::i32, input1_shape);
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::i32);
|
||||
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
}
|
||||
|
||||
TEST(type_prop, DISABLED_einsum_dynamicshape_ellipsis1)
|
||||
{
|
||||
// TODO: fix bug #53518 - PartialShape::broadcast_merge_into or Dimension::broadcast_merge
|
||||
// to support broadcasting between Dimension(3, 5) and Dimension(1, 3)
|
||||
// for which the result must be Dimension(3, 5)
|
||||
std::string equation = "a...b,b...->a...";
|
||||
const auto input1_shape = PartialShape{11, 1, Dimension(3, 5), 3};
|
||||
const auto input2_shape = PartialShape{3, 11, 7, Dimension(1, 3)};
|
||||
const auto out_shape = PartialShape{11, 11, 7, Dimension(3, 5)};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_implicitmode_mixedcaseletters)
|
||||
{
|
||||
// the following equation is equivalent to "AbC->ACb"
|
||||
std::string equation = "AbC";
|
||||
const auto input1_shape = PartialShape{1, Dimension(2, 3), Dimension(4, 5)};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
const auto out_shape = PartialShape{1, Dimension(4, 5), Dimension(2, 3)};
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_implicitmode_mixedcaseletters2)
|
||||
{
|
||||
// the following equation is equivalent to "a...b,B...->...Bab"
|
||||
std::string equation = "a...b,B...";
|
||||
const auto input1_shape = PartialShape{Dimension(3, 5), 11, 1, 3};
|
||||
const auto input2_shape = PartialShape{Dimension(1, 3), 3, 1, 7};
|
||||
const auto out_shape = PartialShape{3, 11, 7, Dimension(1, 3), Dimension(3, 5), 3};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::f32);
|
||||
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_dynamicrank_multimatmul)
|
||||
{
|
||||
std::string equation = "ab,bcd,bc->ca";
|
||||
Shape input1_shape{2, 5};
|
||||
PartialShape input2_shape = PartialShape::dynamic();
|
||||
Shape input3_shape{5, 3};
|
||||
Shape out_shape{3, 2};
|
||||
auto I1 = make_shared<op::Parameter>(element::i32, input1_shape);
|
||||
auto I2 = make_shared<op::Parameter>(element::i32, input2_shape);
|
||||
auto I3 = make_shared<op::Parameter>(element::i32, input3_shape);
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2, I3}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::i32);
|
||||
ASSERT_EQ(O->get_shape(), out_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_dynamicrank_multimatmul2)
|
||||
{
|
||||
std::string equation = "ab,bcd,bc->ca";
|
||||
PartialShape input1_shape = PartialShape::dynamic();
|
||||
PartialShape input2_shape = PartialShape::dynamic();
|
||||
PartialShape input3_shape = PartialShape::dynamic();
|
||||
PartialShape out_shape{Dimension(), Dimension()};
|
||||
auto I1 = make_shared<op::Parameter>(element::i32, input1_shape);
|
||||
auto I2 = make_shared<op::Parameter>(element::i32, input2_shape);
|
||||
auto I3 = make_shared<op::Parameter>(element::i32, input3_shape);
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2, I3}, equation);
|
||||
ASSERT_EQ(O->get_element_type(), element::i32);
|
||||
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_incorrectequation_subscriptnumber)
|
||||
{
|
||||
std::string equation = "ab,bc,cd->ac";
|
||||
Shape input1_shape{2, 3};
|
||||
Shape input2_shape{3, 4};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
|
||||
|
||||
try
|
||||
{
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect number of input subscripts";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Equation must contain a number of subscripts equal to a "
|
||||
"number of Einsum inputs."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Equation format check failed";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_incorrectequation_invalidlabels)
|
||||
{
|
||||
std::string equation = "a$,Bc->ac";
|
||||
Shape input1_shape{2, 3};
|
||||
Shape input2_shape{3, 4};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
|
||||
|
||||
try
|
||||
{
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incorrect number of input subscripts";
|
||||
}
|
||||
catch (const CheckFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("Input subscript of Einsum equation must consist of either only alphabetic "
|
||||
"letters or alphabetic letters with one ellipsis."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Equation format check failed";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_incorrectequation_incompatibleshapes)
|
||||
{
|
||||
std::string equation = "ab,bc->ac";
|
||||
Shape input1_shape{2, 10};
|
||||
Shape input2_shape{3, 4};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
|
||||
|
||||
try
|
||||
{
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Incompatible dimension indicated by the same labels";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Different input dimensions indicated by the same labels "
|
||||
"for Einsum must be compatible."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Equation format check failed";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_incorrectequation_notbroadcastableshapes)
|
||||
{
|
||||
std::string equation = "a...b,b...->a...";
|
||||
Shape input1_shape{11, 1, 4, 3};
|
||||
Shape input2_shape{3, 11, 7, 5};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
|
||||
|
||||
try
|
||||
{
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Non-broadcastable shapes covered by ellipsis";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string(
|
||||
"Input dimensions labeled with ellipsis for Einsum must be broadcastable."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Equation format check failed";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, einsum_incorrectequation_missedellipsis)
|
||||
{
|
||||
std::string equation = "a...b,b...->a";
|
||||
Shape input1_shape{11, 1, 4, 3};
|
||||
Shape input2_shape{3, 11, 7, 5};
|
||||
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
|
||||
|
||||
try
|
||||
{
|
||||
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
|
||||
// Should have thrown, so fail if it didn't
|
||||
FAIL() << "Non-broadcastable shapes covered by ellipsis";
|
||||
}
|
||||
catch (const CheckFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Output subscript of Einsum equation must contain one "
|
||||
"ellipsis if ellipsis is met in any input subscript."));
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
FAIL() << "Equation format check failed";
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user