[Offline Transformations] EfficientNet_b0 Pruning Transformation Enabling (#8926)
* Add include guard to file_utils.cpp * Rebase src * Rename acc tests, fix rebase * Revert debug changes * Fix linter * Move ac tests to new template * Test updated * Fix result operation are sharing output tensor with previous op * Pruning test visualzation option * Add ac support for all test cases with pruned kernels * Remove redundant files * Enable extended pruning logging by env variable * Adjust pruning tests * Remove pruning extended debug env var * Enable init masks only in debug mode * Set result mask to input tensor instead of output tensor by separate key * Bug fix / Test coverage * Fix comments
This commit is contained in:
parent
4734b17e52
commit
4693f7b854
@ -149,19 +149,24 @@ public:
|
||||
m_dependencies.push_back(mask.get());
|
||||
}
|
||||
|
||||
/* Modify state of this mask by corresponding callback,
|
||||
which returns modifying success status (bool) and then
|
||||
modify all dependent masks by their corresponding callbacks*/
|
||||
bool apply_callback(Mask::Ptr mask) {
|
||||
// TODO: in case if callback returns false we need to propagate original value
|
||||
const auto & ref_state = Mask(*this);
|
||||
// Modify this mask by recived mask
|
||||
if (!m_callbacks.at(mask.get())(shared_from_this())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// In case this mask already visited and didn't change by
|
||||
// callback call - stop recursion
|
||||
if (!m_need_initialization && *this == ref_state) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Mark mask as visited
|
||||
m_need_initialization = false;
|
||||
|
||||
// recursively apply callbacks for each dependent mask
|
||||
for (const auto & m_dependency : m_dependencies) {
|
||||
if (!m_dependency->apply_callback(shared_from_this())) {
|
||||
return false;
|
||||
@ -185,13 +190,21 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
/* Ask mask to update ther dependencies
|
||||
even if mask value wasn't changed on callback*/
|
||||
void initialize_dependencies() {
|
||||
m_need_initialization = true;
|
||||
}
|
||||
|
||||
private:
|
||||
bool m_is_shape_like{false};
|
||||
|
||||
// Masks dependent on this mask vs methods, specifying how
|
||||
// this mask will be modifed by correspondent dependent mask
|
||||
std::map<Mask *, std::function<bool(Mask::Ptr)>> m_callbacks;
|
||||
|
||||
// Vector of all dependent masks
|
||||
std::vector<Mask *> m_dependencies;
|
||||
|
||||
// Param used like visiting label (visited or not) during mask applying call
|
||||
bool m_need_initialization{true};
|
||||
};
|
||||
|
||||
@ -203,4 +216,14 @@ Mask::Ptr getMask(const Output<Node> & output);
|
||||
|
||||
void setMask(Output<Node> output, const Mask::Ptr & mask);
|
||||
|
||||
void setMask(Input<Node> node, const Mask::Ptr & mask);
|
||||
|
||||
#ifdef ENABLE_OPENVINO_DEBUG
|
||||
/* Get mask which was defined on InitMasks matcher pass*/
|
||||
Mask::Ptr getInitMask(const Output<Node> & output);
|
||||
|
||||
/* Set mask which was defined on InitMasks matcher pass*/
|
||||
void setInitMask(Output<Node> output, const Mask::Ptr & mask);
|
||||
#endif
|
||||
|
||||
} // namespace ngraph
|
||||
|
@ -56,6 +56,9 @@ ngraph::pass::InitConstMask::InitConstMask(const ngraph::AxisSet & dims,
|
||||
}
|
||||
|
||||
setMask(const_node, mask);
|
||||
#ifdef ENABLE_OPENVINO_DEBUG
|
||||
setInitMask(const_node, mask);
|
||||
#endif
|
||||
if (!mask->all_dims_are_empty()) {
|
||||
NGRAPH_DEBUG << "MASK (" << const_node->get_friendly_name() << ") " << *mask << std::endl;
|
||||
}
|
||||
|
@ -14,16 +14,21 @@ namespace ngraph {
|
||||
|
||||
Mask::Ptr getMask(const Output<const Node> & output) {
|
||||
auto &rtInfo = output.get_rt_info();
|
||||
if (!rtInfo.count(Mask::get_type_info_static())) return nullptr;
|
||||
|
||||
const auto &attr = rtInfo.at(Mask::get_type_info_static());
|
||||
const auto attr_it = rtInfo.find(Mask::get_type_info_static());
|
||||
if (attr_it == rtInfo.end()) return nullptr;
|
||||
|
||||
const auto &attr = attr_it->second;
|
||||
return attr.as<Mask::Ptr>();
|
||||
}
|
||||
|
||||
Mask::Ptr getMask(const Output<Node> & output) {
|
||||
auto &rtInfo = output.get_rt_info();
|
||||
if (!rtInfo.count(Mask::get_type_info_static())) return nullptr;
|
||||
const auto &attr = rtInfo.at(Mask::get_type_info_static());
|
||||
|
||||
const auto attr_it = rtInfo.find(Mask::get_type_info_static());
|
||||
if (attr_it == rtInfo.end()) return nullptr;
|
||||
|
||||
const auto &attr = attr_it->second;
|
||||
return attr.as<Mask::Ptr>();
|
||||
}
|
||||
|
||||
@ -32,6 +37,31 @@ void setMask(Output<Node> output, const Mask::Ptr & mask) {
|
||||
rtInfo[Mask::get_type_info_static()] = mask;
|
||||
}
|
||||
|
||||
void setMask(Input<Node> node, const Mask::Ptr & mask) {
|
||||
auto &rtInfo = node.get_rt_info();
|
||||
rtInfo[Mask::get_type_info_static()] = mask;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_OPENVINO_DEBUG
|
||||
static const char g_init_mask_key[] = "InitMask";
|
||||
Mask::Ptr getInitMask(const Output<Node> & output) {
|
||||
auto &rtInfo = output.get_rt_info();
|
||||
|
||||
const auto attr_it = rtInfo.find(g_init_mask_key);
|
||||
if (attr_it == rtInfo.end()) return nullptr;
|
||||
|
||||
const auto &attr = attr_it->second;
|
||||
return attr.as<Mask::Ptr>();
|
||||
}
|
||||
|
||||
void setInitMask(Output<Node> output, const Mask::Ptr & mask) {
|
||||
auto &rtInfo = output.get_rt_info();
|
||||
auto copy_mask = std::make_shared<Mask>();
|
||||
std::copy(mask->begin(), mask->end(), std::back_inserter(*copy_mask));
|
||||
rtInfo[g_init_mask_key] = copy_mask;
|
||||
}
|
||||
#endif
|
||||
|
||||
std::ostream & operator<< (std::ostream & out, const Mask & mask) {
|
||||
out << "[ ";
|
||||
for (auto & dim : mask) {
|
||||
|
@ -5,6 +5,8 @@
|
||||
#include "pruning.hpp"
|
||||
#include "mask_attribute.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
@ -22,6 +24,7 @@ class GroupConvolution;
|
||||
class GroupConvolutionReshape;
|
||||
class Elementwise;
|
||||
class PassThrough;
|
||||
class Reduce;
|
||||
class StopPropagation;
|
||||
class FakeQuantize;
|
||||
class Concat;
|
||||
@ -228,7 +231,11 @@ public:
|
||||
return false;
|
||||
}
|
||||
auto input_mask_row = input_mask.get();
|
||||
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
|
||||
// Check reshape mask already initialized during StopPropagation pass
|
||||
auto output_mask = getMask(m_output);
|
||||
if (!output_mask)
|
||||
output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
|
||||
|
||||
auto output_mask_row = output_mask.get();
|
||||
|
||||
// Depthwise Convolution pruned only by input channels (== groups) ->
|
||||
@ -301,7 +308,6 @@ public:
|
||||
}
|
||||
|
||||
InitConstMask({0, 1}).apply(m_weights.get_node_shared_ptr());
|
||||
|
||||
auto weights_mask = getMask(m_weights);
|
||||
if (!weights_mask) {
|
||||
NGRAPH_DEBUG << "No weights mask for: " << m_output.get_node()->get_friendly_name() << std::endl;
|
||||
@ -561,9 +567,12 @@ public:
|
||||
class ngraph::pass::mask_propagation::PassThrough : public MatcherPass {
|
||||
public:
|
||||
PassThrough() {
|
||||
auto unary_op = pattern::wrap_type<op::util::UnaryElementwiseArithmetic, opset6::Clamp,
|
||||
opset6::Convert, opset6::ConvertLike, opset6::AvgPool, opset6::MaxPool,
|
||||
opset6::ROIPooling, opset6::PSROIPooling, opset6::Pad>();
|
||||
auto unary_op = pattern::wrap_type<op::util::UnaryElementwiseArithmetic, opset6::Clamp, opset6::Swish,
|
||||
opset6::Elu, opset6::HardSigmoid, opset6::PRelu, opset6::Mish,
|
||||
opset6::Softmax, opset6::SoftPlus, opset6::Convert, opset6::ConvertLike,
|
||||
opset6::AvgPool, opset6::MaxPool, opset6::ROIPooling, opset6::PSROIPooling,
|
||||
opset6::Pad>();
|
||||
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto & pattern_map = m.get_pattern_value_map();
|
||||
@ -582,19 +591,92 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class ngraph::pass::mask_propagation::Reduce : public MatcherPass {
|
||||
public:
|
||||
Reduce() {
|
||||
auto inputs = pattern::any_input();
|
||||
auto weights = pattern::wrap_type<opset6::Constant>();
|
||||
auto pooling_by_reduce = pattern::wrap_type<opset6::ReduceMin, opset6::ReduceMax, opset6::ReduceMean>({inputs, weights});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto & pattern_map = m.get_pattern_value_map();
|
||||
const auto m_weights = pattern_map.at(weights);
|
||||
const auto & m_input = pattern_map.at(inputs);
|
||||
const auto & m_output = pattern_map.at(pooling_by_reduce);
|
||||
|
||||
|
||||
// Check reduce operation reduces only dimension without masks
|
||||
if (auto input_mask = getMask(m_input)) {
|
||||
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
|
||||
const auto constant = std::dynamic_pointer_cast<opset6::Constant>(m_weights.get_node_shared_ptr());
|
||||
const auto reduce_dims = constant->cast_vector<int64_t>();
|
||||
|
||||
auto input_mask_row = input_mask.get();
|
||||
auto output_mask_row = output_mask.get();
|
||||
input_mask->add_callback([output_mask_row](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->copy_value_from_mask(output_mask_row);
|
||||
return true;
|
||||
}, output_mask);
|
||||
output_mask->add_callback([input_mask_row, reduce_dims](Mask::Ptr cur_mask) -> bool{
|
||||
// Propagate masks through dimension only if this dimension isn't reduced
|
||||
for (size_t dim = 0; dim < std::min(cur_mask->size(), input_mask_row->size()); ++dim)
|
||||
if (std::find(reduce_dims.begin(), reduce_dims.end(), dim) == reduce_dims.end())
|
||||
cur_mask->at(dim) = input_mask_row->at(dim);
|
||||
else if (cur_mask->at(dim) != input_mask_row->at(dim))
|
||||
cur_mask->initialize_dependencies();
|
||||
return true;
|
||||
}, input_mask);
|
||||
|
||||
// Invalidate current mask and its parent masks
|
||||
output_mask->apply_callback(input_mask);
|
||||
setMask(m_output, output_mask);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(pooling_by_reduce, "PassThroughReduceMaskPropagation");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
};
|
||||
|
||||
class ngraph::pass::mask_propagation::StopPropagation : public MatcherPass {
|
||||
public:
|
||||
StopPropagation() {
|
||||
auto any_node = pattern::any_input();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto & pattern_map = m.get_pattern_value_map();
|
||||
const auto & m_output = pattern_map.at(any_node);
|
||||
const auto & node = m.get_match_root();
|
||||
|
||||
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
|
||||
bool any_input_with_masks = false;
|
||||
for (const auto & input : node->input_values()) {
|
||||
if (auto mask = getMask(input)) {
|
||||
// Invalidate current mask and its parent masks
|
||||
mask->invalidate();
|
||||
NGRAPH_DEBUG << "Invalidate masks for " << *input.get_node() << " because " << node << " is unknown\n";
|
||||
if (auto input_mask = getMask(input)) {
|
||||
auto input_mask_row = input_mask.get();
|
||||
input_mask->add_callback([](Mask::Ptr cur_mask) -> bool {
|
||||
cur_mask->clean_dim_values();
|
||||
return true;
|
||||
}, output_mask);
|
||||
output_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool{
|
||||
cur_mask->copy_value_from_mask(input_mask_row);
|
||||
return true;
|
||||
}, input_mask);
|
||||
|
||||
// Invalidate current mask and its parent masks
|
||||
output_mask->apply_callback(input_mask);
|
||||
NGRAPH_DEBUG << "Invalidate masks for " << *input.get_node() << " because " << node << " is in scope of stop ops.\n";
|
||||
any_input_with_masks = true;
|
||||
}
|
||||
}
|
||||
if (any_input_with_masks) {
|
||||
// Set mask to stop op first input tensor to prevent mask rewriting for
|
||||
// nodes which share output tensor with previous node.
|
||||
if (ngraph::is_type<opset6::Result>(m_output.get_node_shared_ptr()))
|
||||
setMask(*m_output.get_node()->inputs().begin(), output_mask);
|
||||
else
|
||||
setMask(m_output, output_mask);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
@ -610,6 +692,7 @@ ngraph::pass::PropagateMasks::PropagateMasks() {
|
||||
add_matcher<mask_propagation::GroupConvolution>();
|
||||
add_matcher<mask_propagation::Elementwise>();
|
||||
add_matcher<mask_propagation::PassThrough>();
|
||||
add_matcher<mask_propagation::Reduce>();
|
||||
add_matcher<mask_propagation::FakeQuantize>();
|
||||
add_matcher<mask_propagation::Concat>();
|
||||
add_matcher<mask_propagation::StopPropagation>();
|
||||
|
@ -31,6 +31,25 @@ bool ngraph::pass::ShrinkWeights::run_on_model(const std::shared_ptr<ngraph::Fun
|
||||
total_weights_count += shape_size(const_shape);
|
||||
|
||||
auto mask = getMask(const_node->output(0));
|
||||
|
||||
#ifdef ENABLE_OPENVINO_DEBUG
|
||||
auto init_mask = getInitMask(const_node->output(0));
|
||||
if (!mask && init_mask)
|
||||
NGRAPH_DEBUG << "Mask was ruined for node:" << const_node->get_friendly_name() << "\nInit mask: " << *init_mask;
|
||||
if (mask && init_mask) {
|
||||
for (size_t dim = 0; dim < init_mask->size(); ++dim) {
|
||||
auto& dim_init_set = (*init_mask)[dim];
|
||||
auto& dim_current_set = (*mask)[dim];
|
||||
if (!dim_init_set.empty() && !std::includes(dim_current_set.begin(), dim_current_set.end(),
|
||||
dim_init_set.begin(), dim_init_set.end())) {
|
||||
NGRAPH_DEBUG << "Mask was ruined for node:" << const_node->get_friendly_name()
|
||||
<< "\nInit mask: " << *init_mask << "\nCurrent mask: " << *mask;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if (!mask) continue;
|
||||
|
||||
auto last_output = const_node->output(0);
|
||||
|
@ -2,6 +2,9 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#ifndef FILE_UTILS_CPP
|
||||
#define FILE_UTILS_CPP
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
@ -130,3 +133,5 @@ std::string getIELibraryPath() {
|
||||
}
|
||||
|
||||
} // namespace InferenceEngine
|
||||
|
||||
#endif
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user