setBatchSize: getting rid of ConstantFolding (#2842)
* setBatchSize: getting rid of setBatchSize * Trigger CI * Feedback adressed * Trigger CI * f -> specialized_function
This commit is contained in:
parent
5fa569cbd5
commit
347e1206d5
@ -23,8 +23,6 @@ namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API MimicSetBatchSize;
|
||||
class TRANSFORMATIONS_API DisableCFForPriorBoxes;
|
||||
class TRANSFORMATIONS_API EnableCFForPriorBoxes;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
@ -41,23 +39,7 @@ class TRANSFORMATIONS_API EnableCFForPriorBoxes;
|
||||
* This transformation should be executed only while setBatchSize method call
|
||||
*/
|
||||
|
||||
class ngraph::pass::MimicSetBatchSize: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
MimicSetBatchSize();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief DisableCFForPriorBoxes and EnableCFForPriorBoxes transformations are needed to avoid unnecessary PriorBox folding
|
||||
*/
|
||||
class ngraph::pass::DisableCFForPriorBoxes: public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
};
|
||||
|
||||
class ngraph::pass::EnableCFForPriorBoxes: public ngraph::pass::FunctionPass {
|
||||
class ngraph::pass::MimicSetBatchSize : public ngraph::pass::FunctionPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
|
||||
|
@ -2,35 +2,38 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <transformations/smart_reshape/mimic_set_batch_size.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::MimicSetBatchSize, "MimicSetBatchSize", 0);
|
||||
|
||||
ngraph::pass::MimicSetBatchSize::MimicSetBatchSize() {
|
||||
auto reshape_label = ngraph::pattern::wrap_type<opset5::Reshape>({pattern::any_input(pattern::has_static_dim(0)),
|
||||
ngraph::pattern::wrap_type<opset5::Constant>()},
|
||||
[](const Output<Node> &output) { return output.get_partial_shape().rank().is_static() && output.get_partial_shape().rank().get_length() > 1; });
|
||||
bool ngraph::pass::MimicSetBatchSize::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
// extracting ratio of out to in 0-index dimension value from the folded function
|
||||
auto specialized_function = ngraph::clone_function(*f);
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.run_passes(specialized_function);
|
||||
|
||||
matcher_pass_callback callback = [=](pattern::Matcher &m) -> bool {
|
||||
const auto & reshape = m.get_match_root();
|
||||
auto pattern = std::dynamic_pointer_cast<opset5::Constant>(reshape->get_input_node_shared_ptr(1));
|
||||
if (!pattern)
|
||||
return false;
|
||||
std::map<std::string, float> scale;
|
||||
for (const auto & node : specialized_function->get_ops()) {
|
||||
if (const auto & reshape = std::dynamic_pointer_cast<opset5::Reshape>(node)) {
|
||||
const auto in_pshape = reshape->get_input_partial_shape(0), out_pshape = reshape->get_output_partial_shape(0);
|
||||
if (in_pshape.rank().is_dynamic() || in_pshape.rank().get_length() <= 1 || in_pshape[0].is_dynamic() ||
|
||||
out_pshape.rank().is_dynamic() || out_pshape.rank().get_length() <= 1 || out_pshape[0].is_dynamic())
|
||||
continue;
|
||||
const auto & pattern = std::dynamic_pointer_cast<opset5::Constant>(reshape->get_input_node_shared_ptr(1));
|
||||
if (pattern && pattern->cast_vector<int64_t>()[0] > 0) {
|
||||
scale[reshape->get_friendly_name()] = static_cast<float>(out_pshape[0].get_length()) / static_cast<float>(in_pshape[0].get_length());
|
||||
}
|
||||
}
|
||||
}
|
||||
// apply transformation to original function
|
||||
bool transformed = false;
|
||||
for (auto & reshape : f->get_ops()) {
|
||||
if (!is_type<opset5::Reshape>(reshape) || !scale.count(reshape->get_friendly_name()) || reshape->get_output_partial_shape(0).rank().is_dynamic())
|
||||
continue;
|
||||
|
||||
const auto & pattern_vector = pattern->cast_vector<int64_t>();
|
||||
if (pattern_vector.empty() || pattern_vector[0] < 1)
|
||||
return false;
|
||||
|
||||
// mimicking old setBatchSize style (copied):
|
||||
// float diff = static_cast<float>(dims.at(0)) / static_cast<float>(originalBatchSize);
|
||||
// dims.at(0) = static_cast<size_t>(std::ceil(size * diff));
|
||||
|
||||
const auto & old_input_batch = static_cast<float>(reshape->get_input_partial_shape(0)[0].get_length());
|
||||
const auto & old_output_batch = static_cast<float>(pattern_vector[0]);
|
||||
|
||||
const auto & scale = old_output_batch / old_input_batch;
|
||||
|
||||
const auto & shape_of = std::make_shared<opset5::ShapeOf>(reshape->get_input_source_output(0), pattern->get_element_type());
|
||||
const auto & shape_of = std::make_shared<opset5::ShapeOf>(reshape->get_input_source_output(0), reshape->get_input_element_type(1));
|
||||
const auto & new_input_batch = std::make_shared<ngraph::opset5::Gather>(
|
||||
shape_of, ngraph::opset5::Constant::create(ngraph::element::i64, {1}, std::vector<int64_t>{0}),
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, {}, std::vector<int64_t>{0}));
|
||||
@ -39,75 +42,18 @@ ngraph::pass::MimicSetBatchSize::MimicSetBatchSize() {
|
||||
std::make_shared<opset5::Ceiling>(
|
||||
std::make_shared<opset5::Multiply>(
|
||||
std::make_shared<opset5::Convert>(new_input_batch, element::f32),
|
||||
opset5::Constant::create(element::f32, {1}, {scale}))),
|
||||
pattern->get_element_type());
|
||||
opset5::Constant::create(element::f32, {1}, {scale[reshape->get_friendly_name()]}))),
|
||||
reshape->get_input_element_type(1));
|
||||
|
||||
auto new_reshape_pattern = new_output_batch;
|
||||
const auto rank = pattern_vector.size();
|
||||
if (rank > 1) {
|
||||
std::vector<int64_t> non_batch_dims(rank - 1);
|
||||
std::iota(non_batch_dims.begin(), non_batch_dims.end(), 1);
|
||||
const auto & non_batch_dims_node = std::make_shared<ngraph::opset5::Gather>(
|
||||
pattern,
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, {non_batch_dims.size()}, non_batch_dims),
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, {}, std::vector<int64_t>{0}));
|
||||
new_reshape_pattern = std::make_shared<opset5::Concat>(OutputVector{new_reshape_pattern, non_batch_dims_node}, 0);
|
||||
}
|
||||
std::vector<int64_t> non_batch_dims(reshape->get_output_partial_shape(0).rank().get_length() - 1);
|
||||
std::iota(non_batch_dims.begin(), non_batch_dims.end(), 1);
|
||||
const auto & non_batch_dims_node = std::make_shared<ngraph::opset5::Gather>(
|
||||
reshape->input_value(1),
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, {non_batch_dims.size()}, non_batch_dims),
|
||||
ngraph::opset5::Constant::create(ngraph::element::i64, {}, std::vector<int64_t>{0}));
|
||||
auto new_reshape_pattern = std::make_shared<opset5::Concat>(OutputVector{new_output_batch, non_batch_dims_node}, 0);
|
||||
reshape->input(1).replace_source_output(new_reshape_pattern->output(0));
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_label, "MimicSetBatchSize");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
|
||||
void set_folding_for_PriorBox(std::shared_ptr<ngraph::Node> prior_box, bool flag) {
|
||||
std::string rt_info_disable_cf = "DISABLED_CONSTANT_FOLDING";
|
||||
static std::unordered_set<ngraph::NodeTypeInfo> allowed_to_skip = {
|
||||
ngraph::opset1::Convert::type_info,
|
||||
ngraph::opset1::StridedSlice::type_info,
|
||||
};
|
||||
static std::unordered_set<ngraph::NodeTypeInfo> types_to_find = {
|
||||
ngraph::opset1::ShapeOf::type_info,
|
||||
ngraph::opset3::ShapeOf::type_info,
|
||||
};
|
||||
|
||||
std::deque<std::shared_ptr<ngraph::Node>> nodes;
|
||||
nodes.push_back(prior_box->get_input_node_shared_ptr(0));
|
||||
nodes.push_back(prior_box->get_input_node_shared_ptr(1));
|
||||
|
||||
while (!nodes.empty()) {
|
||||
auto curr_node = nodes.front();
|
||||
nodes.pop_front();
|
||||
if (allowed_to_skip.count(curr_node->get_type_info())) {
|
||||
nodes.push_back(curr_node->get_input_node_shared_ptr(0));
|
||||
} else if (types_to_find.count(curr_node->get_type_info())) {
|
||||
auto& rt_info = curr_node->get_rt_info();
|
||||
if (flag && rt_info.count(rt_info_disable_cf))
|
||||
rt_info.erase(rt_info_disable_cf);
|
||||
if (!flag)
|
||||
rt_info[rt_info_disable_cf];
|
||||
}
|
||||
transformed = true;
|
||||
}
|
||||
return transformed;
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::DisableCFForPriorBoxes, "DisableCFForPriorBoxes", 0);
|
||||
|
||||
bool ngraph::pass::DisableCFForPriorBoxes::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
for (const auto & node : f->get_ops())
|
||||
if (ngraph::is_type<opset1::PriorBox>(node) || ngraph::is_type<opset1::PriorBoxClustered>(node)) {
|
||||
set_folding_for_PriorBox(node, false);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::EnableCFForPriorBoxes, "EnableCFForPriorBoxes", 0);
|
||||
|
||||
bool ngraph::pass::EnableCFForPriorBoxes::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
for (const auto & node : f->get_ops())
|
||||
if (ngraph::is_type<opset1::PriorBox>(node) || ngraph::is_type<opset1::PriorBoxClustered>(node)) {
|
||||
set_folding_for_PriorBox(node, true);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -22,15 +22,9 @@ bool ngraph::pass::SetBatchSize::run_on_function(std::shared_ptr<ngraph::Functio
|
||||
ngraph::pass::Manager manager;
|
||||
// This pass must be called first in pipeline
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
|
||||
manager.register_pass<ngraph::pass::DisableCFForPriorBoxes>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
manager.register_pass<ngraph::pass::EnableCFForPriorBoxes>();
|
||||
manager.register_pass<ngraph::pass::SharedSqueeze>();
|
||||
manager.register_pass<ngraph::pass::SqueezeStridedSlice>();
|
||||
manager.register_pass<ngraph::pass::StridedSliceSqueeze>();
|
||||
manager.register_pass<ngraph::pass::ReshapeTo1D>();
|
||||
|
||||
manager.register_pass<ngraph::pass::MimicSetBatchSize>();
|
||||
manager.run_passes(f);
|
||||
return true;
|
||||
|
@ -2,7 +2,7 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <transformations/common_optimizations/optimize_strided_slice.hpp>
|
||||
#include <transformations/itt.hpp>
|
||||
#include <transformations/smart_reshape/strided_slice_squeeze.hpp>
|
||||
|
||||
#include <ngraph/ngraph.hpp>
|
||||
@ -10,9 +10,8 @@
|
||||
#include <ngraph/pattern/matcher.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <transformations/itt.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::StridedSliceSqueeze, "StridedSliceSqueeze", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::StridedSliceSqueeze, "ngraph::pass::StridedSliceSqueeze", 0);
|
||||
|
||||
ngraph::pass::StridedSliceSqueeze::StridedSliceSqueeze() {
|
||||
auto ss_label = ngraph::pattern::wrap_type<opset5::StridedSlice>(pattern::consumers_count(1));
|
||||
@ -21,17 +20,10 @@ ngraph::pass::StridedSliceSqueeze::StridedSliceSqueeze() {
|
||||
matcher_pass_callback callback = [](pattern::Matcher &m) -> bool {
|
||||
const auto & squeeze = m.get_match_root();
|
||||
const auto & const_axes = std::dynamic_pointer_cast<ngraph::opset5::Constant>(squeeze->get_input_node_shared_ptr(1));
|
||||
|
||||
auto slice = std::dynamic_pointer_cast<ngraph::opset5::StridedSlice>(squeeze->get_input_node_shared_ptr(0));
|
||||
if (!const_axes || !slice)
|
||||
return false;
|
||||
|
||||
const auto & slice_plan = get_slice_plan(slice);
|
||||
if (slice_plan.begins.empty() || slice_plan.reshape_in_shape != slice_plan.reshape_out_shape || !slice_plan.reverse_axes.empty())
|
||||
return false;
|
||||
|
||||
const auto & axes = normalize_axes(squeeze->description(), const_axes->cast_vector<int64_t>(), squeeze->get_input_partial_shape(0).rank());
|
||||
|
||||
auto begin = std::dynamic_pointer_cast<ngraph::opset5::Constant>(slice->input_value(1).get_node_shared_ptr());
|
||||
auto end = std::dynamic_pointer_cast<ngraph::opset5::Constant>(slice->input_value(2).get_node_shared_ptr());
|
||||
auto strides = std::dynamic_pointer_cast<ngraph::opset5::Constant>(slice->input_value(3).get_node_shared_ptr());
|
||||
@ -47,17 +39,28 @@ ngraph::pass::StridedSliceSqueeze::StridedSliceSqueeze() {
|
||||
auto shrink_axis_mask = slice->get_shrink_axis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0) : slice->get_shrink_axis_mask();
|
||||
auto ellipsis_mask = slice->get_ellipsis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0) : slice->get_ellipsis_mask();
|
||||
|
||||
auto is_zero_vec = [](const std::vector<int64_t> & mask){ return std::all_of(mask.begin(), mask.end(), [](const int64_t& i){ return i == 0; }); };
|
||||
if (!is_zero_vec(new_axis_mask) || !is_zero_vec(shrink_axis_mask) || !is_zero_vec(ellipsis_mask))
|
||||
return false;
|
||||
if (!std::all_of(strides_vec.begin(), strides_vec.end(), [](const int64_t& i){ return i == 1; }))
|
||||
return false;
|
||||
|
||||
const auto & axes = normalize_axes(squeeze->description(), const_axes->cast_vector<int64_t>(), squeeze->get_input_partial_shape(0).rank());
|
||||
for (const auto & axis : axes) {
|
||||
if ((slice_plan.ends[axis] - slice_plan.begins[axis]) != 1 && slice_plan.strides[axis] == 1)
|
||||
return false;
|
||||
begin_vec[axis] = slice_plan.begins[axis];
|
||||
end_vec[axis] = slice_plan.ends[axis];
|
||||
strides_vec[axis] = 1;
|
||||
begin_mask[axis] = 0;
|
||||
end_mask[axis] = 0;
|
||||
new_axis_mask[axis] = 0;
|
||||
if (begin_mask[axis]) { // corresponding dimension of the begin input is ignored. starting from 0
|
||||
begin_vec[axis] = 0;
|
||||
end_vec[axis] = 1;
|
||||
begin_mask[axis] = 0;
|
||||
end_mask[axis] = 0;
|
||||
} else { // corresponding dimension of the begin input is used for slicing start
|
||||
if (begin_vec[axis] == -1) { // slicing the latest slice
|
||||
end_mask[axis] = 1;
|
||||
} else {
|
||||
end_vec[axis] = begin_vec[axis] + 1;
|
||||
end_mask[axis] = 0;
|
||||
}
|
||||
}
|
||||
shrink_axis_mask[axis] = 1;
|
||||
ellipsis_mask[axis] = 0;
|
||||
}
|
||||
|
||||
auto new_slice = std::make_shared<opset5::StridedSlice>(
|
||||
@ -72,10 +75,10 @@ ngraph::pass::StridedSliceSqueeze::StridedSliceSqueeze() {
|
||||
copy_runtime_info(slice, new_slice);
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(squeeze_label, "StridedSliceSqueeze");
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(squeeze_label, "ngraph::pass::StridedSliceSqueeze");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::SqueezeStridedSlice, "SqueezeStridedSlice", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::SqueezeStridedSlice, "ngraph::pass::SqueezeStridedSlice", 0);
|
||||
|
||||
ngraph::pass::SqueezeStridedSlice::SqueezeStridedSlice() {
|
||||
auto squeeze_label = ngraph::pattern::wrap_type<opset5::Squeeze>(
|
||||
@ -89,12 +92,6 @@ ngraph::pass::SqueezeStridedSlice::SqueezeStridedSlice() {
|
||||
if (!const_axes || !slice)
|
||||
return false;
|
||||
|
||||
const auto & slice_plan = get_slice_plan(slice);
|
||||
if (slice_plan.begins.empty() || slice_plan.reshape_in_shape != slice_plan.reshape_out_shape || !slice_plan.reverse_axes.empty())
|
||||
return false;
|
||||
|
||||
auto axes = normalize_axes(squeeze->description(), const_axes->cast_vector<int64_t>(), squeeze->get_input_partial_shape(0).rank());
|
||||
std::sort(axes.begin(), axes.end());
|
||||
auto begin = std::dynamic_pointer_cast<ngraph::opset5::Constant>(slice->input_value(1).get_node_shared_ptr());
|
||||
auto end = std::dynamic_pointer_cast<ngraph::opset5::Constant>(slice->input_value(2).get_node_shared_ptr());
|
||||
auto strides = std::dynamic_pointer_cast<ngraph::opset5::Constant>(slice->input_value(3).get_node_shared_ptr());
|
||||
@ -110,6 +107,14 @@ ngraph::pass::SqueezeStridedSlice::SqueezeStridedSlice() {
|
||||
auto shrink_axis_mask = slice->get_shrink_axis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0) : slice->get_shrink_axis_mask();
|
||||
auto ellipsis_mask = slice->get_ellipsis_mask().empty() ? std::vector<int64_t>(begin_mask.size(), 0) : slice->get_ellipsis_mask();
|
||||
|
||||
auto is_zero_vec = [](const std::vector<int64_t> & mask){ return std::all_of(mask.begin(), mask.end(), [](const int64_t& i){ return i == 0; }); };
|
||||
if (!is_zero_vec(new_axis_mask) || !is_zero_vec(shrink_axis_mask) || !is_zero_vec(ellipsis_mask))
|
||||
return false;
|
||||
if (!std::all_of(strides_vec.begin(), strides_vec.end(), [](const int64_t& i){ return i == 1; }))
|
||||
return false;
|
||||
|
||||
auto axes = normalize_axes(squeeze->description(), const_axes->cast_vector<int64_t>(), squeeze->get_input_partial_shape(0).rank());
|
||||
std::sort(axes.begin(), axes.end());
|
||||
for (const auto & axis : axes) {
|
||||
begin_vec.insert(begin_vec.begin() + axis, 0);
|
||||
end_vec.insert(end_vec.begin() + axis, 1);
|
||||
@ -133,13 +138,13 @@ ngraph::pass::SqueezeStridedSlice::SqueezeStridedSlice() {
|
||||
copy_runtime_info(slice, new_slice);
|
||||
return true;
|
||||
};
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(ss_label, "SqueezeStridedSlice");
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(ss_label, "ngraph::pass::SqueezeStridedSlice");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::SharedSqueeze, "SharedSqueeze", 0);
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::SharedSqueeze, "ngraph::pass::SharedSqueeze", 0);
|
||||
|
||||
bool squeezes_perform_the_same(std::shared_ptr<ngraph::opset1::Squeeze> lhs, std::shared_ptr<ngraph::opset1::Squeeze> rhs) {
|
||||
bool squeezes_perform_the_same(std::shared_ptr<ngraph::opset5::Squeeze> lhs, std::shared_ptr<ngraph::opset5::Squeeze> rhs) {
|
||||
size_t l_input_size = lhs->inputs().size(), r_input_size = rhs->inputs().size();
|
||||
if (l_input_size != r_input_size)
|
||||
return false;
|
||||
@ -148,8 +153,8 @@ bool squeezes_perform_the_same(std::shared_ptr<ngraph::opset1::Squeeze> lhs, std
|
||||
const auto rank = lhs->get_input_partial_shape(0).rank();
|
||||
if (rank.is_dynamic())
|
||||
return false;
|
||||
const auto l_axes = std::dynamic_pointer_cast<ngraph::opset1::Constant>(lhs->get_input_node_shared_ptr(1));
|
||||
const auto r_axes = std::dynamic_pointer_cast<ngraph::opset1::Constant>(rhs->get_input_node_shared_ptr(1));
|
||||
const auto l_axes = std::dynamic_pointer_cast<ngraph::opset5::Constant>(lhs->get_input_node_shared_ptr(1));
|
||||
const auto r_axes = std::dynamic_pointer_cast<ngraph::opset5::Constant>(rhs->get_input_node_shared_ptr(1));
|
||||
if (l_axes && r_axes)
|
||||
return normalize_axes(lhs->description(), l_axes->cast_vector<int64_t>(), rank) ==
|
||||
normalize_axes(rhs->description(), r_axes->cast_vector<int64_t>(), rank);
|
||||
@ -161,7 +166,7 @@ bool ngraph::pass::SharedSqueeze::run_on_function(std::shared_ptr<ngraph::Functi
|
||||
|
||||
bool graph_rewritten = false;
|
||||
|
||||
std::map<ngraph::Output<Node>, std::vector<std::shared_ptr<ngraph::opset1::Squeeze>>> source_to_squeeze;
|
||||
std::map<ngraph::Output<Node>, std::vector<std::shared_ptr<ngraph::opset5::Squeeze>>> source_to_squeeze;
|
||||
for (const auto & node : f->get_ordered_ops()) {
|
||||
// Recursively apply transformation for sub-graph based operations
|
||||
if (auto sub_graph_node = std::dynamic_pointer_cast<op::util::SubGraphOp>(node)) {
|
||||
@ -169,7 +174,7 @@ bool ngraph::pass::SharedSqueeze::run_on_function(std::shared_ptr<ngraph::Functi
|
||||
graph_rewritten |= run_on_function(sub_graph);
|
||||
}
|
||||
}
|
||||
if (auto squeeze = std::dynamic_pointer_cast<ngraph::opset1::Squeeze>(node)) {
|
||||
if (auto squeeze = std::dynamic_pointer_cast<ngraph::opset5::Squeeze>(node)) {
|
||||
source_to_squeeze[squeeze->input_value(0)].push_back(squeeze);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user