Enable MatMul 1D multiplication by transformation to GEMM (#4509)
* MatMul tests * Tests update * More tests * IE tetst manifest cleanup * Enable MatMul 1D multiplication by transofrmation to GEMM * Add unsqueeze to transform tests * Fix runtime info * Update transformation condition * Remove Xfail from python tests * Add more tests for ignoring 1D transpose true * Ignore transpose for 1D in MatMul transformations * Resolve python api xfails * Use ngraph::opset namespace * Style apply * Add MatMul single layer tests for 1D
This commit is contained in:
parent
635ffc760a
commit
95a13e05d5
@ -39,6 +39,11 @@ ngraph::pass::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
auto shape_b = input_b.get_shape();
|
||||
auto output_shape = matmul->get_shape();
|
||||
|
||||
// Transformation to FC is not supported for 1D second input
|
||||
if (shape_b.size() == 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
/*
|
||||
* get_aligned_shapes function align two input shapes to have the same size and
|
||||
* the same batch dimensions (last two dimensions are not comparable).
|
||||
@ -54,7 +59,7 @@ ngraph::pass::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
for (size_t i = 0, cnt = max_size - shape_b_aligned.size(); i < cnt; ++i)
|
||||
shape_b_aligned.insert(shape_b_aligned.begin(), 1);
|
||||
|
||||
if (matmul->get_transpose_a()) {
|
||||
if (matmul->get_transpose_a() && shape_a.size() != 1) {
|
||||
std::swap(*(shape_a_aligned.end() - 1), *(shape_a_aligned.end() - 2));
|
||||
}
|
||||
if (matmul->get_transpose_b()) {
|
||||
@ -138,7 +143,7 @@ ngraph::pass::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
}
|
||||
|
||||
// Input normalization
|
||||
if (matmul->get_transpose_a()) {
|
||||
if (matmul->get_transpose_a() && shape_a.size() != 1) {
|
||||
fc_input_a = create_transpose(fc_input_a, matmul->get_friendly_name() + "/transpose_a");
|
||||
new_ops.push_back(fc_input_a.get_node_shared_ptr());
|
||||
}
|
||||
@ -185,7 +190,33 @@ ngraph::pass::ConvertMatMulToGemm::ConvertMatMulToGemm() {
|
||||
auto fc_input_a = input_a, fc_input_b = input_b;
|
||||
NodeVector new_ops;
|
||||
|
||||
if (shape_a.size() == 1) {
|
||||
// If the first input is 1D tensor, it is unsqueezed to 2D tensor (row vector)
|
||||
// by adding axes with size 1 at ROW_INDEX_DIM, to the left of the shape.
|
||||
// For example {S} will be reshaped to {1, S}.
|
||||
fc_input_a = std::make_shared<ngraph::opset1::Unsqueeze>(fc_input_a,
|
||||
ngraph::opset1::Constant::create(element::i64, Shape{1}, {0}));
|
||||
shape_a = fc_input_a.get_shape();
|
||||
new_ops.push_back(fc_input_a.get_node_shared_ptr());
|
||||
// For 1D inputs transpose flag is expected to always act like `false`
|
||||
matmul->set_transpose_a(false);
|
||||
}
|
||||
if (shape_b.size() == 1) {
|
||||
// If the second input is 1D tensor, it is unsqueezed to 2D tensor (column vector)
|
||||
// by adding axes with size 1 at COL_INDEX_DIM, to the right of the shape.
|
||||
// For example {S} will be reshaped to {S, 1}.
|
||||
fc_input_b = std::make_shared<ngraph::opset1::Unsqueeze>(fc_input_b,
|
||||
ngraph::opset1::Constant::create(element::i64, Shape{1}, {1}));
|
||||
shape_b = fc_input_b.get_shape();
|
||||
new_ops.push_back(fc_input_b.get_node_shared_ptr());
|
||||
// For 1D inputs transpose flag is expected to always act like `false`
|
||||
matmul->set_transpose_b(false);
|
||||
}
|
||||
|
||||
// WA for IE that Gemm must have inputs with the same length.
|
||||
// If ranks of input arguments are still different,
|
||||
// the smaller tensor is unsqueezed from the left side of the shape
|
||||
// by necessary number of axes to make both shapes of the same rank.
|
||||
if (shape_a.size() < shape_b.size()) {
|
||||
// Reshape first input (fc_input_a)
|
||||
Shape reshape_shape(shape_b.size() - shape_a.size(), 1);
|
||||
@ -194,17 +225,8 @@ ngraph::pass::ConvertMatMulToGemm::ConvertMatMulToGemm() {
|
||||
new_ops.push_back(fc_input_a.get_node_shared_ptr());
|
||||
} else if (shape_b.size() < shape_a.size()) {
|
||||
// Reshape second input (fc_input_b)
|
||||
Shape reshape_shape;
|
||||
if (shape_b.size() == 1) {
|
||||
// In case if shape_b has only one dimension we reshape it to [...,1,X,1]
|
||||
reshape_shape = Shape(shape_a.size() - (shape_b.size() + 1), 1);
|
||||
reshape_shape.push_back(shape_b[0]); // add X dimension
|
||||
reshape_shape.push_back(1); // add last 1 dimension
|
||||
} else {
|
||||
// In this case we reshape shape_b to [...,1,1,X]
|
||||
reshape_shape = Shape(shape_a.size() - shape_b.size(), 1);
|
||||
reshape_shape.insert(reshape_shape.end(), shape_b.begin(), shape_b.end());
|
||||
}
|
||||
Shape reshape_shape(shape_a.size() - shape_b.size(), 1);
|
||||
reshape_shape.insert(reshape_shape.end(), shape_b.begin(), shape_b.end());
|
||||
fc_input_b = op::util::reshapeTo(fc_input_b, reshape_shape);
|
||||
new_ops.push_back(fc_input_b.get_node_shared_ptr());
|
||||
}
|
||||
@ -213,10 +235,18 @@ ngraph::pass::ConvertMatMulToGemm::ConvertMatMulToGemm() {
|
||||
new_ops.push_back(gemm);
|
||||
|
||||
if (gemm->get_shape() != output_shape) {
|
||||
// This case is possible only when second input had exactly 1 dimension (that is not supported by GEMM operation)
|
||||
// and for this case we have to reshape second input to first but this affects output shape (additional dimensions)
|
||||
// This case is possible when one of the inputs has exactly 1 dimension (that is not supported by GEMM operation)
|
||||
// So to preserve output shape we insert additional reshape operation
|
||||
auto reshape_output = op::util::reshapeTo(gemm, output_shape);
|
||||
std::shared_ptr<ngraph::Node> reshape_output;
|
||||
if (output_shape.size() == 0) {
|
||||
std::vector<int64_t> dim_indices(gemm->get_shape().size());
|
||||
std::iota(dim_indices.begin(), dim_indices.end(), 0);
|
||||
reshape_output = std::make_shared<ngraph::opset1::Squeeze>(gemm,
|
||||
ngraph::opset1::Constant::create(element::i64, Shape{dim_indices.size()}, dim_indices));
|
||||
} else {
|
||||
reshape_output = op::util::reshapeTo(gemm, output_shape);
|
||||
}
|
||||
|
||||
new_ops.push_back(reshape_output);
|
||||
gemm->set_friendly_name(matmul->get_friendly_name() + "/gemm");
|
||||
reshape_output->set_friendly_name(matmul->get_friendly_name());
|
||||
|
@ -80,7 +80,9 @@ TEST(TransformationTests, ConvertMatMulTest2) {
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||
auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2});
|
||||
|
||||
auto reshape = ngraph::op::util::reshapeTo(input2, {1, 2, 1});
|
||||
auto usnqueeze_input2 = std::make_shared<ngraph::opset1::Unsqueeze>(input2,
|
||||
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1}));
|
||||
auto reshape = ngraph::op::util::reshapeTo(usnqueeze_input2, {1, 2, 1});
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, reshape, false, false);
|
||||
auto reshape_output = ngraph::op::util::reshapeTo(matmul, {3, 1});
|
||||
|
||||
@ -111,7 +113,9 @@ TEST(TransformationTests, ConvertMatMulTest3) {
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{2});
|
||||
auto input2 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 1});
|
||||
|
||||
auto reshape = ngraph::op::util::reshapeTo(input1, {1, 1, 2});
|
||||
auto usnqueeze_input1 = std::make_shared<ngraph::opset1::Unsqueeze>(input1,
|
||||
ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0}));
|
||||
auto reshape = ngraph::op::util::reshapeTo(usnqueeze_input1, {1, 1, 2});
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(reshape, input2, false, false);
|
||||
auto reshape_output = ngraph::op::util::reshapeTo(matmul, {3, 1});
|
||||
|
||||
|
@ -17,7 +17,11 @@ const std::vector<InferenceEngine::Precision> inputPrecisions = {
|
||||
const std::vector<ShapeRelatedParams> shapeRelatedParams = {
|
||||
{ { {1, 4, 5, 6}, false }, { {1, 4, 6, 4}, false } },
|
||||
{ { {4, 5, 6}, false }, { {6, 3}, false } },
|
||||
{ { {9, 9, 9}, false }, { {9, 9}, false } }
|
||||
{ { {9, 9, 9}, false }, { {9, 9}, false } },
|
||||
{ { {1, 5}, false }, { {5}, false } },
|
||||
{ { {5}, false }, { {5, 1}, false } },
|
||||
{ { {5}, false }, { {5}, false } },
|
||||
{ { {5}, true }, { {5}, true } }
|
||||
};
|
||||
|
||||
std::vector<ngraph::helpers::InputLayerType> secondaryInputTypes = {
|
||||
|
@ -53,6 +53,8 @@ namespace ngraph
|
||||
|
||||
bool get_transpose_a() const { return m_transpose_a; }
|
||||
bool get_transpose_b() const { return m_transpose_b; }
|
||||
void set_transpose_a(bool transpose_a) { m_transpose_a = transpose_a; }
|
||||
void set_transpose_b(bool transpose_b) { m_transpose_b = transpose_b; }
|
||||
private:
|
||||
bool m_transpose_a;
|
||||
bool m_transpose_b;
|
||||
|
@ -67,12 +67,6 @@ xfail_issue_35911 = xfail_test(reason="Assertion error: Pad model mismatch error
|
||||
xfail_issue_35912 = xfail_test(reason="RuntimeError: Error of validate layer: B with type: "
|
||||
"Pad. Cannot parse parameter pads_end from IR for layer B. "
|
||||
"Value -1,0 cannot be casted to int.")
|
||||
xfail_issue_35916 = xfail_test(reason="RuntimeError: Unsupported input dims count for layer Z")
|
||||
xfail_issue_35917 = xfail_test(reason="RuntimeError: Unsupported input dims count for "
|
||||
"layer MatMul")
|
||||
xfail_issue_35918 = xfail_test(reason="onnx.onnx_cpp2py_export.checker.ValidationError: "
|
||||
"Mismatched attribute type in 'test_node : alpha'")
|
||||
xfail_issue_35921 = xfail_test(reason="ValueError - shapes mismatch in gemm")
|
||||
xfail_issue_35923 = xfail_test(reason="RuntimeError: PReLU without weights is not supported")
|
||||
xfail_issue_35925 = xfail_test(reason="Assertion error - reduction ops results mismatch")
|
||||
xfail_issue_35927 = xfail_test(reason="RuntimeError: B has zero dimension that is not allowable")
|
||||
|
@ -20,7 +20,6 @@ import pytest
|
||||
|
||||
from tests.runtime import get_runtime
|
||||
from tests.test_onnx.utils import import_onnx_model
|
||||
from tests import xfail_issue_35916, xfail_issue_35917, xfail_issue_35918, xfail_issue_35921
|
||||
|
||||
|
||||
def make_onnx_model_for_matmul_op(input_left, input_right):
|
||||
@ -104,7 +103,7 @@ def import_and_compute_gemm(input_a, input_b, input_c, **kwargs):
|
||||
@pytest.mark.parametrize(
|
||||
"data, description",
|
||||
[
|
||||
pytest.param(([1, 2], [1, 3]), "vector and vector 1", marks=xfail_issue_35916),
|
||||
pytest.param(([1, 2], [1, 3]), "vector and vector 1"),
|
||||
(([1, 2, 3], [[4], [5], [6]]), "vector and vector 2"),
|
||||
(([[1, 2, 3]], [1, 2, 3]), "vector and vector 3"),
|
||||
(([1, 2, 3], [[4, 5], [6, 7], [8, 9]]), "vector and matrix"),
|
||||
@ -115,7 +114,7 @@ def import_and_compute_gemm(input_a, input_b, input_c, **kwargs):
|
||||
],
|
||||
)
|
||||
def test_op_matmul(data, description):
|
||||
assert np.array_equal(import_and_compute_matmul(*data), np.matmul(*data))
|
||||
assert np.allclose(import_and_compute_matmul(*data), np.matmul(*data))
|
||||
|
||||
|
||||
def test_op_matmul_3d():
|
||||
@ -130,40 +129,39 @@ def test_op_matmul_3d():
|
||||
@pytest.mark.parametrize(
|
||||
"data, kwargs, description",
|
||||
[
|
||||
pytest.param(([1, 2], [1, 3], [1, 4]), {}, "vectors", marks=xfail_issue_35917),
|
||||
pytest.param(([1, 2], [1, 3], 1), {}, "vectors and scalar", marks=xfail_issue_35917),
|
||||
pytest.param(([1, 2], [1, 3], [1]), {}, "vectors and identity vector", marks=xfail_issue_35917),
|
||||
pytest.param(([1, 2], [1, 3], [1, 4]), {"alpha": 7, "beta": 9},
|
||||
"vectors with alpha and beta", marks=xfail_issue_35918),
|
||||
pytest.param(([1, 2, 3, 4], [1, 3, 5, 7], [1, 4]), {"alpha": 7, "beta": 9},
|
||||
"longer vectors with alpha and beta", marks=xfail_issue_35918)
|
||||
pytest.param(([1, 2], [1, 3], [1, 4]), {}, "vectors"),
|
||||
pytest.param(([1, 2], [1, 3], 1), {}, "vectors and scalar"),
|
||||
pytest.param(([1, 2], [1, 3], [1]), {}, "vectors and identity vector"),
|
||||
pytest.param(([1, 2], [1, 3], [1, 4]), {"alpha": 7.0, "beta": 9.0},
|
||||
"vectors with alpha and beta"),
|
||||
pytest.param(([1, 2, 3, 4], [1, 3, 5, 7], [1, 4]), {"alpha": 7.0, "beta": 9.0},
|
||||
"longer vectors with alpha and beta")
|
||||
],
|
||||
)
|
||||
def test_gemm(data, kwargs, description):
|
||||
assert np.array_equal(import_and_compute_gemm(*data, **kwargs), numpy_gemm(*data))
|
||||
assert np.allclose(import_and_compute_gemm(*data, **kwargs), numpy_gemm(*data, **kwargs))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data, kwargs, description",
|
||||
[
|
||||
pytest.param(([1, 2], [1, 3], [1, 4]), {"trans_a": True, "trans_b": True},
|
||||
"vectors with trans_a/trans_b", marks=xfail_issue_35917),
|
||||
"vectors with trans_a/trans_b"),
|
||||
pytest.param(([[1, 2], [1, 2]], [[1, 3], [1, 3]], [4, 1]),
|
||||
{"trans_a": True, "trans_b": True, "alpha": 7, "beta": 9},
|
||||
"matrices and vector with trans_b and alpha/beta", marks=xfail_issue_35918),
|
||||
pytest.param(([[1, 2]], [[1, 3]], 1), {"trans_b": True, "alpha": 7, "beta": 9},
|
||||
"matrices and scalar with trans_b and alpha/beta", marks=xfail_issue_35918),
|
||||
pytest.param(([[1], [2]], [[1], [3]], 1), {"trans_a": True, "alpha": 7, "beta": 9},
|
||||
"matrices and scalar with trans_a and alpha/beta", marks=xfail_issue_35918),
|
||||
{"trans_a": True, "trans_b": True, "alpha": 7.0, "beta": 9.0},
|
||||
"matrices and vector with trans_b and alpha/beta"),
|
||||
pytest.param(([[1, 2]], [[1, 3]], 1), {"trans_b": True, "alpha": 7.0, "beta": 9.0},
|
||||
"matrices and scalar with trans_b and alpha/beta"),
|
||||
pytest.param(([[1], [2]], [[1], [3]], 1), {"trans_a": True, "alpha": 7.0, "beta": 9.0},
|
||||
"matrices and scalar with trans_a and alpha/beta"),
|
||||
],
|
||||
)
|
||||
def test_gemm_transpositions(data, kwargs, description):
|
||||
assert np.array_equal(import_and_compute_gemm(*data, **kwargs), numpy_gemm(*data, **kwargs))
|
||||
|
||||
|
||||
@xfail_issue_35921
|
||||
def test_gemm_flatten():
|
||||
# input_a.shape is (4,1,1)
|
||||
data = ([[[1]], [[2]], [[3]], [[4]]], [1, 3, 5, 7], [1, 4])
|
||||
kwargs = {"alpha": 7, "beta": 9}
|
||||
# input_a.shape is (4,1)
|
||||
data = ([[1], [2], [3], [4]], [1, 3, 5, 7], [1, 4])
|
||||
kwargs = {"alpha": 7.0, "beta": 9.0, "trans_a": True}
|
||||
assert np.array_equal(import_and_compute_gemm(*data, **kwargs), numpy_gemm(*data, **kwargs))
|
||||
|
@ -18,7 +18,6 @@ import onnx
|
||||
import pytest
|
||||
|
||||
from tests.test_onnx.utils import run_node
|
||||
from tests import xfail_issue_35918
|
||||
|
||||
|
||||
def import_and_compute(op_type, input_data, **node_attrs):
|
||||
@ -107,10 +106,10 @@ def test_selu():
|
||||
@pytest.mark.parametrize(
|
||||
"data, alpha_value",
|
||||
[
|
||||
pytest.param([-2, -1.0, 0.0, 1.0, 2.0], 1, marks=xfail_issue_35918),
|
||||
pytest.param([0.0], 1, marks=xfail_issue_35918),
|
||||
pytest.param([-0.9, -0.8, -0.7, -0.4, -0.3, -0.2, -0.1], 1, marks=xfail_issue_35918),
|
||||
pytest.param([[1, 2, 3], [4, 5, 6]], 1, marks=xfail_issue_35918),
|
||||
pytest.param([-2, -1.0, 0.0, 1.0, 2.0], 1.0),
|
||||
pytest.param([0.0], 1.0),
|
||||
pytest.param([-0.9, -0.8, -0.7, -0.4, -0.3, -0.2, -0.1], 1.0),
|
||||
pytest.param([[1, 2, 3], [4, 5, 6]], 1.0),
|
||||
pytest.param([-2, -1.0, 0.0, 1.0, 2.0], 0.5)
|
||||
]
|
||||
)
|
||||
|
@ -441,6 +441,111 @@ NGRAPH_TEST(${BACKEND_NAME}, matmul_1x2x3_1x4x3x2)
|
||||
256.f}));
|
||||
}
|
||||
|
||||
// 1D x 1D
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_3_false_false_param)
|
||||
{
|
||||
Shape shape_in1{3};
|
||||
Shape shape_in2{3};
|
||||
Shape shape_out{};
|
||||
|
||||
bool transpose_a = false;
|
||||
bool transpose_b = false;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
test_case.add_input<float>(inputs_b);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_3_true_true_param)
|
||||
{
|
||||
Shape shape_in1{3};
|
||||
Shape shape_in2{3};
|
||||
Shape shape_out{};
|
||||
|
||||
// For 1D inputs transpose is expected to be ignored
|
||||
bool transpose_a = true;
|
||||
bool transpose_b = true;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
test_case.add_input<float>(inputs_b);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_3_false_false_const)
|
||||
{
|
||||
Shape shape_in1{3};
|
||||
Shape shape_in2{3};
|
||||
Shape shape_out{};
|
||||
|
||||
bool transpose_a = false;
|
||||
bool transpose_b = false;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_3_true_true_const)
|
||||
{
|
||||
Shape shape_in1{3};
|
||||
Shape shape_in2{3};
|
||||
Shape shape_out{};
|
||||
|
||||
// For 1D inputs transpose is expected to be ignored
|
||||
bool transpose_a = true;
|
||||
bool transpose_b = true;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
// 2D x 1D
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_1_3_x_3_false_false_param)
|
||||
{
|
||||
@ -468,6 +573,31 @@ NGRAPH_TEST(${BACKEND_NAME}, matmul_1_3_x_3_false_false_param)
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_1_3_x_3_false_false_const)
|
||||
{
|
||||
Shape shape_in1{1, 3};
|
||||
Shape shape_in2{3};
|
||||
Shape shape_out{1};
|
||||
|
||||
bool transpose_a = false;
|
||||
bool transpose_b = false;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_1_x_3_true_false_param)
|
||||
{
|
||||
Shape shape_in1{3, 1};
|
||||
@ -494,6 +624,31 @@ NGRAPH_TEST(${BACKEND_NAME}, matmul_3_1_x_3_true_false_param)
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_1_x_3_true_false_const)
|
||||
{
|
||||
Shape shape_in1{3, 1};
|
||||
Shape shape_in2{3};
|
||||
Shape shape_out{1};
|
||||
|
||||
bool transpose_a = true;
|
||||
bool transpose_b = false;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
// 1D x 2D
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_3_1_false_false_param)
|
||||
{
|
||||
@ -521,6 +676,31 @@ NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_3_1_false_false_param)
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_3_1_false_false_const)
|
||||
{
|
||||
Shape shape_in1{3};
|
||||
Shape shape_in2{3, 1};
|
||||
Shape shape_out{1};
|
||||
|
||||
bool transpose_a = false;
|
||||
bool transpose_b = false;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_3_false_true_param)
|
||||
{
|
||||
Shape shape_in1{3};
|
||||
@ -546,3 +726,285 @@ NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_3_false_true_param)
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_3_false_true_const)
|
||||
{
|
||||
Shape shape_in1{3};
|
||||
Shape shape_in2{1, 3};
|
||||
Shape shape_out{1};
|
||||
|
||||
bool transpose_a = false;
|
||||
bool transpose_b = true;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_3_true_true_param)
|
||||
{
|
||||
Shape shape_in1{3};
|
||||
Shape shape_in2{1, 3};
|
||||
Shape shape_out{1};
|
||||
|
||||
bool transpose_a = true;
|
||||
bool transpose_b = true;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
test_case.add_input<float>(inputs_b);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_3_true_true_const)
|
||||
{
|
||||
Shape shape_in1{3};
|
||||
Shape shape_in2{1, 3};
|
||||
Shape shape_out{1};
|
||||
|
||||
bool transpose_a = true;
|
||||
bool transpose_b = true;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
// 3D x 1D
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_1_1_3_x_3_false_false_param)
|
||||
{
|
||||
Shape shape_in1{1, 1, 3};
|
||||
Shape shape_in2{3};
|
||||
Shape shape_out{1, 1};
|
||||
|
||||
bool transpose_a = false;
|
||||
bool transpose_b = false;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
test_case.add_input<float>(inputs_b);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_1_1_3_x_3_false_false_const)
|
||||
{
|
||||
Shape shape_in1{1, 1, 3};
|
||||
Shape shape_in2{3};
|
||||
Shape shape_out{1, 1};
|
||||
|
||||
bool transpose_a = false;
|
||||
bool transpose_b = false;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_1_3_1_x_3_true_false_param)
|
||||
{
|
||||
Shape shape_in1{1, 3, 1};
|
||||
Shape shape_in2{3};
|
||||
Shape shape_out{1, 1};
|
||||
|
||||
bool transpose_a = true;
|
||||
bool transpose_b = false;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
test_case.add_input<float>(inputs_b);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_1_3_1_x_3_true_false_const)
|
||||
{
|
||||
Shape shape_in1{1, 3, 1};
|
||||
Shape shape_in2{3};
|
||||
Shape shape_out{1, 1};
|
||||
|
||||
bool transpose_a = true;
|
||||
bool transpose_b = false;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
// 1D x 3D
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_3_1_false_false_param)
|
||||
{
|
||||
Shape shape_in1{3};
|
||||
Shape shape_in2{1, 3, 1};
|
||||
Shape shape_out{1, 1};
|
||||
|
||||
bool transpose_a = false;
|
||||
bool transpose_b = false;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
test_case.add_input<float>(inputs_b);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_1_3_false_true_param)
|
||||
{
|
||||
Shape shape_in1{3};
|
||||
Shape shape_in2{1, 1, 3};
|
||||
Shape shape_out{1, 1};
|
||||
|
||||
bool transpose_a = false;
|
||||
bool transpose_b = true;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Parameter>(element::f32, shape_in2);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A, B});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
test_case.add_input<float>(inputs_b);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_3_1_false_false_const)
|
||||
{
|
||||
Shape shape_in1{3};
|
||||
Shape shape_in2{1, 3, 1};
|
||||
Shape shape_out{1, 1};
|
||||
|
||||
bool transpose_a = false;
|
||||
bool transpose_b = false;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, matmul_3_x_1_1_3_false_true_const)
|
||||
{
|
||||
Shape shape_in1{3};
|
||||
Shape shape_in2{1, 1, 3};
|
||||
Shape shape_out{1, 1};
|
||||
|
||||
bool transpose_a = false;
|
||||
bool transpose_b = true;
|
||||
|
||||
std::vector<float> inputs_a{1, 2, 3};
|
||||
std::vector<float> inputs_b{1, 2, 3};
|
||||
std::vector<float> expected_result{14.};
|
||||
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_in1);
|
||||
auto B = make_shared<op::Constant>(element::f32, shape_in2, inputs_b);
|
||||
auto matmul = make_shared<op::MatMul>(A, B, transpose_a, transpose_b);
|
||||
auto f = make_shared<Function>(matmul, ParameterVector{A});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(f);
|
||||
test_case.add_input<float>(inputs_a);
|
||||
|
||||
test_case.add_expected_output<float>(shape_out, expected_result);
|
||||
test_case.run();
|
||||
}
|
||||
|
@ -456,9 +456,6 @@ dynamic_reverse_shape
|
||||
tile_3d_small_data_rank
|
||||
tile_3d_few_repeats
|
||||
|
||||
# Error of validate layer: MatMul_683292 with type: Gemm. Gemm input shapes must have at least 2 dimensions
|
||||
matmul_2_2
|
||||
|
||||
# Result mismatch
|
||||
sum_large_1d_to_scalar
|
||||
sum_stable_acc
|
||||
|
Loading…
Reference in New Issue
Block a user