This commit is contained in:
Evgeny Kotov 2023-03-23 15:28:00 +01:00
parent 1ca78f643e
commit 06d6fbf0e8
7 changed files with 139 additions and 97 deletions

View File

@ -82,11 +82,17 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model,
const bool has_mvn = ov::op::util::has_op_with_type<ov::opset8::MVN>(model) ||
ov::op::util::has_op_with_type<ov::op::v0::MVN>(model);
if (has_convolution || has_maxpool || has_mvn) {
if (ov::op::util::has_op_with_type<ov::opset8::MatMul>(model))
std::cout << "[EMUTEX DEBUG] MatMul node" << std::endl;
}
ov::pass::Manager manager;
manager.register_pass<ov::pass::InitNodeInfo>();
// In OV API 2.0(IRv10) default convertion to fp32 (inputs, outputs and weights) is disabled
// and we need to run the ConvertPrecision transformation to support old networks.
manager.register_pass<ov::pass::ConvertPrecision>(precisions_map{{ngraph::element::f16, ngraph::element::f32}});
EMUTEX_DEBUG_VISUALIZE("start");
manager.register_pass<ov::pass::ConvertMVN1ToMVN6>();
manager.register_pass<ov::intel_gna::pass::DecomposeMVN>();
manager.register_pass<ov::pass::CommonOptimizations>();
@ -131,13 +137,16 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model,
manager.register_pass<ov::intel_gna::pass::InsertCopyBeforeLayerToBeEliminated>();
// TODO enable this transformation for networks without convolutions
if (has_convolution || has_maxpool || has_mvn) {
EMUTEX_DEBUG_VISUALIZE("before");
manager.register_pass<ov::intel_gna::pass::TransposeNCHW>();
EMUTEX_DEBUG_VISUALIZE("after_TransposeNCHW");
manager.register_pass<ov::intel_gna::pass::ReshapeTransposeSubstitute>();
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.register_pass<ov::intel_gna::pass::GatherSinkingGeneral>();
manager.register_pass<ov::pass::ReshapeSequenceFusion>();
manager.register_pass<ov::pass::TransposeToReshape>();
manager.register_pass<ov::intel_gna::pass::GnaConvolutionFusion>();
EMUTEX_DEBUG_VISUALIZE("after");
}
manager.register_pass<ov::intel_gna::pass::RemoveInputsProcessing>(subgraph_cpu_map);
manager.register_pass<ov::intel_gna::pass::RemoveOutputsProcessing>(subgraph_cpu_map);
@ -205,6 +214,8 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model,
// Operations Max and Min aren't supported
pass_config->disable<ov::pass::ConcatReduceFusion>();
EMUTEX_DEBUG_VISUALIZE("finish");
manager.run_passes(model);
if (has_slice && (has_convolution || has_maxpool || has_mvn)) {
@ -212,6 +223,7 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model,
manager.register_pass<ov::pass::InitNodeInfo>();
manager.register_pass<ov::pass::SliceToStridedSlice>(true);
manager.register_pass<ngraph::pass::ConvertStridedSliceToCropMatcher>();
EMUTEX_DEBUG_VISUALIZE("finish1");
manager.run_passes(model);
}

View File

@ -17,6 +17,7 @@
#include "transformations/gather_sinking_transpose_reshape.hpp"
#include "transformations/gather_sinking_reshape.hpp"
#include "transformations/gather_sinking_split.hpp"
#include "transformations/gather_sinking_matmul.hpp"
using namespace ov;
using namespace ov::pass::pattern;
@ -28,6 +29,7 @@ GatherSinkingGeneralForward::GatherSinkingGeneralForward() {
add_matcher<GatherSinkingUnaryForward>();
add_matcher<GatherSinkingBinaryForward>();
add_matcher<GatherSinkingTransposeReshapeForward>();
add_matcher<GatherSinkingMatmulForward>();
add_matcher<GatherSinkingFuse>();
}
@ -56,6 +58,5 @@ bool GatherSinkingGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.run_passes(f);
}
return false;
}

View File

@ -13,6 +13,8 @@
#include "transformations/rt_info/gather_sinking_attr.hpp"
#include "transformations/utils/gather_sinking_utils.hpp"
#include "../debug_new_pass.hpp"
using namespace ov;
using namespace ov::opset9;
using namespace ov::pass::pattern;

View File

@ -0,0 +1,100 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/gather_sinking_matmul.hpp"
#include <openvino/cc/ngraph/itt.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/gather_sinking_attr.hpp"
#include "transformations/utils/gather_sinking_utils.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace ov::pass::pattern;
using namespace ov::op::util;
using namespace gather_sinking;
using namespace ov::intel_gna::pass;
using namespace ov::intel_gna::rt_info;
namespace {
/*
Reverts gather indices in a such way that reverted and initial gather will do nothing if
stays after another.
Works only with positive form (no negative indices).
*/
std::vector<int64_t> ReverseGatherIndexes(const std::vector<int64_t>& indexes) {
std::vector<int64_t> out(indexes.size());
for (size_t i = 0; i < indexes.size(); i++) {
out.at(indexes[i]) = i;
}
return out;
}
} // namespace
GatherSinkingMatmulForward::GatherSinkingMatmulForward() {
MATCHER_SCOPE(GatherSinkingMatmulForward);
auto gather_indices_label = wrap_type<Constant>();
auto gather_axis_label = wrap_type<Constant>();
auto gather_label = wrap_type<Gather>({any_input(), gather_indices_label, gather_axis_label});
auto matmul_label = wrap_type<MatMul>({gather_label, any_input()});
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto gather_indices = as_type_ptr<Constant>(pattern_to_output.at(gather_indices_label).get_node_shared_ptr());
auto gather_axis = as_type_ptr<Constant>(pattern_to_output.at(gather_axis_label).get_node_shared_ptr());
auto gather = as_type_ptr<Gather>(pattern_to_output.at(gather_label).get_node_shared_ptr());
auto matmul = as_type_ptr<MatMul>(pattern_to_output.at(matmul_label).get_node_shared_ptr());
std::cout << "[EMUTEX DEBUG] GatherSinkingMatmulForward gather " << gather->get_friendly_name() << " matmul " << matmul->get_friendly_name() << std::endl;
std::cout << "[EMUTEX DEBUG] GatherSinkingMatmulForward gather axis " << gather_axis->cast_vector<int64_t>()[0] << std::endl;
auto gather_parent = matmul->input_value(0 /* TODO */).get_node()->input_value(0);;
// insert input gather
#if 0
size_t gather_axis_value_current = ConvertAxisToPositive(gather_axis->cast_vector<int64_t>()[0],
gather->get_input_shape(0).size());
#endif
const size_t gather_axis_value_new = 0; // TODO
auto gather_axis_new1 = std::make_shared<Constant>(element::i64, Shape{}, gather_axis_value_new);
auto gather_indices_values = ReverseGatherIndexes(gather_indices->cast_vector<int64_t>());
auto gather_indices_new1 = std::make_shared<Constant>(element::i64, Shape{gather_indices_values.size()}, gather_indices_values);
auto gather_new1 = std::make_shared<Gather>(matmul->input_value(1) /* TODO */, gather_indices_new1, gather_axis_new1);
matmul->input(1 /* TODO */).replace_source_output(gather_new1->output(0));
// remove input gather
matmul->input(0 /* TODO */).replace_source_output(gather_parent);
// insert output gather
auto matmul_consumers = matmul->output(0).get_target_inputs();
auto gather_axis_new2 = gather_axis->clone_with_new_inputs({});
auto gather_indices_new2 = gather_indices->clone_with_new_inputs({});
auto gather_new2 = std::make_shared<Gather>(matmul->output(0), gather_indices_new2, gather_axis_new2);
for (auto& consumer : matmul_consumers) {
consumer.replace_source_output(gather_new2);
}
SwapFriendlyNames(gather_new2, matmul);
copy_runtime_info(gather, {gather_new1, gather_indices_new1, gather_axis_new1, gather_new2, gather_indices_new2, gather_axis_new2});
register_new_node(gather_new1);
register_new_node(gather_new2);
gather_sinking::UpdateForwardGatherSinkingAbility(gather_new2);
return true;
};
auto m = std::make_shared<Matcher>(matmul_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -0,0 +1,22 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace intel_gna {
namespace pass {
class GatherSinkingMatmulForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("GatherSinkingMatmulForward", "0");
GatherSinkingMatmulForward();
};
} // namespace pass
} // namespace intel_gna
} // namespace ov

View File

@ -19,8 +19,6 @@
#include "transformations/rt_info/gather_sinking_attr.hpp"
#include "transformations/utils/gather_sinking_utils.hpp"
#include "../debug_new_pass.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace ov::pass::pattern;
@ -119,37 +117,6 @@ std::vector<int64_t> NormalizeGatherIndices(const std::vector<int64_t>& indices)
}
return normalized;
}
#if 0
struct OutputTranspose {
OutputTranspose() = default;
Transpose* transpose = {};
Constant* const_node = {};
int output_idx = {};
};
OutputTranspose FindFirstOutputTranspose(NodePtr node) {
for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) {
for (auto& input : node->get_output_target_inputs(output_idx)) {
auto transpose_node = dynamic_cast<Transpose*>(input.get_node());
if (!transpose_node)
continue;
auto const_node = dynamic_cast<Constant*>(transpose_node->input_value(1).get_node());
if (!const_node)
continue;
{
OutputTranspose output_transpose;
output_transpose.transpose = transpose_node;
output_transpose.const_node = const_node;
output_transpose.output_idx = output_idx;
return output_transpose;
}
}
}
return {};
}
#endif
} // namespace
@ -231,66 +198,3 @@ GatherSinkingSplitBackward::GatherSinkingSplitBackward() {
auto m = std::make_shared<Matcher>(gather_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
#if 0
GatherSinkingSplitTransposeBackward::GatherSinkingSplitTransposeBackward() {
MATCHER_SCOPE(GatherSinkingSplitTransposeBackward);
auto transpose_const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Gather>({any_input(), transpose_const_label}, IsSplitSinked);
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = as_type_ptr<Gather>(pattern_to_output.at(transpose_label).get_node_shared_ptr());
auto split = FindInputNode<Split>(transpose);
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
if (!split_axis_constant) {
return false;
}
int64_t split_axis;
if (!GetSplitAxis(split_axis_constant, split->input_value(0).get_partial_shape().rank(), split_axis)) {
return false;
}
OutputTranspose output_transpose = FindFirstOutputTranspose(split);
// TODO
//
intel_gna_debug::Print("gather_indices", gather_indices);
//
std::vector<int64_t> new_indices(split->get_input_shape(0)[gather_axis]);
std::iota(new_indices.begin(), new_indices.end(), 0);
const size_t base = output_transpose.output_idx * split->get_output_shape(0)[split_axis];
for (size_t i = 0; i < gather_indices.size(); ++i) {
new_indices[base + i] = base + gather_indices[i];
}
intel_gna_debug::Print("new_indices", new_indices);
auto split_input = split->input_value(0);
auto new_indices_const = std::make_shared<Constant>(output_transpose.gather_axis->get_element_type(),
Shape{new_indices.size()},
new_indices);
auto new_axis_const = output_transpose.gather_axis->clone_with_new_inputs({});
auto new_gather = std::make_shared<Gather>(split_input, new_indices_const, new_axis_const);
split->input(0).replace_source_output(new_gather->output(0));
copy_runtime_info(split_input.get_node_shared_ptr(), {new_gather, new_indices_const, new_axis_const});
register_new_node(new_gather);
for (auto& input : split->get_output_target_inputs(output_transpose.output_idx)) {
Node* consumer = input.get_node();
if (consumer->get_output_size() != 1)
continue;
consumer->output(0).replace(split->output(output_transpose.output_idx));
}
return true;
};
auto m = std::make_shared<Matcher>(gather_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
#endif

View File

@ -325,6 +325,7 @@ bool CanPropagateGatherForwardThrough(Node* node) {
CHECK_GATHER_SINKING_SUPPORTED(ov::op::util::BinaryElementwiseArithmetic, node);
CHECK_GATHER_SINKING_SUPPORTED(Gather, node);
CHECK_GATHER_SINKING_SUPPORTED(Reshape, node);
CHECK_GATHER_SINKING_SUPPORTED(MatMul, node);
return false;
}