FIxed is_on_constant_path() using in all places (#19239)
* Fixed matmul weights check in snippets_mark_skipped * fix * ConvertMatMulToFC: is_on_constant_path fix * [TESTS] added SplitMatMulConcat subgraph test * MarkDequantizationSubgraph: is_on_constant_path fix
This commit is contained in:
@@ -34,7 +34,7 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::
|
||||
ov::matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto convert = pattern_map.at(convert_pattern).get_node_shared_ptr();
|
||||
auto input = pattern_map.at(input_pattern).get_node_shared_ptr();
|
||||
auto input = pattern_map.at(input_pattern);
|
||||
const auto multiply = m.get_match_root();
|
||||
|
||||
if (transformation_callback(multiply)) {
|
||||
@@ -48,12 +48,12 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::
|
||||
if (node && std::find(precisions.begin(), precisions.end(), node->get_input_element_type(0)) !=
|
||||
precisions.end()) {
|
||||
convert = node;
|
||||
input = convert->get_input_node_shared_ptr(0);
|
||||
input = convert->input_value(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto& input_precision = input->get_output_element_type(0);
|
||||
const auto& input_precision = input.get_element_type();
|
||||
// validation by Convert operation input precisions
|
||||
if (std::find(precisions.begin(), precisions.end(), input_precision) == precisions.end()) {
|
||||
return false;
|
||||
|
||||
@@ -18,7 +18,7 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
MATCHER_SCOPE(ConvertMatMulToFC);
|
||||
auto activations_m = ngraph::pattern::any_input(ngraph::pattern::has_static_rank());
|
||||
auto weights_path = [](const ov::Output<ov::Node>& output) {
|
||||
return ov::op::util::is_on_constant_path(output.get_node_shared_ptr());
|
||||
return ov::op::util::is_on_constant_path(output);
|
||||
};
|
||||
auto weights_m = ngraph::pattern::any_input(weights_path);
|
||||
auto matmul_m = ngraph::pattern::wrap_type<ngraph::op::v0::MatMul>({ activations_m, weights_m }, ngraph::pattern::has_static_rank());
|
||||
|
||||
@@ -254,7 +254,7 @@ bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node> &node, con
|
||||
ov::PartialShape matmul_shape;
|
||||
for (const auto &parent_out : node->input_values()) {
|
||||
const auto parent = parent_out.get_node_shared_ptr();
|
||||
if (ov::op::util::is_on_constant_path(parent)) {
|
||||
if (ov::op::util::is_on_constant_path(parent_out)) {
|
||||
bias_shape = parent_out.get_shape();
|
||||
num_non_const_inputs++;
|
||||
} else {
|
||||
@@ -265,7 +265,7 @@ bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node> &node, con
|
||||
// first check that weights are constant and both activations and weights have static shape
|
||||
if (grandparents.size() == 2 &&
|
||||
grandparents[1].get_partial_shape().is_static() &&
|
||||
(ov::op::util::is_on_constant_path(grandparents[1].get_node_shared_ptr()))) {
|
||||
(ov::op::util::is_on_constant_path(grandparents[1]))) {
|
||||
auto rank_a = grandparents[0].get_partial_shape().rank().get_length();
|
||||
auto rank_w = grandparents[1].get_partial_shape().rank().get_length();
|
||||
if (rank_a != 1 && rank_w != 1 && rank_a <= 3 && rank_w <= 3)
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_utils/fusing_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace InferenceEngine;
|
||||
using namespace CPUTestUtils;
|
||||
using namespace ov::test;
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
/*
|
||||
---------------
|
||||
| Input |
|
||||
---------------
|
||||
|
|
||||
---------------
|
||||
|VariadicSplit|
|
||||
---------------
|
||||
| |
|
||||
--------- |
|
||||
|MatMul | |
|
||||
--------- |
|
||||
| |
|
||||
---------------
|
||||
| Concat |
|
||||
---------------
|
||||
|
|
||||
---------------
|
||||
| Output |
|
||||
---------------
|
||||
*/
|
||||
|
||||
using SplitMatMulConcatParams = std::tuple<
|
||||
std::vector<InputShape>, // input shapes
|
||||
std::pair<bool, bool> // transposeA, transposeB
|
||||
>;
|
||||
|
||||
class SplitMatMulConcatTest : public testing::WithParamInterface<SplitMatMulConcatParams>,
|
||||
virtual public SubgraphBaseTest, public CPUTestsBase {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<SplitMatMulConcatParams> obj) {
|
||||
std::vector<InputShape> inputShapes;
|
||||
std::pair<bool, bool> transpose;
|
||||
|
||||
std::tie(inputShapes, transpose) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
for (const auto& shape : inputShapes) {
|
||||
result << ov::test::utils::partialShape2str({shape.first}) << "_";
|
||||
}
|
||||
result << "TS=";
|
||||
for (const auto& shape : inputShapes) {
|
||||
result << "(";
|
||||
if (!shape.second.empty()) {
|
||||
auto itr = shape.second.begin();
|
||||
do {
|
||||
result << ov::test::utils::vec2str(*itr);
|
||||
} while (++itr != shape.second.end() && result << "_");
|
||||
}
|
||||
result << ")_";
|
||||
}
|
||||
result << "transpose_a=" << transpose.first << "_";
|
||||
result << "transpose_b=" << transpose.second << "_";
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
template<typename T>
|
||||
void transposeShape(T& shape) {
|
||||
IE_ASSERT(shape.size() > 1);
|
||||
std::swap(*(shape.end() - 1), *(shape.end() - 2));
|
||||
}
|
||||
|
||||
void SetUp() override {
|
||||
targetDevice = ov::test::utils::DEVICE_CPU;
|
||||
|
||||
std::vector<InputShape> inputShapes;
|
||||
std::pair<bool, bool> transpose;
|
||||
|
||||
std::tie(inputShapes, transpose) = this->GetParam();
|
||||
|
||||
init_input_shapes(inputShapes);
|
||||
|
||||
bool transpA = transpose.first;
|
||||
bool transpB = transpose.second;
|
||||
|
||||
if (transpA) {
|
||||
transposeShape(inputDynamicShapes[0]);
|
||||
for (auto& shapes : targetStaticShapes) {
|
||||
transposeShape(shapes[0]);
|
||||
}
|
||||
}
|
||||
if (transpB) {
|
||||
transposeShape(inputDynamicShapes[1]);
|
||||
for (auto& shapes : targetStaticShapes) {
|
||||
transposeShape(shapes[1]);
|
||||
}
|
||||
}
|
||||
|
||||
const auto& inShapeA = inputDynamicShapes[0];
|
||||
const auto& inShapeB = inputDynamicShapes[1];
|
||||
|
||||
auto params = builder::makeDynamicParams(ElementType::f32, {inShapeA});
|
||||
auto paramOuts = helpers::convert2OutputVector(helpers::castOps2Nodes<opset1::Parameter>(params));
|
||||
std::shared_ptr<Node> inputB = builder::makeConstant<float>(ElementType::f32, inShapeB.get_shape(), {}, true);
|
||||
|
||||
auto split = builder::makeVariadicSplit(paramOuts[0], {1, 1}, 0);
|
||||
|
||||
auto matMul = builder::makeMatMul(split->output(0), inputB, transpA, transpB);
|
||||
|
||||
auto concat = builder::makeConcat({matMul, split->output(1)}, 0);
|
||||
|
||||
function = CPUTestsBase::makeNgraphFunction(ElementType::f32, params, concat, "FullyConnected");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(SplitMatMulConcatTest, CompareWithRefs) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED();
|
||||
run();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<std::pair<bool, bool>> transposeParams = {
|
||||
{false, true},
|
||||
};
|
||||
|
||||
const std::vector<std::vector<InputShape>> inputShapes2D = {
|
||||
static_shapes_to_test_representation({{2, 3}, {3, 3}}),
|
||||
};
|
||||
|
||||
const auto testParams2D_FP32_smoke = ::testing::Combine(
|
||||
::testing::ValuesIn(inputShapes2D),
|
||||
::testing::ValuesIn(transposeParams));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_FP32, SplitMatMulConcatTest, testParams2D_FP32_smoke,
|
||||
SplitMatMulConcatTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
Reference in New Issue
Block a user