Review opset1 matmul shape inference aspects (#13333)
* Add label propagate test for shape inference * Update MatMul tests for StaticShape * MatMul label propagation, reshape input with label * Apply clang-format to matmul shape inference tests
This commit is contained in:
parent
f6d6f5629f
commit
2b70158047
@ -2,12 +2,14 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "dimension_tracker.hpp"
|
||||
#include "gmock/gmock.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using namespace testing;
|
||||
|
||||
TEST(type_prop, matmul_2D_same) {
|
||||
auto A = make_shared<op::Parameter>(element::f32, Shape{2, 2});
|
||||
@ -472,3 +474,84 @@ TEST(type_prop, matmul_incompatible_batch_dim_bounds) {
|
||||
ASSERT_EQ(matmul->get_element_type(), element::f32);
|
||||
ASSERT_EQ(matmul->get_output_partial_shape(0), expected_output_shape);
|
||||
}
|
||||
|
||||
TEST(type_prop, matmul_propagate_labels) {
|
||||
const auto a_labels = std::vector<size_t>{1, 0, 2, 5, 6};
|
||||
const auto b_labels = std::vector<size_t>{0, 1, 3, 7, 9};
|
||||
|
||||
auto a_shape = PartialShape{4, 2, 3, 6, 4};
|
||||
auto b_shape = PartialShape{4, 2, 3, 4, 2};
|
||||
|
||||
set_shape_labels(a_shape, a_labels);
|
||||
set_shape_labels(b_shape, b_labels);
|
||||
|
||||
const auto a = make_shared<op::Parameter>(element::f32, a_shape);
|
||||
const auto b = make_shared<op::Parameter>(element::f32, b_shape);
|
||||
const auto matmul = make_shared<op::MatMul>(a, b, false, false);
|
||||
|
||||
const auto& output_shape = matmul->get_output_partial_shape(0);
|
||||
const auto labels = get_shape_labels(output_shape);
|
||||
|
||||
ASSERT_THAT(labels,
|
||||
ElementsAre(a_labels[0], // use a label, b is not set
|
||||
b_labels[1], // use b label, a is not set
|
||||
0, // not set label. a,b has different labels
|
||||
a_labels[3], // use label from a, b is lost
|
||||
b_labels[4] // use label from b, a is lost
|
||||
));
|
||||
}
|
||||
|
||||
TEST(type_prop, matmul_propagate_labels_on_interval_dims) {
|
||||
const auto a_labels = std::vector<size_t>{1, 0, 3, 4, 5};
|
||||
const auto b_labels = std::vector<size_t>{0, 1, 3, 4, 7};
|
||||
|
||||
auto a_shape = PartialShape{Dimension(1, 3), 1, Dimension(2, 3), Dimension(3, 4), 4};
|
||||
auto b_shape = PartialShape{1, Dimension(1, 5), Dimension(1, 3), 4, Dimension::dynamic()};
|
||||
|
||||
set_shape_labels(a_shape, a_labels);
|
||||
set_shape_labels(b_shape, b_labels);
|
||||
|
||||
const auto a = make_shared<op::Parameter>(element::f32, a_shape);
|
||||
const auto b = make_shared<op::Parameter>(element::f32, b_shape);
|
||||
const auto matmul = make_shared<op::MatMul>(a, b, false, false);
|
||||
|
||||
const auto& output_shape = matmul->get_output_partial_shape(0);
|
||||
const auto labels = get_shape_labels(output_shape);
|
||||
|
||||
ASSERT_THAT(labels,
|
||||
ElementsAre(a_labels[0], // use a label, b is not set
|
||||
b_labels[1], // use b label, a is not set
|
||||
a_labels[2], // use a label, b is same
|
||||
a_labels[3], // use label from a, b is lost
|
||||
b_labels[4] // use label from a, b is lost
|
||||
));
|
||||
}
|
||||
|
||||
TEST(type_prop, matmul_propagate_label_on_b_input_after_reshape) {
|
||||
constexpr size_t my_label = 2;
|
||||
auto marked_dim = Dimension(2, 3);
|
||||
ov::DimensionTracker::set_label(marked_dim, my_label);
|
||||
|
||||
const auto a_shape = PartialShape{Dimension::dynamic(), 5, 3};
|
||||
const auto b_shape = PartialShape{3, marked_dim, 2};
|
||||
|
||||
const auto b = make_shared<op::Parameter>(element::f32, b_shape);
|
||||
const auto shape_of_b = std::make_shared<op::ShapeOf>(b);
|
||||
const auto gather = std::make_shared<op::v7::Gather>(
|
||||
shape_of_b,
|
||||
std::make_shared<op::Constant>(element::i64, Shape{2}, std::vector<int64_t>{1, 0}),
|
||||
std::make_shared<op::Constant>(element::i64, Shape{}, 0));
|
||||
const auto concat =
|
||||
std::make_shared<op::Concat>(OutputVector{gather, std::make_shared<op::Constant>(element::i64, Shape{1}, 8)},
|
||||
0);
|
||||
const auto reshape_b = make_shared<op::v1::Reshape>(b, concat, false);
|
||||
|
||||
const auto a = make_shared<op::Parameter>(element::f32, a_shape);
|
||||
const auto matmul = make_shared<op::MatMul>(a, reshape_b, false, false);
|
||||
|
||||
const auto& output_shape = matmul->get_output_partial_shape(0);
|
||||
const auto labels = get_shape_labels(output_shape);
|
||||
|
||||
ASSERT_THAT(labels, ElementsAre(my_label, 0, 0));
|
||||
ASSERT_EQ(output_shape, (PartialShape{marked_dim, 5, 8}));
|
||||
}
|
||||
|
@ -12,15 +12,113 @@
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
using namespace testing;
|
||||
|
||||
TEST(StaticShapeInferenceTest, MatMulTest) {
|
||||
auto A_input = std::make_shared<op::v0::Parameter>(element::i64, PartialShape{-1, -1, -1});
|
||||
auto B_input = std::make_shared<op::v0::Parameter>(element::i64, PartialShape{-1, -1, -1});
|
||||
using matmul_test_params_t = std::tuple<StaticShape, // Input A shape
|
||||
StaticShape // Input B shape
|
||||
>;
|
||||
|
||||
class MatMulTest : public TestWithParam<matmul_test_params_t> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
std::tie(a_shape, b_shape) = GetParam();
|
||||
|
||||
set_exp_shape();
|
||||
}
|
||||
|
||||
std::shared_ptr<op::v0::MatMul> make_matmul(const size_t& a_dim_count,
|
||||
const size_t& b_dim_count,
|
||||
const bool transpose_a,
|
||||
const bool transpose_b) {
|
||||
auto a_input = std::make_shared<op::v0::Parameter>(element::i64, PartialShape::dynamic(a_dim_count));
|
||||
auto b_input = std::make_shared<op::v0::Parameter>(element::i64, PartialShape::dynamic(b_dim_count));
|
||||
|
||||
return std::make_shared<op::v0::MatMul>(a_input, b_input, transpose_a, transpose_b);
|
||||
}
|
||||
|
||||
void set_exp_shape() {
|
||||
if (a_shape.size() > 1 && b_shape.size() > 1) {
|
||||
std::transform(a_shape.cbegin(),
|
||||
a_shape.cend() - 2,
|
||||
b_shape.cbegin(),
|
||||
std::back_inserter(exp_shape),
|
||||
[](const StaticDimension& a, const StaticDimension& b) {
|
||||
return std::max(a.get_length(), b.get_length());
|
||||
});
|
||||
exp_shape.push_back(*std::next(a_shape.rbegin()));
|
||||
exp_shape.push_back(b_shape.back());
|
||||
} else if (a_shape.size() == 1 && b_shape.size() > 1) {
|
||||
exp_shape = b_shape;
|
||||
exp_shape.erase(std::prev(exp_shape.end(), 2));
|
||||
} else if (b_shape.size() == 1 && a_shape.size() > 1) {
|
||||
exp_shape = a_shape;
|
||||
exp_shape.erase(std::prev(exp_shape.end()));
|
||||
}
|
||||
}
|
||||
|
||||
static StaticShape make_transpose_input(const StaticShape& in) {
|
||||
StaticShape out(in);
|
||||
if (out.size() > 1) {
|
||||
std::iter_swap(out.rbegin(), std::next(out.rbegin()));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
StaticShape a_shape, b_shape, exp_shape;
|
||||
};
|
||||
|
||||
/** \brief Use transpose order -> output shape dimensions shall be as transpose order. */
|
||||
INSTANTIATE_TEST_SUITE_P(StaticShapeInference,
|
||||
MatMulTest,
|
||||
Values(make_tuple(StaticShape({1}), StaticShape({1})),
|
||||
make_tuple(StaticShape({1}), StaticShape({1, 3})),
|
||||
make_tuple(StaticShape({1}), StaticShape({1, 1, 3})),
|
||||
make_tuple(StaticShape({3, 1}), StaticShape({1})),
|
||||
make_tuple(StaticShape({3, 2, 1}), StaticShape({1})),
|
||||
make_tuple(StaticShape({3}), StaticShape({3})),
|
||||
make_tuple(StaticShape({5, 2}), StaticShape({2, 6})),
|
||||
make_tuple(StaticShape({2, 1, 2}), StaticShape({2, 6})),
|
||||
make_tuple(StaticShape({10, 8, 9, 2}), StaticShape({10, 8, 2, 8})),
|
||||
make_tuple(StaticShape({3, 1, 4, 3, 4}), StaticShape({3, 2, 1, 4, 1}))),
|
||||
PrintToStringParamName());
|
||||
|
||||
TEST_P(MatMulTest, no_input_transpose) {
|
||||
const auto matmul = make_matmul(a_shape.size(), b_shape.size(), false, false);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {a_shape, b_shape}, static_output_shapes = {StaticShape{}};
|
||||
|
||||
auto matmul = std::make_shared<op::v0::MatMul>(A_input, B_input, 0, 1);
|
||||
// Test StaticShape
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 5, 7}, StaticShape{3, 6, 7}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(matmul.get(), static_input_shapes, static_output_shapes);
|
||||
ASSERT_EQ(static_output_shapes[0], (StaticShape{3, 5, 6}));
|
||||
}
|
||||
ASSERT_EQ(static_output_shapes.front(), exp_shape);
|
||||
}
|
||||
|
||||
TEST_P(MatMulTest, transpose_input_a) {
|
||||
const auto matmul = make_matmul(a_shape.size(), b_shape.size(), true, false);
|
||||
|
||||
const auto a_transpose = make_transpose_input(a_shape);
|
||||
std::vector<StaticShape> static_input_shapes = {a_transpose, b_shape}, static_output_shapes = {StaticShape{}};
|
||||
|
||||
shape_inference(matmul.get(), static_input_shapes, static_output_shapes);
|
||||
ASSERT_EQ(static_output_shapes.front(), exp_shape);
|
||||
}
|
||||
|
||||
TEST_P(MatMulTest, transpose_input_b) {
|
||||
const auto matmul = make_matmul(a_shape.size(), b_shape.size(), false, true);
|
||||
|
||||
const auto b_transpose = make_transpose_input(b_shape);
|
||||
std::vector<StaticShape> static_input_shapes = {a_shape, b_transpose}, static_output_shapes = {StaticShape{}};
|
||||
|
||||
shape_inference(matmul.get(), static_input_shapes, static_output_shapes);
|
||||
ASSERT_EQ(static_output_shapes.front(), exp_shape);
|
||||
}
|
||||
|
||||
TEST_P(MatMulTest, transpose_inputs_a_b) {
|
||||
const auto matmul = make_matmul(a_shape.size(), b_shape.size(), true, true);
|
||||
|
||||
const auto a_transpose = make_transpose_input(a_shape);
|
||||
const auto b_transpose = make_transpose_input(b_shape);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes = {a_transpose, b_transpose}, static_output_shapes = {StaticShape{}};
|
||||
|
||||
shape_inference(matmul.get(), static_input_shapes, static_output_shapes);
|
||||
ASSERT_EQ(static_output_shapes.front(), exp_shape);
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ add_library(ngraph_test_util STATIC EXCLUDE_FROM_ALL ${UTIL_SRC})
|
||||
|
||||
ie_faster_build(ngraph_test_util UNITY)
|
||||
|
||||
target_link_libraries(ngraph_test_util PUBLIC openvino::runtime gtest gmock)
|
||||
target_link_libraries(ngraph_test_util PUBLIC openvino::runtime openvino::core::dev gtest gmock)
|
||||
target_include_directories(ngraph_test_util PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
target_include_directories(ngraph_test_util PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/..")
|
||||
|
||||
|
31
src/tests/util/type_prop.cpp
Normal file
31
src/tests/util/type_prop.cpp
Normal file
@ -0,0 +1,31 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "type_prop.hpp"
|
||||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include "dimension_tracker.hpp"
|
||||
#include "openvino/core/dimension.hpp"
|
||||
|
||||
std::vector<size_t> get_shape_labels(const ov::PartialShape& p_shape) {
|
||||
std::vector<size_t> labels;
|
||||
transform(p_shape.cbegin(), p_shape.cend(), back_inserter(labels), [](const ov::Dimension& dim) {
|
||||
return ov::DimensionTracker::get_label(dim);
|
||||
});
|
||||
return labels;
|
||||
}
|
||||
|
||||
void set_shape_labels(ov::PartialShape& p_shape, const std::vector<size_t>& labels) {
|
||||
ASSERT_EQ(labels.size(), p_shape.size());
|
||||
auto label_it = labels.begin();
|
||||
|
||||
std::for_each(p_shape.begin(), p_shape.end(), [&label_it](ov::Dimension& dim) {
|
||||
if (*label_it > 0) {
|
||||
ov::DimensionTracker::set_label(dim, *label_it);
|
||||
}
|
||||
++label_it;
|
||||
});
|
||||
}
|
@ -5,6 +5,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "openvino/core/partial_shape.hpp"
|
||||
|
||||
#define EXPECT_HAS_SUBSTRING(haystack, needle) EXPECT_PRED_FORMAT2(testing::IsSubstring, needle, haystack)
|
||||
|
||||
@ -14,3 +15,7 @@ struct PrintToDummyParamName {
|
||||
return "dummy" + std::to_string(info.index);
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<size_t> get_shape_labels(const ov::PartialShape& p_shape);
|
||||
|
||||
void set_shape_labels(ov::PartialShape& p_shape, const std::vector<size_t>& labels);
|
||||
|
Loading…
Reference in New Issue
Block a user