Moved AlgebraicSimplification and NopElimination passes to IE (#1859)

* Moved AlgebraicSimplification and NopElimination passes to IE

* Fixed headerfiles
This commit is contained in:
Gleb Kazantaev
2020-08-19 18:55:11 +03:00
committed by GitHub
parent 016c696869
commit 301d6b50e3
12 changed files with 302 additions and 534 deletions

View File

@@ -1,34 +0,0 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
#include "ngraph/util.hpp"
namespace ngraph
{
namespace pass
{
class AlgebraicSimplification;
}
}
class NGRAPH_API ngraph::pass::AlgebraicSimplification : public FunctionPass
{
public:
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
};

View File

@@ -1,31 +0,0 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "ngraph/pass/pass.hpp"
namespace ngraph
{
namespace pass
{
class NGRAPH_API NopElimination : public FunctionPass
{
public:
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
};
}
}

View File

@@ -1,309 +0,0 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <memory>
#include <numeric>
#include <set>
#include "algebraic_simplification.hpp"
#include "ngraph/axis_vector.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/divide.hpp"
#include "ngraph/op/exp.hpp"
#include "ngraph/op/gather.hpp"
#include "ngraph/op/log.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/product.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/shape_of.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/transpose.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/opsets/opset2.hpp"
#include "ngraph/opsets/opset3.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/rt_info.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
//`simplify_gather`, optimizes gather if Gather is gathering the
// whole input tensor
static bool simplify_gather(std::shared_ptr<Node> node)
{
if (auto gather = as_type_ptr<opset3::Gather>(node))
{
// check if we are gathering the whole input
auto data = gather->input_value(0);
auto indices = gather->input_value(1);
// we need to know data and indices shape to infer if gather is Nop
if (data.get_partial_shape().is_dynamic() || indices.get_partial_shape().is_dynamic())
{
return false;
}
// if rank of data and gather output dont match, we will skip
if (data.get_shape().size() != node->get_shape().size())
{
return false;
}
auto axis = gather->get_axis();
if (axis == opset3::Gather::AXIS_NOT_SET_VALUE)
{
NGRAPH_DEBUG << "axis value not set";
return false;
}
// case_1 : if the input tensor is of shape (4, 1, 4)
// and axis = 1, then the gather would be simply
// gathering the whole input tensor, so we can optimize this
// op has Nop
if (data.get_shape()[axis] == 1 && data.get_shape() == node->get_shape())
{
return replace_output_update_name(gather->output(0), gather->input_value(0));
}
// case_2 : if the input tensor is of shape (4, 3, 4)
// we need to check the contents of indices, if indices
// is 1D tensor of value {0, 1, 2}, we can optimize this
// op has Nop
// check if the indices is constant
auto constant_indices =
as_type_ptr<opset3::Constant>(gather->input_value(1).get_node_shared_ptr());
if (!constant_indices)
{
return false;
}
else
{
// if ref_inidices == indices, we are capturing the
// entire input tensor
std::vector<int64_t> ref_indices(data.get_shape()[axis], 0);
std::iota(ref_indices.begin(), ref_indices.end(), 0);
if (ref_indices == constant_indices->cast_vector<int64_t>())
{
return replace_output_update_name(gather->output(0), gather->input_value(0));
}
}
}
return false;
}
// optimizes `gather->shapeof` into `shapeof->gather` for 0D indices
// other cases into Concat of shapeof/gather(data) + shapeof(indices)
static bool simplify_gather_shapeof(shared_ptr<Node> node)
{
auto gather = as_type_ptr<opset3::Gather>(node->input_value(0).get_node_shared_ptr());
if (!gather)
{
return false;
}
auto gather_in_rank = gather->get_input_partial_shape(0).rank();
auto indices_rank = gather->get_input_partial_shape(1).rank();
auto axis = gather->get_axis();
if (gather_in_rank.is_dynamic() || indices_rank.is_dynamic() ||
axis == opset3::Gather::AXIS_NOT_SET_VALUE)
{
NGRAPH_DEBUG << gather << " cannot simplify gather->shapeof";
return false;
}
auto zero_axis = opset3::Constant::create<int64_t>(element::i64, Shape{}, {0});
NodeVector new_ops;
auto new_shapeof = make_shared<opset3::ShapeOf>(gather->input_value(0));
new_ops.push_back(new_shapeof);
std::shared_ptr<Node> replace_op;
if (indices_rank.get_length() == 0)
{
std::vector<int64_t> vi(gather_in_rank.get_length());
std::iota(vi.begin(), vi.end(), 0);
vi.erase(vi.begin() + axis);
auto new_indices = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
replace_op = make_shared<opset3::Gather>(new_shapeof, new_indices, zero_axis);
new_ops.push_back(replace_op);
}
else
{
NodeVector concat_inputs;
if (axis > 0)
{
std::vector<int64_t> vi(axis);
std::iota(vi.begin(), vi.end(), 0);
auto indices = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
auto gather = make_shared<opset3::Gather>(new_shapeof, indices, zero_axis);
new_ops.push_back(gather);
concat_inputs.push_back(gather);
}
auto shapeof_indices = make_shared<opset3::ShapeOf>(gather->input_value(1));
new_ops.push_back(shapeof_indices);
concat_inputs.push_back(shapeof_indices);
if (gather_in_rank.get_length() - 1 > axis)
{
std::vector<int64_t> vi(gather_in_rank.get_length() - (axis + 1));
std::iota(vi.begin(), vi.end(), axis + 1);
auto indices = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
auto gather = make_shared<opset3::Gather>(new_shapeof, indices, zero_axis);
new_ops.push_back(gather);
concat_inputs.push_back(gather);
}
replace_op = make_shared<opset3::Concat>(concat_inputs, 0);
new_ops.push_back(replace_op);
}
replace_op->set_friendly_name(node->get_friendly_name());
copy_runtime_info(node, new_ops);
replace_node(node, replace_op);
return true;
}
static bool replace_transpose_with_reshape(shared_ptr<Node> transpose)
{
auto data = transpose->input_value(0);
const auto input_shape = transpose->input(0).get_partial_shape();
if (input_shape.rank().is_dynamic())
{
return false;
}
const auto input_shape_rank = input_shape.rank().get_length();
auto order = as_type_ptr<opset3::Constant>(transpose->input_value(1).get_node_shared_ptr());
if (!order)
{
return false;
}
const auto order_value = order->cast_vector<int64_t>();
// Check that transpose order without 1 dims has an ascending order
int64_t last_dim(-1);
for (size_t i = 0; i < input_shape_rank; ++i)
{
if (input_shape[order_value[i]].is_dynamic() || input_shape[order_value[i]] != 1)
{
if (order_value[i] < last_dim)
{
return false;
}
last_dim = order_value[i];
}
}
// Transpose operation can be removed if original transpose order is sorted
// or dimension that changes their places equal to 1
using DimensionToPosition = struct
{
Dimension dim;
size_t pos;
};
std::vector<DimensionToPosition> dims;
for (size_t i = 0; i < input_shape_rank; ++i)
{
if (order_value[i] != i)
{
dims.push_back({input_shape[order_value[i]], i});
}
}
// If number of dimensions != 1 to move equal to 0 we can remove this Transpose
if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) {
return !(item.dim.is_static() && item.dim.get_length() == 1);
}) == 0)
{
return replace_output_update_name(transpose->output(0), transpose->input_value(0));
}
// Transpose can be replaced with Reshape in two ways:
// 1. Reshape with dims as Constant
// 2. Reshape with dims as input (ShapeOf->Gather)
//
// The first case is possible only if one or less dynamic dimensions changes their position
// For example: input_shape {?, 3, 1, ?} and order {0, 1, 3, 2} can be replaced with Reshape
// with Constant {0, 3, -1, 1} but if input_shape {?, 1, 1, ?} and order {1, 0, 3, 2} transpose
// cannot be replaced int the same way and in this case its only possible to use Gather(ShapeOf,
// order)
Output<Node> reshape_dim;
NodeVector new_ops;
if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) {
return item.dim.is_dynamic();
}) < 2)
{
vector<int64_t> reshape_value(input_shape_rank, 0);
for (const auto& item : dims)
{
reshape_value[item.pos] = item.dim.is_dynamic() ? -1 : item.dim.get_length();
}
reshape_dim =
opset3::Constant::create(element::i64, Shape{reshape_value.size()}, reshape_value);
}
else
{
auto shape_of = make_shared<opset3::ShapeOf>(data);
new_ops.push_back(shape_of);
reshape_dim = make_shared<opset3::Gather>(
shape_of, order, opset3::Constant::create(element::i64, Shape{1}, {0}));
new_ops.push_back(reshape_dim.get_node_shared_ptr());
}
auto reshape_op = make_shared<opset3::Reshape>(data, reshape_dim, true);
new_ops.push_back(reshape_op);
reshape_op->set_friendly_name(transpose->get_friendly_name());
copy_runtime_info(transpose, new_ops);
replace_node(transpose, reshape_op);
return true;
}
bool pass::AlgebraicSimplification::run_on_function(shared_ptr<Function> f)
{
static const unordered_map<NodeTypeInfo, function<bool(shared_ptr<Node>)>> ops_to_simplifiers =
{{opset3::Gather::type_info, simplify_gather},
{opset2::ShapeOf::type_info, simplify_gather_shapeof},
{opset3::ShapeOf::type_info, simplify_gather_shapeof},
{opset3::Transpose::type_info, replace_transpose_with_reshape}};
bool replaced = false;
for (auto n : f->get_ordered_ops())
{
if (op::is_output(n) || op::is_parameter(n))
{
continue;
}
auto eh = ops_to_simplifiers.find(n->get_type_info());
if (eh != ops_to_simplifiers.end())
{
replaced = eh->second(n) || replaced;
}
}
return replaced;
}

View File

@@ -1,451 +0,0 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <functional>
#include <memory>
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/non_zero.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/shape_of.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/squeeze.hpp"
#include "ngraph/op/stop_gradient.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/unsqueeze.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/opsets/opset3.hpp"
#include "ngraph/util.hpp"
#include "nop_elimination.hpp"
using namespace std;
using namespace ngraph;
#define TI(x) x::type_info
static bool eliminate_nop(const std::shared_ptr<Node>& node)
{
// skip if shapes are dynamic
if (node->get_input_partial_shape(0).is_dynamic() ||
node->get_output_partial_shape(0).is_dynamic())
{
return false;
}
if (node->get_input_shape(0) == node->get_output_shape(0))
{
return replace_output_update_name(node->output(0), node->input_value(0));
}
return false;
}
static bool eliminate_sum(const std::shared_ptr<Node>& node)
{
auto sum = as_type_ptr<op::v0::Sum>(node);
if (sum->get_reduction_axes().empty())
{
return replace_output_update_name(node->output(0), node->input_value(0));
}
return false;
}
static bool eliminate_convert(const std::shared_ptr<Node>& node)
{
bool is_out_type_agnostic = false;
static const std::set<NodeTypeInfo> type_agnostic{TI(opset3::NonZero)};
if (node->output(0).get_target_inputs().size() == 1)
{
Input<Node> out = *node->output(0).get_target_inputs().begin();
is_out_type_agnostic = type_agnostic.count(out.get_node()->get_type_info()) == 1;
}
auto convert = as_type_ptr<opset3::Convert>(node);
auto input = convert->input_value(0);
if (convert->get_convert_element_type() == input.get_element_type() || is_out_type_agnostic)
{
if (is_out_type_agnostic && is_type<opset3::Convert>(input.get_node()))
{
input = input.get_node()->input_value(0);
}
return replace_output_update_name(node->output(0), input);
}
return false;
}
static bool eliminate_concat(const std::shared_ptr<Node>& node)
{
auto node_input = node->input_value(0);
// remove concat with single input
if (node->get_input_size() == 1)
{
return replace_output_update_name(node->output(0), node_input);
}
return false;
}
static bool eliminate_reshape_v1(const std::shared_ptr<Node>& node)
{
auto input = node->input_value(0);
// check if reshape is not identity op
if (input.get_partial_shape().is_dynamic() || node->get_output_partial_shape(0).is_dynamic())
{
NGRAPH_DEBUG << node << " has dynamic shapes.";
return false;
}
// remove identity op
if (input.get_shape() == node->get_output_shape(0))
{
return replace_output_update_name(node->output(0), input);
}
// eliminate redundant reshape, squeeze, or unsqueeze
if (is_type<opset3::Squeeze>(input.get_node()) ||
is_type<opset3::Unsqueeze>(input.get_node()) || is_type<opset3::Reshape>(input.get_node()))
{
auto shape = node->get_output_shape(0);
std::vector<int64_t> vi;
vi.assign(shape.begin(), shape.end());
auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
auto new_reshape =
make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
return replace_node_update_name(node, new_reshape);
}
return false;
}
static size_t count_unknown_dims(const PartialShape& ps)
{
size_t rc = 0;
if (ps.is_static())
{
return rc;
}
for (auto i = 0; i < ps.rank().get_length(); i++)
{
if (ps[i].is_dynamic())
{
rc += 1;
}
}
return rc;
}
static bool replace_squeeze_unsqueeze(const std::shared_ptr<Node>& node)
{
auto shape_ps = node->get_output_partial_shape(0);
if (shape_ps.rank().get_length() == 0)
{
return false;
}
if (count_unknown_dims(shape_ps) > 1)
{
return false;
}
std::vector<int64_t> target_shape;
for (auto i = 0; i < shape_ps.rank().get_length(); i++)
{
if (shape_ps[i].is_dynamic())
{
target_shape.emplace_back(-1);
}
else
{
target_shape.emplace_back(shape_ps[i].get_length());
}
}
shared_ptr<Node> reshape;
auto input = node->input_value(0).get_node_shared_ptr();
auto pat =
opset3::Constant::create<int64_t>(element::i64, Shape{target_shape.size()}, target_shape);
if (is_type<opset3::Reshape>(input) || is_type<opset3::Squeeze>(input) ||
is_type<opset3::Unsqueeze>(input))
{
reshape = make_shared<opset3::Reshape>(input->input_value(0), pat, false);
}
else
{
reshape = make_shared<opset3::Reshape>(node->input_value(0), pat, false);
}
// skip if reshape is nop
if (reshape->get_input_partial_shape(0).same_scheme(shape_ps))
{
return replace_output_update_name(node->output(0), reshape->input_value(0));
}
else
{
return replace_node_update_name(node, reshape);
}
return false;
}
static std::vector<int64_t> get_unsqueeze_axes(const PartialShape& data_shape,
const PartialShape& out_shape)
{
std::vector<int64_t> axes;
size_t i = 0;
for (auto o = 0; o < out_shape.rank().get_length(); o++)
{
if (i < data_shape.rank().get_length() && data_shape[i].same_scheme(out_shape[o]))
{
i += 1;
continue;
}
if (out_shape[o].is_static() && out_shape[o] == 1)
{
axes.push_back(o);
}
}
return axes;
}
static std::vector<int64_t> get_squeeze_axes(const PartialShape& data_shape,
const PartialShape& out_shape)
{
std::vector<int64_t> axes;
size_t out_i = 0;
for (auto i = 0; i < data_shape.rank().get_length(); i++)
{
if (out_i < out_shape.rank().get_length() && data_shape[i].same_scheme(out_shape[out_i]))
{
out_i += 1;
continue;
}
if (data_shape[i].is_static() && data_shape[i] == 1)
{
axes.push_back(i);
}
}
return axes;
}
static bool eliminate_unsqueeze(const std::shared_ptr<Node>& node)
{
auto out_shape = node->get_output_partial_shape(0);
// try to replace all squeeze/unsqueeze with reshape
if (out_shape.rank().get_length() != 0 && count_unknown_dims(out_shape) < 2)
{
return replace_squeeze_unsqueeze(node);
}
auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(node);
auto input = unsqueeze->input_value(0).get_node_shared_ptr();
auto squeeze = as_type_ptr<opset3::Squeeze>(input);
auto replace_unsqueeze_only = [&](const vector<int64_t>& axes) {
auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
auto new_unsq = make_shared<opset3::Unsqueeze>(input->input_value(0), axes_const);
if (unsqueeze->get_output_partial_shape(0).same_scheme(
new_unsq->get_output_partial_shape(0)))
{
return replace_node_update_name(unsqueeze, new_unsq);
}
return false;
};
// eliminate redundant squeeze->unsqueeze
if (squeeze)
{
const auto& data_shape = squeeze->input_value(0).get_partial_shape();
if (ngraph::compare_constants(squeeze->input_value(1).get_node_shared_ptr(),
unsqueeze->input_value(1).get_node_shared_ptr()))
{
return replace_output_update_name(unsqueeze->output(0), squeeze->input_value(0));
}
if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic())
{
return false;
}
if (out_shape.rank().get_length() > data_shape.rank().get_length())
{
// check if single unsqueeze can handle this
auto axes = get_unsqueeze_axes(data_shape, out_shape);
if (axes.size() + data_shape.rank().get_length() == out_shape.rank().get_length())
{
return replace_unsqueeze_only(axes);
}
}
if (out_shape.rank().get_length() < data_shape.rank().get_length())
{
// check if single squeeze can handle this
auto axes = get_squeeze_axes(data_shape, out_shape);
if (data_shape.rank().get_length() - axes.size() == out_shape.rank().get_length())
{
auto axes_const =
opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
auto new_sq = make_shared<opset3::Squeeze>(input->input_value(0), axes_const);
if (unsqueeze->get_output_partial_shape(0).same_scheme(
new_sq->get_output_partial_shape(0)))
{
return replace_node_update_name(unsqueeze, new_sq);
}
return false;
}
}
return false;
}
// eliminate redundant unsqueeze->unsqueeze
auto unsqueeze_i = as_type_ptr<opset3::Unsqueeze>(input);
if (unsqueeze_i)
{
const auto& data_shape = unsqueeze_i->input_value(0).get_partial_shape();
if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic())
{
return false;
}
auto axes = get_unsqueeze_axes(data_shape, out_shape);
return replace_unsqueeze_only(axes);
}
return false;
}
static bool eliminate_squeeze(const std::shared_ptr<Node>& node)
{
auto out_shape = node->get_output_partial_shape(0);
// try to replace all unsqueeze/squeeze with reshape
if (out_shape.rank().get_length() != 0 && count_unknown_dims(out_shape) < 2)
{
return replace_squeeze_unsqueeze(node);
}
auto squeeze = as_type_ptr<opset3::Squeeze>(node);
auto input = squeeze->input_value(0).get_node_shared_ptr();
auto replace_squeeze_only = [&](const vector<int64_t>& axes) {
auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
auto new_sq = make_shared<opset3::Squeeze>(input->input_value(0), axes_const);
if (squeeze->get_output_partial_shape(0).same_scheme(new_sq->get_output_partial_shape(0)))
{
return replace_node_update_name(squeeze, new_sq);
}
return false;
};
// eliminate redundant unsqueeze->squeeze
if (auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(input))
{
PartialShape data_shape;
if (op::is_parameter(input))
{
data_shape = unsqueeze->input(0).get_partial_shape();
}
else
{
data_shape = input->input(0).get_partial_shape();
}
if (ngraph::compare_constants(unsqueeze->input_value(1).get_node_shared_ptr(),
squeeze->input_value(1).get_node_shared_ptr()))
{
return replace_output_update_name(squeeze->output(0), unsqueeze->input_value(0));
}
if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic())
{
return false;
}
if (out_shape.rank().get_length() < data_shape.rank().get_length())
{
// check if single squeeze can handle this
auto axes = get_squeeze_axes(data_shape, out_shape);
if (data_shape.rank().get_length() == out_shape.rank().get_length() + axes.size())
{
return replace_squeeze_only(axes);
}
}
if (out_shape.rank().get_length() > data_shape.rank().get_length())
{
// check if single unsqueeze can handle this
auto axes = get_unsqueeze_axes(data_shape, out_shape);
if (data_shape.rank().get_length() + axes.size() == out_shape.rank().get_length())
{
auto axes_const =
opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
auto new_unsq = make_shared<opset3::Unsqueeze>(input->input_value(0), axes_const);
if (squeeze->get_output_partial_shape(0).same_scheme(
new_unsq->get_output_partial_shape(0)))
{
replace_output_update_name(squeeze, new_unsq);
return true;
}
}
}
return false;
}
// eliminate redundant squeeze->squeeze
if (auto squeeze_i = as_type_ptr<opset3::Squeeze>(input))
{
PartialShape data_shape;
if (op::is_parameter(input))
{
data_shape = squeeze_i->input(0).get_partial_shape();
}
else
{
data_shape = input->input(0).get_partial_shape();
}
if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic())
{
return false;
}
auto axes = get_squeeze_axes(data_shape, out_shape);
return replace_squeeze_only(axes);
}
return false;
}
static bool eliminate_stop_gradient(const std::shared_ptr<Node>& node)
{
replace_output_update_name(node->output(0), node->input_value(0));
return true;
}
bool pass::NopElimination::run_on_function(std::shared_ptr<Function> function)
{
static const std::unordered_map<NodeTypeInfo, std::function<bool(const std::shared_ptr<Node>&)>>
dispatcher{{TI(op::v0::Pad), &eliminate_nop},
{TI(opset3::Pad), &eliminate_nop},
{TI(op::v0::Sum), &eliminate_sum},
{TI(opset3::Convert), &eliminate_convert},
{TI(op::v0::Slice), &eliminate_nop},
{TI(op::v0::StopGradient), &eliminate_stop_gradient},
{TI(opset3::Reshape), &eliminate_reshape_v1},
{TI(opset3::Concat), &eliminate_concat},
{TI(opset3::Squeeze), &eliminate_squeeze},
{TI(opset3::Unsqueeze), &eliminate_unsqueeze},
{TI(op::v0::Broadcast), &eliminate_nop}};
bool clobbered = false;
for (const auto& n : function->get_ops())
{
// Work around a warning [-Wpotentially-evaluated-expression]
const Node& node = *n;
auto handler = dispatcher.find(node.get_type_info());
if (handler != dispatcher.end())
{
clobbered = handler->second(n) || clobbered;
}
}
return clobbered;
}