Files
openvino/src/plugins/intel_gna/transformations/remove_extra_reshapes.cpp
Nadezhda Ageeva 56581dbe2e [GNA] Support new kaldi irs (#9474)
* Support new kaldi IRs (generated in NHWC layout)

* Update tests with activation and fq

* Cleanup

* Fix reordering FQ and MaxPool and problem with overflow

* Fix win

* Update src/plugins/intel_gna/transformations/unfuse_reshape_and_transpose.hpp

Co-authored-by: Elizaveta Lobanova <elizaveta.lobanova@intel.com>

* Update src/plugins/intel_gna/transformations/unfuse_reshape_and_transpose.cpp

Co-authored-by: Elizaveta Lobanova <elizaveta.lobanova@intel.com>

* Update inference-engine/tests/unit/gna/ngraph/transformations/gna_unfuse_reshape_and_transpose.cpp

Co-authored-by: Elizaveta Lobanova <elizaveta.lobanova@intel.com>

* Code review

Co-authored-by: Elizaveta Lobanova <elizaveta.lobanova@intel.com>
2022-01-17 14:16:23 +03:00

35 lines
1.3 KiB
C++

// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/cc/ngraph/itt.hpp>
#include "transformations/remove_extra_reshapes.hpp"
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
using namespace GNAPluginNS;
NGRAPH_RTTI_DEFINITION(RemoveExtraReshapes, "RemoveExtraReshapes", 0);
RemoveExtraReshapes::RemoveExtraReshapes() {
MATCHER_SCOPE(RemoveExtraReshapes);
const auto reshape = ngraph::pattern::wrap_type<ngraph::opset7::Reshape>(
[](const ngraph::Output<ngraph::Node>& value) {
return (value.get_node_shared_ptr()->get_input_shape(0) == value.get_node_shared_ptr()->get_output_shape(0));
});
const auto pooling = ngraph::pattern::wrap_type<ngraph::opset7::MaxPool>({reshape});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher &m) {
const auto& pattern_map = m.get_pattern_value_map();
const auto reshape_node = pattern_map.at(reshape).get_node_shared_ptr();
ngraph::replace_output_update_name(reshape_node->output(0), reshape_node->input_value(0));
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(pooling, matcher_name);
this->register_matcher(m, callback);
}