initial
This commit is contained in:
parent
1ca78f643e
commit
06d6fbf0e8
@ -82,11 +82,17 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model,
|
||||
const bool has_mvn = ov::op::util::has_op_with_type<ov::opset8::MVN>(model) ||
|
||||
ov::op::util::has_op_with_type<ov::op::v0::MVN>(model);
|
||||
|
||||
if (has_convolution || has_maxpool || has_mvn) {
|
||||
if (ov::op::util::has_op_with_type<ov::opset8::MatMul>(model))
|
||||
std::cout << "[EMUTEX DEBUG] MatMul node" << std::endl;
|
||||
}
|
||||
|
||||
ov::pass::Manager manager;
|
||||
manager.register_pass<ov::pass::InitNodeInfo>();
|
||||
// In OV API 2.0(IRv10) default convertion to fp32 (inputs, outputs and weights) is disabled
|
||||
// and we need to run the ConvertPrecision transformation to support old networks.
|
||||
manager.register_pass<ov::pass::ConvertPrecision>(precisions_map{{ngraph::element::f16, ngraph::element::f32}});
|
||||
EMUTEX_DEBUG_VISUALIZE("start");
|
||||
manager.register_pass<ov::pass::ConvertMVN1ToMVN6>();
|
||||
manager.register_pass<ov::intel_gna::pass::DecomposeMVN>();
|
||||
manager.register_pass<ov::pass::CommonOptimizations>();
|
||||
@ -131,13 +137,16 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model,
|
||||
manager.register_pass<ov::intel_gna::pass::InsertCopyBeforeLayerToBeEliminated>();
|
||||
// TODO enable this transformation for networks without convolutions
|
||||
if (has_convolution || has_maxpool || has_mvn) {
|
||||
EMUTEX_DEBUG_VISUALIZE("before");
|
||||
manager.register_pass<ov::intel_gna::pass::TransposeNCHW>();
|
||||
EMUTEX_DEBUG_VISUALIZE("after_TransposeNCHW");
|
||||
manager.register_pass<ov::intel_gna::pass::ReshapeTransposeSubstitute>();
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
manager.register_pass<ov::intel_gna::pass::GatherSinkingGeneral>();
|
||||
manager.register_pass<ov::pass::ReshapeSequenceFusion>();
|
||||
manager.register_pass<ov::pass::TransposeToReshape>();
|
||||
manager.register_pass<ov::intel_gna::pass::GnaConvolutionFusion>();
|
||||
EMUTEX_DEBUG_VISUALIZE("after");
|
||||
}
|
||||
manager.register_pass<ov::intel_gna::pass::RemoveInputsProcessing>(subgraph_cpu_map);
|
||||
manager.register_pass<ov::intel_gna::pass::RemoveOutputsProcessing>(subgraph_cpu_map);
|
||||
@ -205,6 +214,8 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model,
|
||||
// Operations Max and Min aren't supported
|
||||
pass_config->disable<ov::pass::ConcatReduceFusion>();
|
||||
|
||||
EMUTEX_DEBUG_VISUALIZE("finish");
|
||||
|
||||
manager.run_passes(model);
|
||||
|
||||
if (has_slice && (has_convolution || has_maxpool || has_mvn)) {
|
||||
@ -212,6 +223,7 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model,
|
||||
manager.register_pass<ov::pass::InitNodeInfo>();
|
||||
manager.register_pass<ov::pass::SliceToStridedSlice>(true);
|
||||
manager.register_pass<ngraph::pass::ConvertStridedSliceToCropMatcher>();
|
||||
EMUTEX_DEBUG_VISUALIZE("finish1");
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "transformations/gather_sinking_transpose_reshape.hpp"
|
||||
#include "transformations/gather_sinking_reshape.hpp"
|
||||
#include "transformations/gather_sinking_split.hpp"
|
||||
#include "transformations/gather_sinking_matmul.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::pass::pattern;
|
||||
@ -28,6 +29,7 @@ GatherSinkingGeneralForward::GatherSinkingGeneralForward() {
|
||||
add_matcher<GatherSinkingUnaryForward>();
|
||||
add_matcher<GatherSinkingBinaryForward>();
|
||||
add_matcher<GatherSinkingTransposeReshapeForward>();
|
||||
add_matcher<GatherSinkingMatmulForward>();
|
||||
add_matcher<GatherSinkingFuse>();
|
||||
}
|
||||
|
||||
@ -56,6 +58,5 @@ bool GatherSinkingGeneral::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.run_passes(f);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
@ -13,6 +13,8 @@
|
||||
#include "transformations/rt_info/gather_sinking_attr.hpp"
|
||||
#include "transformations/utils/gather_sinking_utils.hpp"
|
||||
|
||||
#include "../debug_new_pass.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset9;
|
||||
using namespace ov::pass::pattern;
|
||||
|
@ -0,0 +1,100 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/gather_sinking_matmul.hpp"
|
||||
|
||||
#include <openvino/cc/ngraph/itt.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/gather_sinking_attr.hpp"
|
||||
#include "transformations/utils/gather_sinking_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::op::util;
|
||||
using namespace gather_sinking;
|
||||
using namespace ov::intel_gna::pass;
|
||||
using namespace ov::intel_gna::rt_info;
|
||||
|
||||
namespace {
|
||||
/*
|
||||
Reverts gather indices in a such way that reverted and initial gather will do nothing if
|
||||
stays after another.
|
||||
Works only with positive form (no negative indices).
|
||||
*/
|
||||
std::vector<int64_t> ReverseGatherIndexes(const std::vector<int64_t>& indexes) {
|
||||
std::vector<int64_t> out(indexes.size());
|
||||
for (size_t i = 0; i < indexes.size(); i++) {
|
||||
out.at(indexes[i]) = i;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GatherSinkingMatmulForward::GatherSinkingMatmulForward() {
|
||||
MATCHER_SCOPE(GatherSinkingMatmulForward);
|
||||
auto gather_indices_label = wrap_type<Constant>();
|
||||
auto gather_axis_label = wrap_type<Constant>();
|
||||
auto gather_label = wrap_type<Gather>({any_input(), gather_indices_label, gather_axis_label});
|
||||
auto matmul_label = wrap_type<MatMul>({gather_label, any_input()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto gather_indices = as_type_ptr<Constant>(pattern_to_output.at(gather_indices_label).get_node_shared_ptr());
|
||||
auto gather_axis = as_type_ptr<Constant>(pattern_to_output.at(gather_axis_label).get_node_shared_ptr());
|
||||
auto gather = as_type_ptr<Gather>(pattern_to_output.at(gather_label).get_node_shared_ptr());
|
||||
auto matmul = as_type_ptr<MatMul>(pattern_to_output.at(matmul_label).get_node_shared_ptr());
|
||||
|
||||
std::cout << "[EMUTEX DEBUG] GatherSinkingMatmulForward gather " << gather->get_friendly_name() << " matmul " << matmul->get_friendly_name() << std::endl;
|
||||
std::cout << "[EMUTEX DEBUG] GatherSinkingMatmulForward gather axis " << gather_axis->cast_vector<int64_t>()[0] << std::endl;
|
||||
|
||||
auto gather_parent = matmul->input_value(0 /* TODO */).get_node()->input_value(0);;
|
||||
// insert input gather
|
||||
#if 0
|
||||
size_t gather_axis_value_current = ConvertAxisToPositive(gather_axis->cast_vector<int64_t>()[0],
|
||||
gather->get_input_shape(0).size());
|
||||
#endif
|
||||
const size_t gather_axis_value_new = 0; // TODO
|
||||
|
||||
auto gather_axis_new1 = std::make_shared<Constant>(element::i64, Shape{}, gather_axis_value_new);
|
||||
auto gather_indices_values = ReverseGatherIndexes(gather_indices->cast_vector<int64_t>());
|
||||
auto gather_indices_new1 = std::make_shared<Constant>(element::i64, Shape{gather_indices_values.size()}, gather_indices_values);
|
||||
auto gather_new1 = std::make_shared<Gather>(matmul->input_value(1) /* TODO */, gather_indices_new1, gather_axis_new1);
|
||||
|
||||
matmul->input(1 /* TODO */).replace_source_output(gather_new1->output(0));
|
||||
|
||||
// remove input gather
|
||||
matmul->input(0 /* TODO */).replace_source_output(gather_parent);
|
||||
|
||||
// insert output gather
|
||||
auto matmul_consumers = matmul->output(0).get_target_inputs();
|
||||
|
||||
auto gather_axis_new2 = gather_axis->clone_with_new_inputs({});
|
||||
auto gather_indices_new2 = gather_indices->clone_with_new_inputs({});
|
||||
auto gather_new2 = std::make_shared<Gather>(matmul->output(0), gather_indices_new2, gather_axis_new2);
|
||||
|
||||
for (auto& consumer : matmul_consumers) {
|
||||
consumer.replace_source_output(gather_new2);
|
||||
}
|
||||
|
||||
SwapFriendlyNames(gather_new2, matmul);
|
||||
|
||||
copy_runtime_info(gather, {gather_new1, gather_indices_new1, gather_axis_new1, gather_new2, gather_indices_new2, gather_axis_new2});
|
||||
|
||||
register_new_node(gather_new1);
|
||||
register_new_node(gather_new2);
|
||||
|
||||
gather_sinking::UpdateForwardGatherSinkingAbility(gather_new2);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(matmul_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -0,0 +1,22 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
namespace pass {
|
||||
|
||||
class GatherSinkingMatmulForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("GatherSinkingMatmulForward", "0");
|
||||
GatherSinkingMatmulForward();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace intel_gna
|
||||
} // namespace ov
|
@ -19,8 +19,6 @@
|
||||
#include "transformations/rt_info/gather_sinking_attr.hpp"
|
||||
#include "transformations/utils/gather_sinking_utils.hpp"
|
||||
|
||||
#include "../debug_new_pass.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
@ -119,37 +117,6 @@ std::vector<int64_t> NormalizeGatherIndices(const std::vector<int64_t>& indices)
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
#if 0
|
||||
struct OutputTranspose {
|
||||
OutputTranspose() = default;
|
||||
Transpose* transpose = {};
|
||||
Constant* const_node = {};
|
||||
int output_idx = {};
|
||||
};
|
||||
|
||||
OutputTranspose FindFirstOutputTranspose(NodePtr node) {
|
||||
for (size_t output_idx = 0; output_idx < node->get_output_size(); ++output_idx) {
|
||||
for (auto& input : node->get_output_target_inputs(output_idx)) {
|
||||
auto transpose_node = dynamic_cast<Transpose*>(input.get_node());
|
||||
if (!transpose_node)
|
||||
continue;
|
||||
auto const_node = dynamic_cast<Constant*>(transpose_node->input_value(1).get_node());
|
||||
if (!const_node)
|
||||
continue;
|
||||
{
|
||||
OutputTranspose output_transpose;
|
||||
output_transpose.transpose = transpose_node;
|
||||
output_transpose.const_node = const_node;
|
||||
output_transpose.output_idx = output_idx;
|
||||
|
||||
return output_transpose;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -231,66 +198,3 @@ GatherSinkingSplitBackward::GatherSinkingSplitBackward() {
|
||||
auto m = std::make_shared<Matcher>(gather_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
#if 0
|
||||
GatherSinkingSplitTransposeBackward::GatherSinkingSplitTransposeBackward() {
|
||||
MATCHER_SCOPE(GatherSinkingSplitTransposeBackward);
|
||||
|
||||
auto transpose_const_label = wrap_type<Constant>();
|
||||
auto transpose_label = wrap_type<Gather>({any_input(), transpose_const_label}, IsSplitSinked);
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose = as_type_ptr<Gather>(pattern_to_output.at(transpose_label).get_node_shared_ptr());
|
||||
|
||||
auto split = FindInputNode<Split>(transpose);
|
||||
auto split_axis_constant = as_type_ptr<Constant>(split->input_value(1).get_node_shared_ptr());
|
||||
if (!split_axis_constant) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int64_t split_axis;
|
||||
if (!GetSplitAxis(split_axis_constant, split->input_value(0).get_partial_shape().rank(), split_axis)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
OutputTranspose output_transpose = FindFirstOutputTranspose(split);
|
||||
|
||||
// TODO
|
||||
//
|
||||
intel_gna_debug::Print("gather_indices", gather_indices);
|
||||
|
||||
//
|
||||
std::vector<int64_t> new_indices(split->get_input_shape(0)[gather_axis]);
|
||||
std::iota(new_indices.begin(), new_indices.end(), 0);
|
||||
|
||||
const size_t base = output_transpose.output_idx * split->get_output_shape(0)[split_axis];
|
||||
for (size_t i = 0; i < gather_indices.size(); ++i) {
|
||||
new_indices[base + i] = base + gather_indices[i];
|
||||
}
|
||||
|
||||
intel_gna_debug::Print("new_indices", new_indices);
|
||||
|
||||
auto split_input = split->input_value(0);
|
||||
auto new_indices_const = std::make_shared<Constant>(output_transpose.gather_axis->get_element_type(),
|
||||
Shape{new_indices.size()},
|
||||
new_indices);
|
||||
auto new_axis_const = output_transpose.gather_axis->clone_with_new_inputs({});
|
||||
auto new_gather = std::make_shared<Gather>(split_input, new_indices_const, new_axis_const);
|
||||
split->input(0).replace_source_output(new_gather->output(0));
|
||||
copy_runtime_info(split_input.get_node_shared_ptr(), {new_gather, new_indices_const, new_axis_const});
|
||||
register_new_node(new_gather);
|
||||
|
||||
for (auto& input : split->get_output_target_inputs(output_transpose.output_idx)) {
|
||||
Node* consumer = input.get_node();
|
||||
if (consumer->get_output_size() != 1)
|
||||
continue;
|
||||
consumer->output(0).replace(split->output(output_transpose.output_idx));
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(gather_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
#endif
|
||||
|
@ -325,6 +325,7 @@ bool CanPropagateGatherForwardThrough(Node* node) {
|
||||
CHECK_GATHER_SINKING_SUPPORTED(ov::op::util::BinaryElementwiseArithmetic, node);
|
||||
CHECK_GATHER_SINKING_SUPPORTED(Gather, node);
|
||||
CHECK_GATHER_SINKING_SUPPORTED(Reshape, node);
|
||||
CHECK_GATHER_SINKING_SUPPORTED(MatMul, node);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user