Fix recovery of output subscript in Einsum implicit mode (#6131)
* Fix recovery of output subscript in Einsum implicit mode Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com> * Fix code style Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
ea3ed8af21
commit
3292543252
@ -28,6 +28,22 @@ class Einsum(Op):
|
|||||||
def backend_attrs(self):
|
def backend_attrs(self):
|
||||||
return ['equation']
|
return ['equation']
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_label_elsewhere(input_subscripts: list, label_to_check: str, excluded_subscript_inds: list) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the given label is met in input subscripts excluding ones specified by a list of indices
|
||||||
|
excluded_subscript_inds
|
||||||
|
|
||||||
|
:param input_subscripts: input subscripts among which to check if the label is met
|
||||||
|
:param label_to_check: a label to check
|
||||||
|
:param excluded_subscript_inds: indices of input subscripts to be excluded for this check
|
||||||
|
:return: True - met, False - otherwise
|
||||||
|
"""
|
||||||
|
for ind, input_subscript in enumerate(input_subscripts):
|
||||||
|
if ind not in excluded_subscript_inds and label_to_check in input_subscript:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_equation(node_name: str, equation: str) -> (list, str):
|
def parse_equation(node_name: str, equation: str) -> (list, str):
|
||||||
"""
|
"""
|
||||||
@ -70,7 +86,12 @@ class Einsum(Op):
|
|||||||
"The output subscript of Einsum node {} must contain ellipsis".format(node_name)
|
"The output subscript of Einsum node {} must contain ellipsis".format(node_name)
|
||||||
elif len(splitted_equation) == 1:
|
elif len(splitted_equation) == 1:
|
||||||
# recover output subscript in case implicit mode
|
# recover output subscript in case implicit mode
|
||||||
output_subscript = ''.join(input_subscripts_list)
|
output_subscript = ""
|
||||||
|
for ind, input_subscript in enumerate(input_subscripts_list):
|
||||||
|
labels = Einsum.extract_subscript_labels(node_name, input_subscript)
|
||||||
|
for label in labels:
|
||||||
|
if Einsum.is_label_elsewhere(input_subscripts_list, label, [ind]) is False:
|
||||||
|
output_subscript += label
|
||||||
output_subscript = ''.join(sorted(list(set(output_subscript) - {'.'})))
|
output_subscript = ''.join(sorted(list(set(output_subscript) - {'.'})))
|
||||||
if is_ellipsis_met:
|
if is_ellipsis_met:
|
||||||
output_subscript = "..." + output_subscript
|
output_subscript = "..." + output_subscript
|
||||||
|
@ -60,6 +60,11 @@ class TestEinsum(unittest.TestCase):
|
|||||||
([int64_array([1, 3, 5])], "AbC", int64_array([1, 5, 3])),
|
([int64_array([1, 3, 5])], "AbC", int64_array([1, 5, 3])),
|
||||||
# mixed case letters and equation in implicit mode
|
# mixed case letters and equation in implicit mode
|
||||||
([int64_array([3, 11, 1, 5]), int64_array([1, 3, 1, 7])], "a...b,B...", int64_array([3, 11, 7, 1, 3, 5])),
|
([int64_array([3, 11, 1, 5]), int64_array([1, 3, 1, 7])], "a...b,B...", int64_array([3, 11, 7, 1, 3, 5])),
|
||||||
|
# inner product in implicit mode
|
||||||
|
([int64_array([3]), int64_array([3])], "i,i", int64_array([])),
|
||||||
|
# equation with ellipsis and repeated labels in implicit mode
|
||||||
|
# "a...b,b..." is equivalent to "a...b,b...->...a"
|
||||||
|
([int64_array([9, 1, 4, 3]), int64_array([3, 11, 7, 1])], "a...b,b...", int64_array([11, 7, 4, 9])),
|
||||||
])
|
])
|
||||||
def test_einsum(self, input_shapes, equation, ref_output_shape):
|
def test_einsum(self, input_shapes, equation, ref_output_shape):
|
||||||
graph = create_einsum_graph(input_shapes, equation)
|
graph = create_einsum_graph(input_shapes, equation)
|
||||||
|
@ -60,11 +60,40 @@ bool is_subscript_correct(const std::string& subscript, bool& is_ellipsis_met)
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// \brief Check if the given label is met in input subscripts excluding ones
|
||||||
|
/// specified by a vector excluded_indices
|
||||||
|
///
|
||||||
|
/// \param input_subscripts The vector of the input subscripts
|
||||||
|
/// \param label_to_check A label to check
|
||||||
|
/// \param excluded_indices A vector of input subscript indices to be excluded
|
||||||
|
///
|
||||||
|
/// \return true - met, false - otherwise
|
||||||
|
///
|
||||||
|
bool is_label_elsewhere(const std::vector<std::string>& input_subscripts,
|
||||||
|
const std::string& label_to_check,
|
||||||
|
const std::vector<size_t>& excluded_indices)
|
||||||
|
{
|
||||||
|
for (size_t input_ind = 0; input_ind < input_subscripts.size(); ++input_ind)
|
||||||
|
{
|
||||||
|
const auto& input_subscript = input_subscripts[input_ind];
|
||||||
|
// the subscript is checked only if its index is not in excluded indices list
|
||||||
|
bool check_subscript =
|
||||||
|
(std::find(excluded_indices.begin(), excluded_indices.end(), input_ind) ==
|
||||||
|
excluded_indices.end());
|
||||||
|
if (check_subscript && input_subscript.find(label_to_check) != std::string::npos)
|
||||||
|
{
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
void op::v7::Einsum::parse_equation(const std::string& equation,
|
void op::v7::Einsum::parse_equation(const std::string& equation,
|
||||||
std::vector<std::string>& input_subscripts,
|
std::vector<std::string>& input_subscripts,
|
||||||
std::string& output_subscript)
|
std::string& output_subscript)
|
||||||
{
|
{
|
||||||
NGRAPH_OP_SCOPE(v7_Einsum_parse_equation);
|
NGRAPH_OP_SCOPE(v7_Einsum_parse_equation);
|
||||||
|
constexpr char ellipsis[] = "...";
|
||||||
|
|
||||||
// split equation to input subscripts and an output subscript
|
// split equation to input subscripts and an output subscript
|
||||||
auto pos_output_delimeter = equation.find("->");
|
auto pos_output_delimeter = equation.find("->");
|
||||||
@ -93,13 +122,15 @@ void op::v7::Einsum::parse_equation(const std::string& equation,
|
|||||||
|
|
||||||
if (pos_output_delimeter == std::string::npos)
|
if (pos_output_delimeter == std::string::npos)
|
||||||
{
|
{
|
||||||
// recover output subscript
|
// equation is in implicit mode so recover output subscript
|
||||||
output_subscript = "";
|
output_subscript = "";
|
||||||
for (auto const& input_subscript : input_subscripts)
|
for (size_t ind = 0; ind < input_subscripts.size(); ++ind)
|
||||||
{
|
{
|
||||||
for (auto const& label : input_subscript)
|
auto const& input_subscript = input_subscripts[ind];
|
||||||
|
for (auto const& label : extract_labels(input_subscript))
|
||||||
{
|
{
|
||||||
if (std::isalpha(label) && output_subscript.find(label) == std::string::npos)
|
if (label != ellipsis &&
|
||||||
|
(is_label_elsewhere(input_subscripts, label, {ind}) == false))
|
||||||
{
|
{
|
||||||
output_subscript += label;
|
output_subscript += label;
|
||||||
}
|
}
|
||||||
|
@ -186,6 +186,34 @@ TEST(type_prop, einsum_implicitmode_mixedcaseletters2)
|
|||||||
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
|
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, einsum_implicitmode_repeatedlabels)
|
||||||
|
{
|
||||||
|
// the following equation is equivalent to "a...b,b...->...a"
|
||||||
|
std::string equation = "a...b,b...";
|
||||||
|
const auto input1_shape = PartialShape{Dimension(3, 5), 11, 1, 3};
|
||||||
|
const auto input2_shape = PartialShape{Dimension(1, 3), 3, 1, 7};
|
||||||
|
const auto out_shape = PartialShape{3, 11, 7, Dimension(3, 5)};
|
||||||
|
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||||
|
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
|
||||||
|
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
|
||||||
|
ASSERT_EQ(O->get_element_type(), element::f32);
|
||||||
|
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(type_prop, einsum_implicitmode_innerprod)
|
||||||
|
{
|
||||||
|
// the following equation is equivalent to "i,i->"
|
||||||
|
std::string equation = "i,i";
|
||||||
|
const auto input1_shape = PartialShape{11};
|
||||||
|
const auto input2_shape = PartialShape{Dimension(1, 20)};
|
||||||
|
const auto out_shape = PartialShape{};
|
||||||
|
auto I1 = make_shared<op::Parameter>(element::f32, input1_shape);
|
||||||
|
auto I2 = make_shared<op::Parameter>(element::f32, input2_shape);
|
||||||
|
auto O = make_shared<op::v7::Einsum>(OutputVector{I1, I2}, equation);
|
||||||
|
ASSERT_EQ(O->get_element_type(), element::f32);
|
||||||
|
ASSERT_TRUE(O->get_output_partial_shape(0).same_scheme(out_shape));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(type_prop, einsum_dynamicrank_multimatmul)
|
TEST(type_prop, einsum_dynamicrank_multimatmul)
|
||||||
{
|
{
|
||||||
std::string equation = "ab,bcd,bc->ca";
|
std::string equation = "ab,bcd,bc->ca";
|
||||||
|
Loading…
Reference in New Issue
Block a user