Enable MatMul 1D multiplication by transformation to GEMM (#4509)

* MatMul tests

* Tests update

* More tests

* IE tetst manifest cleanup

* Enable MatMul 1D multiplication by transofrmation to GEMM

* Add unsqueeze to transform tests

* Fix runtime info

* Update transformation condition

* Remove Xfail from python tests

* Add more tests for ignoring 1D transpose true

* Ignore transpose for 1D in MatMul transformations

* Resolve python api xfails

* Use ngraph::opset namespace

* Style apply

* Add MatMul single layer tests for 1D
This commit is contained in:
Katarzyna Mitrus 2021-03-15 11:44:57 +01:00 committed by GitHub
parent 635ffc760a
commit 95a13e05d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 545 additions and 55 deletions

View File

@ -39,6 +39,11 @@ ngraph::pass::ConvertMatMulToFC::ConvertMatMulToFC() {
auto shape_b = input_b.get_shape();
auto output_shape = matmul->get_shape();
// Transformation to FC is not supported for 1D second input
if (shape_b.size() == 1) {
return false;
}
/*
* get_aligned_shapes function align two input shapes to have the same size and
* the same batch dimensions (last two dimensions are not comparable).
@ -54,7 +59,7 @@ ngraph::pass::ConvertMatMulToFC::ConvertMatMulToFC() {
for (size_t i = 0, cnt = max_size - shape_b_aligned.size(); i < cnt; ++i)
shape_b_aligned.insert(shape_b_aligned.begin(), 1);
if (matmul->get_transpose_a()) {
if (matmul->get_transpose_a() && shape_a.size() != 1) {
std::swap(*(shape_a_aligned.end() - 1), *(shape_a_aligned.end() - 2));
}
if (matmul->get_transpose_b()) {
@ -138,7 +143,7 @@ ngraph::pass::ConvertMatMulToFC::ConvertMatMulToFC() {
}
// Input normalization
if (matmul->get_transpose_a()) {
if (matmul->get_transpose_a() && shape_a.size() != 1) {
fc_input_a = create_transpose(fc_input_a, matmul->get_friendly_name() + "/transpose_a");
new_ops.push_back(fc_input_a.get_node_shared_ptr());
}
@ -185,7 +190,33 @@ ngraph::pass::ConvertMatMulToGemm::ConvertMatMulToGemm() {
auto fc_input_a = input_a, fc_input_b = input_b;
NodeVector new_ops;
if (shape_a.size() == 1) {
// If the first input is 1D tensor, it is unsqueezed to 2D tensor (row vector)
// by adding axes with size 1 at ROW_INDEX_DIM, to the left of the shape.
// For example {S} will be reshaped to {1, S}.
fc_input_a = std::make_shared<ngraph::opset1::Unsqueeze>(fc_input_a,
ngraph::opset1::Constant::create(element::i64, Shape{1}, {0}));
shape_a = fc_input_a.get_shape();
new_ops.push_back(fc_input_a.get_node_shared_ptr());
// For 1D inputs transpose flag is expected to always act like `false`
matmul->set_transpose_a(false);
}
if (shape_b.size() == 1) {
// If the second input is 1D tensor, it is unsqueezed to 2D tensor (column vector)
// by adding axes with size 1 at COL_INDEX_DIM, to the right of the shape.
// For example {S} will be reshaped to {S, 1}.
fc_input_b = std::make_shared<ngraph::opset1::Unsqueeze>(fc_input_b,
ngraph::opset1::Constant::create(element::i64, Shape{1}, {1}));
shape_b = fc_input_b.get_shape();
new_ops.push_back(fc_input_b.get_node_shared_ptr());
// For 1D inputs transpose flag is expected to always act like `false`
matmul->set_transpose_b(false);
}
// WA for IE that Gemm must have inputs with the same length.
// If ranks of input arguments are still different,
// the smaller tensor is unsqueezed from the left side of the shape
// by necessary number of axes to make both shapes of the same rank.
if (shape_a.size() < shape_b.size()) {
// Reshape first input (fc_input_a)
Shape reshape_shape(shape_b.size() - shape_a.size(), 1);
@ -194,17 +225,8 @@ ngraph::pass::ConvertMatMulToGemm::ConvertMatMulToGemm() {
new_ops.push_back(fc_input_a.get_node_shared_ptr());
} else if (shape_b.size() < shape_a.size()) {
// Reshape second input (fc_input_b)
Shape reshape_shape;
if (shape_b.size() == 1) {
// In case if shape_b has only one dimension we reshape it to [...,1,X,1]
reshape_shape = Shape(shape_a.size() - (shape_b.size() + 1), 1);
reshape_shape.push_back(shape_b[0]); // add X dimension
reshape_shape.push_back(1); // add last 1 dimension
} else {
// In this case we reshape shape_b to [...,1,1,X]
reshape_shape = Shape(shape_a.size() - shape_b.size(), 1);
reshape_shape.insert(reshape_shape.end(), shape_b.begin(), shape_b.end());
}
Shape reshape_shape(shape_a.size() - shape_b.size(), 1);
reshape_shape.insert(reshape_shape.end(), shape_b.begin(), shape_b.end());
fc_input_b = op::util::reshapeTo(fc_input_b, reshape_shape);
new_ops.push_back(fc_input_b.get_node_shared_ptr());
}
@ -213,10 +235,18 @@ ngraph::pass::ConvertMatMulToGemm::ConvertMatMulToGemm() {
new_ops.push_back(gemm);
if (gemm->get_shape() != output_shape) {
// This case is possible only when second input had exactly 1 dimension (that is not supported by GEMM operation)
// and for this case we have to reshape second input to first but this affects output shape (additional dimensions)
// This case is possible when one of the inputs has exactly 1 dimension (that is not supported by GEMM operation)
// So to preserve output shape we insert additional reshape operation
auto reshape_output = op::util::reshapeTo(gemm, output_shape);
std::shared_ptr<ngraph::Node> reshape_output;
if (output_shape.size() == 0) {
std::vector<int64_t> dim_indices(gemm->get_shape().size());
std::iota(dim_indices.begin(), dim_indices.end(), 0);
reshape_output = std::make_shared<ngraph::opset1::Squeeze>(gemm,
ngraph::opset1::Constant::create(element::i64, Shape{dim_indices.size()}, dim_indices));
} else {
reshape_output = op::util::reshapeTo(gemm, output_shape);
}
new_ops.push_back(reshape_output);
gemm->set_friendly_name(matmul->get_friendly_name() + "/gemm");
reshape_output->set_friendly_name(matmul->get_friendly_name());

View File

@ -80,7 +80,9 @@ TEST(TransformationTests, ConvertMatMulTest2) {
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2});
auto reshape = ngraph::op::util::reshapeTo(input2, {1, 2, 1});
auto usnqueeze_input2 = std::make_shared<ngraph::opset1::Unsqueeze>(input2,
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1}));
auto reshape = ngraph::op::util::reshapeTo(usnqueeze_input2, {1, 2, 1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, reshape, false, false);
auto reshape_output = ngraph::op::util::reshapeTo(matmul, {3, 1});
@ -111,7 +113,9 @@ TEST(TransformationTests, ConvertMatMulTest3) {
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2});
auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 1});
auto reshape = ngraph::op::util::reshapeTo(input1, {1, 1, 2});
auto usnqueeze_input1 = std::make_shared<ngraph::opset1::Unsqueeze>(input1,
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}));
auto reshape = ngraph::op::util::reshapeTo(usnqueeze_input1, {1, 1, 2});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(reshape, input2, false, false);
auto reshape_output = ngraph::op::util::reshapeTo(matmul, {3, 1});

View File

@ -17,7 +17,11 @@ const std::vector<InferenceEngine::Precision> inputPrecisions = {
const std::vector<ShapeRelatedParams> shapeRelatedParams = {
{ { {1, 4, 5, 6}, false }, { {1, 4, 6, 4}, false } },
{ { {4, 5, 6}, false }, { {6, 3}, false } },
{ { {9, 9, 9}, false }, { {9, 9}, false } }
{ { {9, 9, 9}, false }, { {9, 9}, false } },
{ { {1, 5}, false }, { {5}, false } },
{ { {5}, false }, { {5, 1}, false } },
{ { {5}, false }, { {5}, false } },
{ { {5}, true }, { {5}, true } }
};
std::vector<ngraph::helpers::InputLayerType> secondaryInputTypes = {

View File

@ -53,6 +53,8 @@ namespace ngraph
bool get_transpose_a() const { return m_transpose_a; }
bool get_transpose_b() const { return m_transpose_b; }
void set_transpose_a(bool transpose_a) { m_transpose_a = transpose_a; }
void set_transpose_b(bool transpose_b) { m_transpose_b = transpose_b; }
private:
bool m_transpose_a;
bool m_transpose_b;

View File

@ -67,12 +67,6 @@ xfail_issue_35911 = xfail_test(reason="Assertion error: Pad model mismatch error
xfail_issue_35912 = xfail_test(reason="RuntimeError: Error of validate layer: B with type: "
"Pad. Cannot parse parameter pads_end from IR for layer B. "
"Value -1,0 cannot be casted to int.")
xfail_issue_35916 = xfail_test(reason="RuntimeError: Unsupported input dims count for layer Z")
xfail_issue_35917 = xfail_test(reason="RuntimeError: Unsupported input dims count for "
"layer MatMul")
xfail_issue_35918 = xfail_test(reason="onnx.onnx_cpp2py_export.checker.ValidationError: "
"Mismatched attribute type in 'test_node : alpha'")
xfail_issue_35921 = xfail_test(reason="ValueError - shapes mismatch in gemm")
xfail_issue_35923 = xfail_test(reason="RuntimeError: PReLU without weights is not supported")
xfail_issue_35925 = xfail_test(reason="Assertion error - reduction ops results mismatch")
xfail_issue_35927 = xfail_test(reason="RuntimeError: B has zero dimension that is not allowable")

View File

@ -20,7 +20,6 @@ import pytest
from tests.runtime import get_runtime
from tests.test_onnx.utils import import_onnx_model
from tests import xfail_issue_35916, xfail_issue_35917, xfail_issue_35918, xfail_issue_35921
def make_onnx_model_for_matmul_op(input_left, input_right):
@ -104,7 +103,7 @@ def import_and_compute_gemm(input_a, input_b, input_c, **kwargs):
@pytest.mark.parametrize(
"data, description",
[
pytest.param(([1, 2], [1, 3]), "vector and vector 1", marks=xfail_issue_35916),
pytest.param(([1, 2], [1, 3]), "vector and vector 1"),
(([1, 2, 3], [[4], [5], [6]]), "vector and vector 2"),
(([[1, 2, 3]], [1, 2, 3]), "vector and vector 3"),
(([1, 2, 3], [[4, 5], [6, 7], [8, 9]]), "vector and matrix"),
@ -115,7 +114,7 @@ def import_and_compute_gemm(input_a, input_b, input_c, **kwargs):
],
)
def test_op_matmul(data, description):
assert np.array_equal(import_and_compute_matmul(*data), np.matmul(*data))
assert np.allclose(import_and_compute_matmul(*data), np.matmul(*data))
def test_op_matmul_3d():
@ -130,40 +129,39 @@ def test_op_matmul_3d():
@pytest.mark.parametrize(
"data, kwargs, description",
[
pytest.param(([1, 2], [1, 3], [1, 4]), {}, "vectors", marks=xfail_issue_35917),
pytest.param(([1, 2], [1, 3], 1), {}, "vectors and scalar", marks=xfail_issue_35917),
pytest.param(([1, 2], [1, 3], [1]), {}, "vectors and identity vector", marks=xfail_issue_35917),
pytest.param(([1, 2], [1, 3], [1, 4]), {"alpha": 7, "beta": 9},
"vectors with alpha and beta", marks=xfail_issue_35918),
pytest.param(([1, 2, 3, 4], [1, 3, 5, 7], [1, 4]), {"alpha": 7, "beta": 9},
"longer vectors with alpha and beta", marks=xfail_issue_35918)
pytest.param(([1, 2], [1, 3], [1, 4]), {}, "vectors"),
pytest.param(([1, 2], [1, 3], 1), {}, "vectors and scalar"),
pytest.param(([1, 2], [1, 3], [1]), {}, "vectors and identity vector"),
pytest.param(([1, 2], [1, 3], [1, 4]), {"alpha": 7.0, "beta": 9.0},
"vectors with alpha and beta"),
pytest.param(([1, 2, 3, 4], [1, 3, 5, 7], [1, 4]), {"alpha": 7.0, "beta": 9.0},
"longer vectors with alpha and beta")
],
)
def test_gemm(data, kwargs, description):
assert np.array_equal(import_and_compute_gemm(*data, **kwargs), numpy_gemm(*data))
assert np.allclose(import_and_compute_gemm(*data, **kwargs), numpy_gemm(*data, **kwargs))
@pytest.mark.parametrize(
"data, kwargs, description",
[
pytest.param(([1, 2], [1, 3], [1, 4]), {"trans_a": True, "trans_b": True},
"vectors with trans_a/trans_b", marks=xfail_issue_35917),
"vectors with trans_a/trans_b"),
pytest.param(([[1, 2], [1, 2]], [[1, 3], [1, 3]], [4, 1]),
{"trans_a": True, "trans_b": True, "alpha": 7, "beta": 9},
"matrices and vector with trans_b and alpha/beta", marks=xfail_issue_35918),
pytest.param(([[1, 2]], [[1, 3]], 1), {"trans_b": True, "alpha": 7, "beta": 9},
"matrices and scalar with trans_b and alpha/beta", marks=xfail_issue_35918),
pytest.param(([[1], [2]], [[1], [3]], 1), {"trans_a": True, "alpha": 7, "beta": 9},
"matrices and scalar with trans_a and alpha/beta", marks=xfail_issue_35918),
{"trans_a": True, "trans_b": True, "alpha": 7.0, "beta": 9.0},
"matrices and vector with trans_b and alpha/beta"),
pytest.param(([[1, 2]], [[1, 3]], 1), {"trans_b": True, "alpha": 7.0, "beta": 9.0},
"matrices and scalar with trans_b and alpha/beta"),
pytest.param(([[1], [2]], [[1], [3]], 1), {"trans_a": True, "alpha": 7.0, "beta": 9.0},
"matrices and scalar with trans_a and alpha/beta"),
],
)
def test_gemm_transpositions(data, kwargs, description):
assert np.array_equal(import_and_compute_gemm(*data, **kwargs), numpy_gemm(*data, **kwargs))
@xfail_issue_35921
def test_gemm_flatten():
# input_a.shape is (4,1,1)
data = ([[[1]], [[2]], [[3]], [[4]]], [1, 3, 5, 7], [1, 4])
kwargs = {"alpha": 7, "beta": 9}
# input_a.shape is (4,1)
data = ([[1], [2], [3], [4]], [1, 3, 5, 7], [1, 4])
kwargs = {"alpha": 7.0, "beta": 9.0, "trans_a": True}
assert np.array_equal(import_and_compute_gemm(*data, **kwargs), numpy_gemm(*data, **kwargs))

View File

@ -18,7 +18,6 @@ import onnx
import pytest
from tests.test_onnx.utils import run_node
from tests import xfail_issue_35918
def import_and_compute(op_type, input_data, **node_attrs):
@ -107,10 +106,10 @@ def test_selu():
@pytest.mark.parametrize(
"data, alpha_value",
[
pytest.param([-2, -1.0, 0.0, 1.0, 2.0], 1, marks=xfail_issue_35918),
pytest.param([0.0], 1, marks=xfail_issue_35918),
pytest.param([-0.9, -0.8, -0.7, -0.4, -0.3, -0.2, -0.1], 1, marks=xfail_issue_35918),
pytest.param([[1, 2, 3], [4, 5, 6]], 1, marks=xfail_issue_35918),
pytest.param([-2, -1.0, 0.0, 1.0, 2.0], 1.0),
pytest.param([0.0], 1.0),
pytest.param([-0.9, -0.8, -0.7, -0.4, -0.3, -0.2, -0.1], 1.0),
pytest.param([[1, 2, 3], [4, 5, 6]], 1.0),
pytest.param([-2, -1.0, 0.0, 1.0, 2.0], 0.5)
]
)

View File

@ -441,6 +441,111 @@ NGRAPH_TEST(${BACKEND_NAME}, matmul_1x2x3_1x4x3x2)
256.f}));
}
// 1D x 1D
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_3_false_false_param)
{
Shape shape_in1{3};
Shape shape_in2{3};
Shape shape_out{};
bool transpose_a = false;
bool transpose_b = false;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_input<float>(inputs_b);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_3_true_true_param)
{
Shape shape_in1{3};
Shape shape_in2{3};
Shape shape_out{};
// For 1D inputs transpose is expected to be ignored
bool transpose_a = true;
bool transpose_b = true;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_input<float>(inputs_b);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_3_false_false_const)
{
Shape shape_in1{3};
Shape shape_in2{3};
Shape shape_out{};
bool transpose_a = false;
bool transpose_b = false;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_3_true_true_const)
{
Shape shape_in1{3};
Shape shape_in2{3};
Shape shape_out{};
// For 1D inputs transpose is expected to be ignored
bool transpose_a = true;
bool transpose_b = true;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
// 2D x 1D
NGRAPH_TEST(${BACKEND_NAME}, matmul_1_3_x_3_false_false_param)
{
@ -468,6 +573,31 @@ NGRAPH_TEST(${BACKEND_NAME}, matmul_1_3_x_3_false_false_param)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_1_3_x_3_false_false_const)
{
Shape shape_in1{1, 3};
Shape shape_in2{3};
Shape shape_out{1};
bool transpose_a = false;
bool transpose_b = false;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_1_x_3_true_false_param)
{
Shape shape_in1{3, 1};
@ -494,6 +624,31 @@ NGRAPH_TEST(${BACKEND_NAME}, matmul_3_1_x_3_true_false_param)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_1_x_3_true_false_const)
{
Shape shape_in1{3, 1};
Shape shape_in2{3};
Shape shape_out{1};
bool transpose_a = true;
bool transpose_b = false;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
// 1D x 2D
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_3_1_false_false_param)
{
@ -521,6 +676,31 @@ NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_3_1_false_false_param)
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_3_1_false_false_const)
{
Shape shape_in1{3};
Shape shape_in2{3, 1};
Shape shape_out{1};
bool transpose_a = false;
bool transpose_b = false;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_3_false_true_param)
{
Shape shape_in1{3};
@ -546,3 +726,285 @@ NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_3_false_true_param)
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_3_false_true_const)
{
Shape shape_in1{3};
Shape shape_in2{1, 3};
Shape shape_out{1};
bool transpose_a = false;
bool transpose_b = true;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_3_true_true_param)
{
Shape shape_in1{3};
Shape shape_in2{1, 3};
Shape shape_out{1};
bool transpose_a = true;
bool transpose_b = true;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_input<float>(inputs_b);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_3_true_true_const)
{
Shape shape_in1{3};
Shape shape_in2{1, 3};
Shape shape_out{1};
bool transpose_a = true;
bool transpose_b = true;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
// 3D x 1D
NGRAPH_TEST(${BACKEND_NAME}, matmul_1_1_3_x_3_false_false_param)
{
Shape shape_in1{1, 1, 3};
Shape shape_in2{3};
Shape shape_out{1, 1};
bool transpose_a = false;
bool transpose_b = false;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_input<float>(inputs_b);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_1_1_3_x_3_false_false_const)
{
Shape shape_in1{1, 1, 3};
Shape shape_in2{3};
Shape shape_out{1, 1};
bool transpose_a = false;
bool transpose_b = false;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_1_3_1_x_3_true_false_param)
{
Shape shape_in1{1, 3, 1};
Shape shape_in2{3};
Shape shape_out{1, 1};
bool transpose_a = true;
bool transpose_b = false;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_input<float>(inputs_b);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_1_3_1_x_3_true_false_const)
{
Shape shape_in1{1, 3, 1};
Shape shape_in2{3};
Shape shape_out{1, 1};
bool transpose_a = true;
bool transpose_b = false;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
// 1D x 3D
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_3_1_false_false_param)
{
Shape shape_in1{3};
Shape shape_in2{1, 3, 1};
Shape shape_out{1, 1};
bool transpose_a = false;
bool transpose_b = false;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_input<float>(inputs_b);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_1_3_false_true_param)
{
Shape shape_in1{3};
Shape shape_in2{1, 1, 3};
Shape shape_out{1, 1};
bool transpose_a = false;
bool transpose_b = true;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_input<float>(inputs_b);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_3_1_false_false_const)
{
Shape shape_in1{3};
Shape shape_in2{1, 3, 1};
Shape shape_out{1, 1};
bool transpose_a = false;
bool transpose_b = false;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_1_3_false_true_const)
{
Shape shape_in1{3};
Shape shape_in2{1, 1, 3};
Shape shape_out{1, 1};
bool transpose_a = false;
bool transpose_b = true;
std::vector<float> inputs_a{1, 2, 3};
std::vector<float> inputs_b{1, 2, 3};
std::vector<float> expected_result{14.};
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
auto f = make_shared<Function>(matmul, ParameterVector{A});
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(inputs_a);
test_case.add_expected_output<float>(shape_out, expected_result);
test_case.run();
}

View File

@ -456,9 +456,6 @@ dynamic_reverse_shape
tile_3d_small_data_rank
tile_3d_few_repeats
# Error of validate layer: MatMul_683292 with type: Gemm. Gemm input shapes must have at least 2 dimensions
matmul_2_2
# Result mismatch
sum_large_1d_to_scalar
sum_stable_acc