[Offline Transformations] EfficientNet_b0 Pruning Transformation Enabling (#8926)

* Add include guard to file_utils.cpp

* Rebase src

* Rename acc tests, fix rebase

* Revert debug changes

* Fix linter

* Move ac tests to new template

* Test updated

* Fix result operation are sharing output tensor with previous op

* Pruning test visualzation option

* Add ac support for all test cases with pruned kernels

* Remove redundant files

* Enable extended pruning logging by env variable

* Adjust pruning tests

* Remove pruning extended debug env var

* Enable init masks only in debug mode

* Set result mask to input tensor instead of output tensor by separate key

* Bug fix / Test coverage

* Fix comments
This commit is contained in:
Daniil Lyakhov 2022-01-14 18:43:38 +03:00 committed by GitHub
parent 4734b17e52
commit 4693f7b854
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1051 additions and 91 deletions

View File

@ -149,19 +149,24 @@ public:
m_dependencies.push_back(mask.get());
}
/* Modify state of this mask by corresponding callback,
which returns modifying success status (bool) and then
modify all dependent masks by their corresponding callbacks*/
bool apply_callback(Mask::Ptr mask) {
// TODO: in case if callback returns false we need to propagate original value
const auto & ref_state = Mask(*this);
// Modify this mask by recived mask
if (!m_callbacks.at(mask.get())(shared_from_this())) {
return false;
}
// In case this mask already visited and didn't change by
// callback call - stop recursion
if (!m_need_initialization && *this == ref_state) {
return true;
}
// Mark mask as visited
m_need_initialization = false;
// recursively apply callbacks for each dependent mask
for (const auto & m_dependency : m_dependencies) {
if (!m_dependency->apply_callback(shared_from_this())) {
return false;
@ -185,13 +190,21 @@ public:
}
}
/* Ask mask to update ther dependencies
even if mask value wasn't changed on callback*/
void initialize_dependencies() {
m_need_initialization = true;
}
private:
bool m_is_shape_like{false};
// Masks dependent on this mask vs methods, specifying how
// this mask will be modifed by correspondent dependent mask
std::map<Mask *, std::function<bool(Mask::Ptr)>> m_callbacks;
// Vector of all dependent masks
std::vector<Mask *> m_dependencies;
// Param used like visiting label (visited or not) during mask applying call
bool m_need_initialization{true};
};
@ -203,4 +216,14 @@ Mask::Ptr getMask(const Output<Node> & output);
void setMask(Output<Node> output, const Mask::Ptr & mask);
void setMask(Input<Node> node, const Mask::Ptr & mask);
#ifdef ENABLE_OPENVINO_DEBUG
/* Get mask which was defined on InitMasks matcher pass*/
Mask::Ptr getInitMask(const Output<Node> & output);
/* Set mask which was defined on InitMasks matcher pass*/
void setInitMask(Output<Node> output, const Mask::Ptr & mask);
#endif
} // namespace ngraph

View File

@ -56,6 +56,9 @@ ngraph::pass::InitConstMask::InitConstMask(const ngraph::AxisSet & dims,
}
setMask(const_node, mask);
#ifdef ENABLE_OPENVINO_DEBUG
setInitMask(const_node, mask);
#endif
if (!mask->all_dims_are_empty()) {
NGRAPH_DEBUG << "MASK (" << const_node->get_friendly_name() << ") " << *mask << std::endl;
}

View File

@ -14,16 +14,21 @@ namespace ngraph {
Mask::Ptr getMask(const Output<const Node> & output) {
auto &rtInfo = output.get_rt_info();
if (!rtInfo.count(Mask::get_type_info_static())) return nullptr;
const auto &attr = rtInfo.at(Mask::get_type_info_static());
const auto attr_it = rtInfo.find(Mask::get_type_info_static());
if (attr_it == rtInfo.end()) return nullptr;
const auto &attr = attr_it->second;
return attr.as<Mask::Ptr>();
}
Mask::Ptr getMask(const Output<Node> & output) {
auto &rtInfo = output.get_rt_info();
if (!rtInfo.count(Mask::get_type_info_static())) return nullptr;
const auto &attr = rtInfo.at(Mask::get_type_info_static());
const auto attr_it = rtInfo.find(Mask::get_type_info_static());
if (attr_it == rtInfo.end()) return nullptr;
const auto &attr = attr_it->second;
return attr.as<Mask::Ptr>();
}
@ -32,6 +37,31 @@ void setMask(Output<Node> output, const Mask::Ptr & mask) {
rtInfo[Mask::get_type_info_static()] = mask;
}
void setMask(Input<Node> node, const Mask::Ptr & mask) {
auto &rtInfo = node.get_rt_info();
rtInfo[Mask::get_type_info_static()] = mask;
}
#ifdef ENABLE_OPENVINO_DEBUG
static const char g_init_mask_key[] = "InitMask";
Mask::Ptr getInitMask(const Output<Node> & output) {
auto &rtInfo = output.get_rt_info();
const auto attr_it = rtInfo.find(g_init_mask_key);
if (attr_it == rtInfo.end()) return nullptr;
const auto &attr = attr_it->second;
return attr.as<Mask::Ptr>();
}
void setInitMask(Output<Node> output, const Mask::Ptr & mask) {
auto &rtInfo = output.get_rt_info();
auto copy_mask = std::make_shared<Mask>();
std::copy(mask->begin(), mask->end(), std::back_inserter(*copy_mask));
rtInfo[g_init_mask_key] = copy_mask;
}
#endif
std::ostream & operator<< (std::ostream & out, const Mask & mask) {
out << "[ ";
for (auto & dim : mask) {

View File

@ -5,6 +5,8 @@
#include "pruning.hpp"
#include "mask_attribute.hpp"
#include <algorithm>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/opsets/opset5.hpp>
@ -22,6 +24,7 @@ class GroupConvolution;
class GroupConvolutionReshape;
class Elementwise;
class PassThrough;
class Reduce;
class StopPropagation;
class FakeQuantize;
class Concat;
@ -228,7 +231,11 @@ public:
return false;
}
auto input_mask_row = input_mask.get();
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
// Check reshape mask already initialized during StopPropagation pass
auto output_mask = getMask(m_output);
if (!output_mask)
output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
auto output_mask_row = output_mask.get();
// Depthwise Convolution pruned only by input channels (== groups) ->
@ -301,7 +308,6 @@ public:
}
InitConstMask({0, 1}).apply(m_weights.get_node_shared_ptr());
auto weights_mask = getMask(m_weights);
if (!weights_mask) {
NGRAPH_DEBUG << "No weights mask for: " << m_output.get_node()->get_friendly_name() << std::endl;
@ -561,9 +567,12 @@ public:
class ngraph::pass::mask_propagation::PassThrough : public MatcherPass {
public:
PassThrough() {
auto unary_op = pattern::wrap_type<op::util::UnaryElementwiseArithmetic, opset6::Clamp,
opset6::Convert, opset6::ConvertLike, opset6::AvgPool, opset6::MaxPool,
opset6::ROIPooling, opset6::PSROIPooling, opset6::Pad>();
auto unary_op = pattern::wrap_type<op::util::UnaryElementwiseArithmetic, opset6::Clamp, opset6::Swish,
opset6::Elu, opset6::HardSigmoid, opset6::PRelu, opset6::Mish,
opset6::Softmax, opset6::SoftPlus, opset6::Convert, opset6::ConvertLike,
opset6::AvgPool, opset6::MaxPool, opset6::ROIPooling, opset6::PSROIPooling,
opset6::Pad>();
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto & pattern_map = m.get_pattern_value_map();
@ -582,20 +591,93 @@ public:
}
};
class ngraph::pass::mask_propagation::Reduce : public MatcherPass {
public:
Reduce() {
auto inputs = pattern::any_input();
auto weights = pattern::wrap_type<opset6::Constant>();
auto pooling_by_reduce = pattern::wrap_type<opset6::ReduceMin, opset6::ReduceMax, opset6::ReduceMean>({inputs, weights});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto & pattern_map = m.get_pattern_value_map();
const auto m_weights = pattern_map.at(weights);
const auto & m_input = pattern_map.at(inputs);
const auto & m_output = pattern_map.at(pooling_by_reduce);
// Check reduce operation reduces only dimension without masks
if (auto input_mask = getMask(m_input)) {
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
const auto constant = std::dynamic_pointer_cast<opset6::Constant>(m_weights.get_node_shared_ptr());
const auto reduce_dims = constant->cast_vector<int64_t>();
auto input_mask_row = input_mask.get();
auto output_mask_row = output_mask.get();
input_mask->add_callback([output_mask_row](Mask::Ptr cur_mask) -> bool {
cur_mask->copy_value_from_mask(output_mask_row);
return true;
}, output_mask);
output_mask->add_callback([input_mask_row, reduce_dims](Mask::Ptr cur_mask) -> bool{
// Propagate masks through dimension only if this dimension isn't reduced
for (size_t dim = 0; dim < std::min(cur_mask->size(), input_mask_row->size()); ++dim)
if (std::find(reduce_dims.begin(), reduce_dims.end(), dim) == reduce_dims.end())
cur_mask->at(dim) = input_mask_row->at(dim);
else if (cur_mask->at(dim) != input_mask_row->at(dim))
cur_mask->initialize_dependencies();
return true;
}, input_mask);
// Invalidate current mask and its parent masks
output_mask->apply_callback(input_mask);
setMask(m_output, output_mask);
}
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(pooling_by_reduce, "PassThroughReduceMaskPropagation");
register_matcher(m, callback);
}
};
class ngraph::pass::mask_propagation::StopPropagation : public MatcherPass {
public:
StopPropagation() {
auto any_node = pattern::any_input();
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto & pattern_map = m.get_pattern_value_map();
const auto & m_output = pattern_map.at(any_node);
const auto & node = m.get_match_root();
auto output_mask = std::make_shared<Mask>(m_output.get_partial_shape().rank().get_length());
bool any_input_with_masks = false;
for (const auto & input : node->input_values()) {
if (auto mask = getMask(input)) {
if (auto input_mask = getMask(input)) {
auto input_mask_row = input_mask.get();
input_mask->add_callback([](Mask::Ptr cur_mask) -> bool {
cur_mask->clean_dim_values();
return true;
}, output_mask);
output_mask->add_callback([input_mask_row](Mask::Ptr cur_mask) -> bool{
cur_mask->copy_value_from_mask(input_mask_row);
return true;
}, input_mask);
// Invalidate current mask and its parent masks
mask->invalidate();
NGRAPH_DEBUG << "Invalidate masks for " << *input.get_node() << " because " << node << " is unknown\n";
output_mask->apply_callback(input_mask);
NGRAPH_DEBUG << "Invalidate masks for " << *input.get_node() << " because " << node << " is in scope of stop ops.\n";
any_input_with_masks = true;
}
}
if (any_input_with_masks) {
// Set mask to stop op first input tensor to prevent mask rewriting for
// nodes which share output tensor with previous node.
if (ngraph::is_type<opset6::Result>(m_output.get_node_shared_ptr()))
setMask(*m_output.get_node()->inputs().begin(), output_mask);
else
setMask(m_output, output_mask);
}
return true;
};
@ -610,6 +692,7 @@ ngraph::pass::PropagateMasks::PropagateMasks() {
add_matcher<mask_propagation::GroupConvolution>();
add_matcher<mask_propagation::Elementwise>();
add_matcher<mask_propagation::PassThrough>();
add_matcher<mask_propagation::Reduce>();
add_matcher<mask_propagation::FakeQuantize>();
add_matcher<mask_propagation::Concat>();
add_matcher<mask_propagation::StopPropagation>();

View File

@ -31,6 +31,25 @@ bool ngraph::pass::ShrinkWeights::run_on_model(const std::shared_ptr<ngraph::Fun
total_weights_count += shape_size(const_shape);
auto mask = getMask(const_node->output(0));
#ifdef ENABLE_OPENVINO_DEBUG
auto init_mask = getInitMask(const_node->output(0));
if (!mask && init_mask)
NGRAPH_DEBUG << "Mask was ruined for node:" << const_node->get_friendly_name() << "\nInit mask: " << *init_mask;
if (mask && init_mask) {
for (size_t dim = 0; dim < init_mask->size(); ++dim) {
auto& dim_init_set = (*init_mask)[dim];
auto& dim_current_set = (*mask)[dim];
if (!dim_init_set.empty() && !std::includes(dim_current_set.begin(), dim_current_set.end(),
dim_init_set.begin(), dim_init_set.end())) {
NGRAPH_DEBUG << "Mask was ruined for node:" << const_node->get_friendly_name()
<< "\nInit mask: " << *init_mask << "\nCurrent mask: " << *mask;
break;
}
}
}
#endif
if (!mask) continue;
auto last_output = const_node->output(0);

View File

@ -2,6 +2,9 @@
// SPDX-License-Identifier: Apache-2.0
//
#ifndef FILE_UTILS_CPP
#define FILE_UTILS_CPP
#include <cstring>
#include <fstream>
#include <string>
@ -130,3 +133,5 @@ std::string getIELibraryPath() {
}
} // namespace InferenceEngine
#endif