Fix recovery of output subscript in Einsum implicit mode (#6131)

* Fix recovery of output subscript in Einsum implicit mode

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>

* Fix code style

Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Roman Kazantsev 2021-06-16 06:43:12 +03:00 committed by GitHub
parent ea3ed8af21
commit 3292543252
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 90 additions and 5 deletions

View File

@ -28,6 +28,22 @@ class Einsum(Op):
def backend_attrs(self): def backend_attrs(self):
return ['equation'] return ['equation']
@staticmethod
def is_label_elsewhere(input_subscripts: list, label_to_check: str, excluded_subscript_inds: list) -> bool:
"""
Check if the given label is met in input subscripts excluding ones specified by a list of indices
excluded_subscript_inds
:param input_subscripts: input subscripts among which to check if the label is met
:param label_to_check: a label to check
:param excluded_subscript_inds: indices of input subscripts to be excluded for this check
:return: True - met, False - otherwise
"""
for ind, input_subscript in enumerate(input_subscripts):
if ind not in excluded_subscript_inds and label_to_check in input_subscript:
return True
return False
@staticmethod @staticmethod
def parse_equation(node_name: str, equation: str) -> (list, str): def parse_equation(node_name: str, equation: str) -> (list, str):
""" """
@ -70,7 +86,12 @@ class Einsum(Op):
"The output subscript of Einsum node {} must contain ellipsis".format(node_name) "The output subscript of Einsum node {} must contain ellipsis".format(node_name)
elif len(splitted_equation) == 1: elif len(splitted_equation) == 1:
# recover output subscript in case implicit mode # recover output subscript in case implicit mode
output_subscript = ''.join(input_subscripts_list) output_subscript = ""
for ind, input_subscript in enumerate(input_subscripts_list):
labels = Einsum.extract_subscript_labels(node_name, input_subscript)
for label in labels:
if Einsum.is_label_elsewhere(input_subscripts_list, label, [ind]) is False:
output_subscript += label
output_subscript = ''.join(sorted(list(set(output_subscript) - {'.'}))) output_subscript = ''.join(sorted(list(set(output_subscript) - {'.'})))
if is_ellipsis_met: if is_ellipsis_met:
output_subscript = "..." + output_subscript output_subscript = "..." + output_subscript

View File

@ -60,6 +60,11 @@ class TestEinsum(unittest.TestCase):
([int64_array([1, 3, 5])], "AbC", int64_array([1, 5, 3])), ([int64_array([1, 3, 5])], "AbC", int64_array([1, 5, 3])),
# mixed case letters and equation in implicit mode # mixed case letters and equation in implicit mode
([int64_array([3, 11, 1, 5]), int64_array([1, 3, 1, 7])], "a...b,B...", int64_array([3, 11, 7, 1, 3, 5])), ([int64_array([3, 11, 1, 5]), int64_array([1, 3, 1, 7])], "a...b,B...", int64_array([3, 11, 7, 1, 3, 5])),
# inner product in implicit mode
([int64_array([3]), int64_array([3])], "i,i", int64_array([])),
# equation with ellipsis and repeated labels in implicit mode
# "a...b,b..." is equivalent to "a...b,b...->...a"
([int64_array([9, 1, 4, 3]), int64_array([3, 11, 7, 1])], "a...b,b...", int64_array([11, 7, 4, 9])),
]) ])
def test_einsum(self, input_shapes, equation, ref_output_shape): def test_einsum(self, input_shapes, equation, ref_output_shape):
graph = create_einsum_graph(input_shapes, equation) graph = create_einsum_graph(input_shapes, equation)

View File

@ -60,11 +60,40 @@ bool is_subscript_correct(const std::string& subscript, bool& is_ellipsis_met)
return true; return true;
} }
/// \brief Check if the given label is met in input subscripts excluding ones
/// specified by a vector excluded_indices
///
/// \param input_subscripts The vector of the input subscripts
/// \param label_to_check A label to check
/// \param excluded_indices A vector of input subscript indices to be excluded
///
/// \return true - met, false - otherwise
///
bool is_label_elsewhere(const std::vector<std::string>& input_subscripts,
const std::string& label_to_check,
const std::vector<size_t>& excluded_indices)
{
for (size_t input_ind = 0; input_ind < input_subscripts.size(); ++input_ind)
{
const auto& input_subscript = input_subscripts[input_ind];
// the subscript is checked only if its index is not in excluded indices list
bool check_subscript =
(std::find(excluded_indices.begin(), excluded_indices.end(), input_ind) ==
excluded_indices.end());
if (check_subscript && input_subscript.find(label_to_check) != std::string::npos)
{
return true;
}
}
return false;
}
void op::v7::Einsum::parse_equation(const std::string& equation, void op::v7::Einsum::parse_equation(const std::string& equation,
std::vector<std::string>& input_subscripts, std::vector<std::string>& input_subscripts,
std::string& output_subscript) std::string& output_subscript)
{ {
NGRAPH_OP_SCOPE(v7_Einsum_parse_equation); NGRAPH_OP_SCOPE(v7_Einsum_parse_equation);
constexpr char ellipsis[] = "...";
// split equation to input subscripts and an output subscript // split equation to input subscripts and an output subscript
auto pos_output_delimeter = equation.find("->"); auto pos_output_delimeter = equation.find("->");
@ -93,13 +122,15 @@ void op::v7::Einsum::parse_equation(const std::string& equation,
if (pos_output_delimeter == std::string::npos) if (pos_output_delimeter == std::string::npos)
{ {
// recover output subscript // equation is in implicit mode so recover output subscript
output_subscript = ""; output_subscript = "";
for (auto const& input_subscript : input_subscripts) for (size_t ind = 0; ind < input_subscripts.size(); ++ind)
{ {
for (auto const& label : input_subscript) auto const& input_subscript = input_subscripts[ind];
for (auto const& label : extract_labels(input_subscript))
{ {
if (std::isalpha(label) && output_subscript.find(label) == std::string::npos) if (label != ellipsis &&
(is_label_elsewhere(input_subscripts, label, {ind}) == false))
{ {
output_subscript += label; output_subscript += label;
} }

View File

@ -186,6 +186,34 @@ TEST(type_prop, einsum_implicitmode_mixedcaseletters2)
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape)); ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
} }
TEST(type_prop, einsum_implicitmode_repeatedlabels)
{
// the following equation is equivalent to "a...b,b...->...a"
std::string equation = "a...b,b...";
const auto input1_shape = PartialShape{Dimension(3, 5), 11, 1, 3};
const auto input2_shape = PartialShape{Dimension(1, 3), 3, 1, 7};
const auto out_shape = PartialShape{3, 11, 7, Dimension(3, 5)};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
}
TEST(type_prop, einsum_implicitmode_innerprod)
{
// the following equation is equivalent to "i,i->"
std::string equation = "i,i";
const auto input1_shape = PartialShape{11};
const auto input2_shape = PartialShape{Dimension(1, 20)};
const auto out_shape = PartialShape{};
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
ASSERT_EQ(O->get_element_type(), element::f32);
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
}
TEST(type_prop, einsum_dynamicrank_multimatmul) TEST(type_prop, einsum_dynamicrank_multimatmul)
{ {
std::string equation = "ab,bcd,bc->ca"; std::string equation = "ab,bcd,bc->ca";