Review opset1 matmul shape inference aspects (#13333)

* Add label propagate test for shape inference

* Update MatMul tests for StaticShape

* MatMul label propagation, reshape input with label

* Apply clang-format to matmul shape inference tests
This commit is contained in:
Pawel Raasz 2022-10-05 15:21:36 +02:00 committed by GitHub
parent f6d6f5629f
commit 2b70158047
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 228 additions and 11 deletions

View File

@ -2,12 +2,14 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "dimension_tracker.hpp"
#include "gmock/gmock.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
using namespace testing;
TEST(type_prop, matmul_2D_same) {
auto A = make_shared<op::Parameter>(element::f32, Shape{2, 2});
@ -472,3 +474,84 @@ TEST(type_prop, matmul_incompatible_batch_dim_bounds) {
ASSERT_EQ(matmul->get_element_type(), element::f32);
ASSERT_EQ(matmul->get_output_partial_shape(0), expected_output_shape);
}
TEST(type_prop, matmul_propagate_labels) {
const auto a_labels = std::vector<size_t>{1, 0, 2, 5, 6};
const auto b_labels = std::vector<size_t>{0, 1, 3, 7, 9};
auto a_shape = PartialShape{4, 2, 3, 6, 4};
auto b_shape = PartialShape{4, 2, 3, 4, 2};
set_shape_labels(a_shape, a_labels);
set_shape_labels(b_shape, b_labels);
const auto a = make_shared<op::Parameter>(element::f32, a_shape);
const auto b = make_shared<op::Parameter>(element::f32, b_shape);
const auto matmul = make_shared<op::MatMul>(a, b, false, false);
const auto& output_shape = matmul->get_output_partial_shape(0);
const auto labels = get_shape_labels(output_shape);
ASSERT_THAT(labels,
ElementsAre(a_labels[0], // use a label, b is not set
b_labels[1], // use b label, a is not set
0, // not set label. a,b has different labels
a_labels[3], // use label from a, b is lost
b_labels[4] // use label from b, a is lost
));
}
TEST(type_prop, matmul_propagate_labels_on_interval_dims) {
const auto a_labels = std::vector<size_t>{1, 0, 3, 4, 5};
const auto b_labels = std::vector<size_t>{0, 1, 3, 4, 7};
auto a_shape = PartialShape{Dimension(1, 3), 1, Dimension(2, 3), Dimension(3, 4), 4};
auto b_shape = PartialShape{1, Dimension(1, 5), Dimension(1, 3), 4, Dimension::dynamic()};
set_shape_labels(a_shape, a_labels);
set_shape_labels(b_shape, b_labels);
const auto a = make_shared<op::Parameter>(element::f32, a_shape);
const auto b = make_shared<op::Parameter>(element::f32, b_shape);
const auto matmul = make_shared<op::MatMul>(a, b, false, false);
const auto& output_shape = matmul->get_output_partial_shape(0);
const auto labels = get_shape_labels(output_shape);
ASSERT_THAT(labels,
ElementsAre(a_labels[0], // use a label, b is not set
b_labels[1], // use b label, a is not set
a_labels[2], // use a label, b is same
a_labels[3], // use label from a, b is lost
b_labels[4] // use label from a, b is lost
));
}
TEST(type_prop, matmul_propagate_label_on_b_input_after_reshape) {
constexpr size_t my_label = 2;
auto marked_dim = Dimension(2, 3);
ov::DimensionTracker::set_label(marked_dim, my_label);
const auto a_shape = PartialShape{Dimension::dynamic(), 5, 3};
const auto b_shape = PartialShape{3, marked_dim, 2};
const auto b = make_shared<op::Parameter>(element::f32, b_shape);
const auto shape_of_b = std::make_shared<op::ShapeOf>(b);
const auto gather = std::make_shared<op::v7::Gather>(
shape_of_b,
std::make_shared<op::Constant>(element::i64, Shape{2}, std::vector<int64_t>{1, 0}),
std::make_shared<op::Constant>(element::i64, Shape{}, 0));
const auto concat =
std::make_shared<op::Concat>(OutputVector{gather, std::make_shared<op::Constant>(element::i64, Shape{1}, 8)},
0);
const auto reshape_b = make_shared<op::v1::Reshape>(b, concat, false);
const auto a = make_shared<op::Parameter>(element::f32, a_shape);
const auto matmul = make_shared<op::MatMul>(a, reshape_b, false, false);
const auto& output_shape = matmul->get_output_partial_shape(0);
const auto labels = get_shape_labels(output_shape);
ASSERT_THAT(labels, ElementsAre(my_label, 0, 0));
ASSERT_EQ(output_shape, (PartialShape{marked_dim, 5, 8}));
}

View File

@ -12,15 +12,113 @@
using namespace ov;
using namespace ov::intel_cpu;
using namespace testing;
TEST(StaticShapeInferenceTest, MatMulTest) {
auto A_input = std::make_shared<op::v0::Parameter>(element::i64, PartialShape{-1, -1, -1});
auto B_input = std::make_shared<op::v0::Parameter>(element::i64, PartialShape{-1, -1, -1});
using matmul_test_params_t = std::tuple<StaticShape, // Input A shape
StaticShape // Input B shape
>;
class MatMulTest : public TestWithParam<matmul_test_params_t> {
protected:
void SetUp() override {
std::tie(a_shape, b_shape) = GetParam();
set_exp_shape();
}
std::shared_ptr<op::v0::MatMul> make_matmul(const size_t& a_dim_count,
const size_t& b_dim_count,
const bool transpose_a,
const bool transpose_b) {
auto a_input = std::make_shared<op::v0::Parameter>(element::i64, PartialShape::dynamic(a_dim_count));
auto b_input = std::make_shared<op::v0::Parameter>(element::i64, PartialShape::dynamic(b_dim_count));
return std::make_shared<op::v0::MatMul>(a_input, b_input, transpose_a, transpose_b);
}
void set_exp_shape() {
if (a_shape.size() > 1 && b_shape.size() > 1) {
std::transform(a_shape.cbegin(),
a_shape.cend() - 2,
b_shape.cbegin(),
std::back_inserter(exp_shape),
[](const StaticDimension& a, const StaticDimension& b) {
return std::max(a.get_length(), b.get_length());
});
exp_shape.push_back(*std::next(a_shape.rbegin()));
exp_shape.push_back(b_shape.back());
} else if (a_shape.size() == 1 && b_shape.size() > 1) {
exp_shape = b_shape;
exp_shape.erase(std::prev(exp_shape.end(), 2));
} else if (b_shape.size() == 1 && a_shape.size() > 1) {
exp_shape = a_shape;
exp_shape.erase(std::prev(exp_shape.end()));
}
}
static StaticShape make_transpose_input(const StaticShape& in) {
StaticShape out(in);
if (out.size() > 1) {
std::iter_swap(out.rbegin(), std::next(out.rbegin()));
}
return out;
}
StaticShape a_shape, b_shape, exp_shape;
};
/** \brief Use transpose order -> output shape dimensions shall be as transpose order. */
INSTANTIATE_TEST_SUITE_P(StaticShapeInference,
MatMulTest,
Values(make_tuple(StaticShape({1}), StaticShape({1})),
make_tuple(StaticShape({1}), StaticShape({1, 3})),
make_tuple(StaticShape({1}), StaticShape({1, 1, 3})),
make_tuple(StaticShape({3, 1}), StaticShape({1})),
make_tuple(StaticShape({3, 2, 1}), StaticShape({1})),
make_tuple(StaticShape({3}), StaticShape({3})),
make_tuple(StaticShape({5, 2}), StaticShape({2, 6})),
make_tuple(StaticShape({2, 1, 2}), StaticShape({2, 6})),
make_tuple(StaticShape({10, 8, 9, 2}), StaticShape({10, 8, 2, 8})),
make_tuple(StaticShape({3, 1, 4, 3, 4}), StaticShape({3, 2, 1, 4, 1}))),
PrintToStringParamName());
TEST_P(MatMulTest, no_input_transpose) {
const auto matmul = make_matmul(a_shape.size(), b_shape.size(), false, false);
std::vector<StaticShape> static_input_shapes = {a_shape, b_shape}, static_output_shapes = {StaticShape{}};
auto matmul = std::make_shared<op::v0::MatMul>(A_input, B_input, 0, 1);
// Test StaticShape
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 5, 7}, StaticShape{3, 6, 7}},
static_output_shapes = {StaticShape{}};
shape_inference(matmul.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], (StaticShape{3, 5, 6}));
}
ASSERT_EQ(static_output_shapes.front(), exp_shape);
}
TEST_P(MatMulTest, transpose_input_a) {
const auto matmul = make_matmul(a_shape.size(), b_shape.size(), true, false);
const auto a_transpose = make_transpose_input(a_shape);
std::vector<StaticShape> static_input_shapes = {a_transpose, b_shape}, static_output_shapes = {StaticShape{}};
shape_inference(matmul.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes.front(), exp_shape);
}
TEST_P(MatMulTest, transpose_input_b) {
const auto matmul = make_matmul(a_shape.size(), b_shape.size(), false, true);
const auto b_transpose = make_transpose_input(b_shape);
std::vector<StaticShape> static_input_shapes = {a_shape, b_transpose}, static_output_shapes = {StaticShape{}};
shape_inference(matmul.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes.front(), exp_shape);
}
TEST_P(MatMulTest, transpose_inputs_a_b) {
const auto matmul = make_matmul(a_shape.size(), b_shape.size(), true, true);
const auto a_transpose = make_transpose_input(a_shape);
const auto b_transpose = make_transpose_input(b_shape);
std::vector<StaticShape> static_input_shapes = {a_transpose, b_transpose}, static_output_shapes = {StaticShape{}};
shape_inference(matmul.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes.front(), exp_shape);
}

View File

@ -8,7 +8,7 @@ add_library(ngraph_test_util STATIC EXCLUDE_FROM_ALL ${UTIL_SRC})
ie_faster_build(ngraph_test_util UNITY)
target_link_libraries(ngraph_test_util PUBLIC openvino::runtime gtest gmock)
target_link_libraries(ngraph_test_util PUBLIC openvino::runtime openvino::core::dev gtest gmock)
target_include_directories(ngraph_test_util PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})
target_include_directories(ngraph_test_util PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/..")

View File

@ -0,0 +1,31 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "type_prop.hpp"
#include <functional>
#include <vector>
#include "dimension_tracker.hpp"
#include "openvino/core/dimension.hpp"
std::vector<size_t> get_shape_labels(const ov::PartialShape& p_shape) {
std::vector<size_t> labels;
transform(p_shape.cbegin(), p_shape.cend(), back_inserter(labels), [](const ov::Dimension& dim) {
return ov::DimensionTracker::get_label(dim);
});
return labels;
}
void set_shape_labels(ov::PartialShape& p_shape, const std::vector<size_t>& labels) {
ASSERT_EQ(labels.size(), p_shape.size());
auto label_it = labels.begin();
std::for_each(p_shape.begin(), p_shape.end(), [&label_it](ov::Dimension& dim) {
if (*label_it > 0) {
ov::DimensionTracker::set_label(dim, *label_it);
}
++label_it;
});
}

View File

@ -5,6 +5,7 @@
#pragma once
#include "gtest/gtest.h"
#include "openvino/core/partial_shape.hpp"
#define EXPECT_HAS_SUBSTRING(haystack, needle) EXPECT_PRED_FORMAT2(testing::IsSubstring, needle, haystack)
@ -14,3 +15,7 @@ struct PrintToDummyParamName {
return "dummy" + std::to_string(info.index);
}
};
std::vector<size_t> get_shape_labels(const ov::PartialShape& p_shape);
void set_shape_labels(ov::PartialShape& p_shape, const std::vector<size_t>& labels);