[PT FE]: support aten::einsum (#15844)
This commit is contained in:
parent
a9efe5bd8d
commit
288a750bc6
@ -19,6 +19,7 @@
|
||||
#include "transforms/aten_cat_replacer.hpp"
|
||||
#include "transforms/aten_getitem_replacer.hpp"
|
||||
#include "transforms/aten_stack_list_construct_replacer.hpp"
|
||||
#include "transforms/einsum_list_construct.hpp"
|
||||
#include "transforms/listconstruct_replacer.hpp"
|
||||
#include "transforms/min_max_prim_list_construct_replacer.hpp"
|
||||
#include "transforms/prim_list_construct_pad.hpp"
|
||||
@ -97,6 +98,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
||||
manager.register_pass<ov::frontend::pytorch::pass::AtenGetItemReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::ListConstructReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::PrimListConstructPadReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::AtenEinsumListConstructReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::MinMaxPrimListConstructReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::DecomposeListTupleResults>();
|
||||
manager.register_pass<ov::pass::RemoveMultiSubGraphOpDanglingParams>();
|
||||
|
@ -0,0 +1,68 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "einsum_list_construct.hpp"
|
||||
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/op/einsum.hpp"
|
||||
#include "openvino/op/util/framework_node.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ov::pass::pattern;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace pass {
|
||||
|
||||
using namespace ov::pass;
|
||||
using namespace ov::op;
|
||||
|
||||
AtenEinsumListConstructReplacer::AtenEinsumListConstructReplacer() {
|
||||
auto einsum_op = pattern::wrap_type<ov::op::util::FrameworkNode>();
|
||||
ov::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto einsum_op = cast_fw_node(m.get_match_root(), "aten::einsum");
|
||||
if (!einsum_op) {
|
||||
return false;
|
||||
}
|
||||
auto equation_input = einsum_op->input_value(0).get_node_shared_ptr();
|
||||
auto tensor_list = einsum_op->input_value(1).get_node_shared_ptr();
|
||||
std::string equation;
|
||||
// equation should be string constant
|
||||
if (const auto& fw_node_mode = cast_fw_node(equation_input, "prim::Constant")) {
|
||||
const auto& attrs = fw_node_mode->get_attrs();
|
||||
if (attrs.find("string_value") != attrs.end()) {
|
||||
equation = attrs.at("string_value");
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
// Check if ListConstruct is an input
|
||||
if (auto list_construct_node = cast_fw_node(tensor_list, "prim::ListConstruct")) {
|
||||
const auto& list_inputs = list_construct_node->input_values();
|
||||
OutputVector node_vector;
|
||||
// Iterate over values in ListConstruct
|
||||
for (const auto& list_input : list_inputs) {
|
||||
node_vector.push_back(list_input);
|
||||
}
|
||||
|
||||
auto einsum = std::make_shared<v7::Einsum>(node_vector, equation);
|
||||
copy_runtime_info({einsum_op, equation_input, tensor_list}, einsum);
|
||||
replace_node(einsum_op, einsum);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
auto m =
|
||||
std::make_shared<pattern::Matcher>(einsum_op, "ov::frontend::pytorch::pass::AtenEinsumListConstructReplacer");
|
||||
this->register_matcher(m, callback);
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -0,0 +1,24 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace pass {
|
||||
|
||||
class AtenEinsumListConstructReplacer : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::frontend::pytorch::pass::AtenEinsumListConstructReplacer");
|
||||
AtenEinsumListConstructReplacer();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
103
tests/layer_tests/pytorch_tests/test_einsum.py
Normal file
103
tests/layer_tests/pytorch_tests/test_einsum.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestEinsumBatchMatMul(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
|
||||
return (np.random.randn(5, 2, 3).astype(np.float32), np.random.randn(5, 3, 4).astype(np.float32),)
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class EinsumModelBatchMatmul(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
eqn = "bij, bjk -> bik"
|
||||
return torch.einsum(eqn, x, y)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return EinsumModelBatchMatmul(), ref_net, "aten::einsum"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_einsum_batch_matmul(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestEinsumBatchDiagonal(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
|
||||
return (np.random.randn(3, 5, 5).astype(np.float32),)
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class EinsumModelBatchDiagonal(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
eqn = "kii -> ki"
|
||||
return torch.einsum(eqn, x)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return EinsumModelBatchDiagonal(), ref_net, "aten::einsum"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.xfail(reason='OpenVINO CPU plugin does not support einsum diagonal')
|
||||
def test_einsum_batch_diagonal(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, dynamic_shapes=False)
|
||||
|
||||
|
||||
class TestEinsumInnerProd(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
|
||||
return (np.random.randn(5).astype(np.float32), np.random.randn(5).astype(np.float32))
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class EinsumModelInnerProd(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
eqn = "i,i"
|
||||
return torch.einsum(eqn, x, y)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return EinsumModelInnerProd(), ref_net, "aten::einsum"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_einsum_inner_prod(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestEinsumTranspose(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
|
||||
return (np.random.randn(3, 5).astype(np.float32),)
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
|
||||
class EinsumModelTranspose(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
eqn = "ij->ji"
|
||||
return torch.einsum(eqn, x)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return EinsumModelTranspose(), ref_net, "aten::einsum"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_einsum_transpose(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user