Moved AlgebraicSimplification and NopElimination passes to IE (#1859)
* Moved AlgebraicSimplification and NopElimination passes to IE * Fixed headerfiles
This commit is contained in:
@@ -1,34 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class AlgebraicSimplification;
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::AlgebraicSimplification : public FunctionPass
|
||||
{
|
||||
public:
|
||||
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
|
||||
};
|
||||
@@ -1,31 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class NGRAPH_API NopElimination : public FunctionPass
|
||||
{
|
||||
public:
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -1,309 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
|
||||
#include "algebraic_simplification.hpp"
|
||||
#include "ngraph/axis_vector.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/divide.hpp"
|
||||
#include "ngraph/op/exp.hpp"
|
||||
#include "ngraph/op/gather.hpp"
|
||||
#include "ngraph/op/log.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/op/product.hpp"
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/op/shape_of.hpp"
|
||||
#include "ngraph/op/slice.hpp"
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/op/transpose.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/opsets/opset2.hpp"
|
||||
#include "ngraph/opsets/opset3.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
#include "ngraph/rt_info.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
//`simplify_gather`, optimizes gather if Gather is gathering the
|
||||
// whole input tensor
|
||||
static bool simplify_gather(std::shared_ptr<Node> node)
|
||||
{
|
||||
if (auto gather = as_type_ptr<opset3::Gather>(node))
|
||||
{
|
||||
// check if we are gathering the whole input
|
||||
auto data = gather->input_value(0);
|
||||
auto indices = gather->input_value(1);
|
||||
|
||||
// we need to know data and indices shape to infer if gather is Nop
|
||||
if (data.get_partial_shape().is_dynamic() || indices.get_partial_shape().is_dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// if rank of data and gather output dont match, we will skip
|
||||
if (data.get_shape().size() != node->get_shape().size())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
auto axis = gather->get_axis();
|
||||
|
||||
if (axis == opset3::Gather::AXIS_NOT_SET_VALUE)
|
||||
{
|
||||
NGRAPH_DEBUG << "axis value not set";
|
||||
return false;
|
||||
}
|
||||
|
||||
// case_1 : if the input tensor is of shape (4, 1, 4)
|
||||
// and axis = 1, then the gather would be simply
|
||||
// gathering the whole input tensor, so we can optimize this
|
||||
// op has Nop
|
||||
|
||||
if (data.get_shape()[axis] == 1 && data.get_shape() == node->get_shape())
|
||||
{
|
||||
return replace_output_update_name(gather->output(0), gather->input_value(0));
|
||||
}
|
||||
|
||||
// case_2 : if the input tensor is of shape (4, 3, 4)
|
||||
// we need to check the contents of indices, if indices
|
||||
// is 1D tensor of value {0, 1, 2}, we can optimize this
|
||||
// op has Nop
|
||||
|
||||
// check if the indices is constant
|
||||
auto constant_indices =
|
||||
as_type_ptr<opset3::Constant>(gather->input_value(1).get_node_shared_ptr());
|
||||
if (!constant_indices)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
// if ref_inidices == indices, we are capturing the
|
||||
// entire input tensor
|
||||
std::vector<int64_t> ref_indices(data.get_shape()[axis], 0);
|
||||
std::iota(ref_indices.begin(), ref_indices.end(), 0);
|
||||
if (ref_indices == constant_indices->cast_vector<int64_t>())
|
||||
{
|
||||
return replace_output_update_name(gather->output(0), gather->input_value(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// optimizes `gather->shapeof` into `shapeof->gather` for 0D indices
|
||||
// other cases into Concat of shapeof/gather(data) + shapeof(indices)
|
||||
static bool simplify_gather_shapeof(shared_ptr<Node> node)
|
||||
{
|
||||
auto gather = as_type_ptr<opset3::Gather>(node->input_value(0).get_node_shared_ptr());
|
||||
if (!gather)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
auto gather_in_rank = gather->get_input_partial_shape(0).rank();
|
||||
auto indices_rank = gather->get_input_partial_shape(1).rank();
|
||||
auto axis = gather->get_axis();
|
||||
if (gather_in_rank.is_dynamic() || indices_rank.is_dynamic() ||
|
||||
axis == opset3::Gather::AXIS_NOT_SET_VALUE)
|
||||
{
|
||||
NGRAPH_DEBUG << gather << " cannot simplify gather->shapeof";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto zero_axis = opset3::Constant::create<int64_t>(element::i64, Shape{}, {0});
|
||||
NodeVector new_ops;
|
||||
auto new_shapeof = make_shared<opset3::ShapeOf>(gather->input_value(0));
|
||||
new_ops.push_back(new_shapeof);
|
||||
std::shared_ptr<Node> replace_op;
|
||||
if (indices_rank.get_length() == 0)
|
||||
{
|
||||
std::vector<int64_t> vi(gather_in_rank.get_length());
|
||||
std::iota(vi.begin(), vi.end(), 0);
|
||||
vi.erase(vi.begin() + axis);
|
||||
auto new_indices = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
|
||||
replace_op = make_shared<opset3::Gather>(new_shapeof, new_indices, zero_axis);
|
||||
new_ops.push_back(replace_op);
|
||||
}
|
||||
else
|
||||
{
|
||||
NodeVector concat_inputs;
|
||||
if (axis > 0)
|
||||
{
|
||||
std::vector<int64_t> vi(axis);
|
||||
std::iota(vi.begin(), vi.end(), 0);
|
||||
auto indices = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
|
||||
auto gather = make_shared<opset3::Gather>(new_shapeof, indices, zero_axis);
|
||||
new_ops.push_back(gather);
|
||||
concat_inputs.push_back(gather);
|
||||
}
|
||||
auto shapeof_indices = make_shared<opset3::ShapeOf>(gather->input_value(1));
|
||||
new_ops.push_back(shapeof_indices);
|
||||
|
||||
concat_inputs.push_back(shapeof_indices);
|
||||
|
||||
if (gather_in_rank.get_length() - 1 > axis)
|
||||
{
|
||||
std::vector<int64_t> vi(gather_in_rank.get_length() - (axis + 1));
|
||||
std::iota(vi.begin(), vi.end(), axis + 1);
|
||||
auto indices = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
|
||||
auto gather = make_shared<opset3::Gather>(new_shapeof, indices, zero_axis);
|
||||
new_ops.push_back(gather);
|
||||
concat_inputs.push_back(gather);
|
||||
}
|
||||
replace_op = make_shared<opset3::Concat>(concat_inputs, 0);
|
||||
new_ops.push_back(replace_op);
|
||||
}
|
||||
replace_op->set_friendly_name(node->get_friendly_name());
|
||||
copy_runtime_info(node, new_ops);
|
||||
replace_node(node, replace_op);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool replace_transpose_with_reshape(shared_ptr<Node> transpose)
|
||||
{
|
||||
auto data = transpose->input_value(0);
|
||||
const auto input_shape = transpose->input(0).get_partial_shape();
|
||||
if (input_shape.rank().is_dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto input_shape_rank = input_shape.rank().get_length();
|
||||
|
||||
auto order = as_type_ptr<opset3::Constant>(transpose->input_value(1).get_node_shared_ptr());
|
||||
if (!order)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto order_value = order->cast_vector<int64_t>();
|
||||
|
||||
// Check that transpose order without 1 dims has an ascending order
|
||||
int64_t last_dim(-1);
|
||||
for (size_t i = 0; i < input_shape_rank; ++i)
|
||||
{
|
||||
if (input_shape[order_value[i]].is_dynamic() || input_shape[order_value[i]] != 1)
|
||||
{
|
||||
if (order_value[i] < last_dim)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
last_dim = order_value[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Transpose operation can be removed if original transpose order is sorted
|
||||
// or dimension that changes their places equal to 1
|
||||
using DimensionToPosition = struct
|
||||
{
|
||||
Dimension dim;
|
||||
size_t pos;
|
||||
};
|
||||
std::vector<DimensionToPosition> dims;
|
||||
for (size_t i = 0; i < input_shape_rank; ++i)
|
||||
{
|
||||
if (order_value[i] != i)
|
||||
{
|
||||
dims.push_back({input_shape[order_value[i]], i});
|
||||
}
|
||||
}
|
||||
|
||||
// If number of dimensions != 1 to move equal to 0 we can remove this Transpose
|
||||
if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) {
|
||||
return !(item.dim.is_static() && item.dim.get_length() == 1);
|
||||
}) == 0)
|
||||
{
|
||||
return replace_output_update_name(transpose->output(0), transpose->input_value(0));
|
||||
}
|
||||
|
||||
// Transpose can be replaced with Reshape in two ways:
|
||||
// 1. Reshape with dims as Constant
|
||||
// 2. Reshape with dims as input (ShapeOf->Gather)
|
||||
//
|
||||
// The first case is possible only if one or less dynamic dimensions changes their position
|
||||
// For example: input_shape {?, 3, 1, ?} and order {0, 1, 3, 2} can be replaced with Reshape
|
||||
// with Constant {0, 3, -1, 1} but if input_shape {?, 1, 1, ?} and order {1, 0, 3, 2} transpose
|
||||
// cannot be replaced int the same way and in this case its only possible to use Gather(ShapeOf,
|
||||
// order)
|
||||
|
||||
Output<Node> reshape_dim;
|
||||
NodeVector new_ops;
|
||||
|
||||
if (count_if(dims.begin(), dims.end(), [](const DimensionToPosition& item) {
|
||||
return item.dim.is_dynamic();
|
||||
}) < 2)
|
||||
{
|
||||
vector<int64_t> reshape_value(input_shape_rank, 0);
|
||||
for (const auto& item : dims)
|
||||
{
|
||||
reshape_value[item.pos] = item.dim.is_dynamic() ? -1 : item.dim.get_length();
|
||||
}
|
||||
reshape_dim =
|
||||
opset3::Constant::create(element::i64, Shape{reshape_value.size()}, reshape_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto shape_of = make_shared<opset3::ShapeOf>(data);
|
||||
new_ops.push_back(shape_of);
|
||||
reshape_dim = make_shared<opset3::Gather>(
|
||||
shape_of, order, opset3::Constant::create(element::i64, Shape{1}, {0}));
|
||||
new_ops.push_back(reshape_dim.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
auto reshape_op = make_shared<opset3::Reshape>(data, reshape_dim, true);
|
||||
new_ops.push_back(reshape_op);
|
||||
|
||||
reshape_op->set_friendly_name(transpose->get_friendly_name());
|
||||
copy_runtime_info(transpose, new_ops);
|
||||
replace_node(transpose, reshape_op);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pass::AlgebraicSimplification::run_on_function(shared_ptr<Function> f)
|
||||
{
|
||||
static const unordered_map<NodeTypeInfo, function<bool(shared_ptr<Node>)>> ops_to_simplifiers =
|
||||
{{opset3::Gather::type_info, simplify_gather},
|
||||
{opset2::ShapeOf::type_info, simplify_gather_shapeof},
|
||||
{opset3::ShapeOf::type_info, simplify_gather_shapeof},
|
||||
{opset3::Transpose::type_info, replace_transpose_with_reshape}};
|
||||
|
||||
bool replaced = false;
|
||||
for (auto n : f->get_ordered_ops())
|
||||
{
|
||||
if (op::is_output(n) || op::is_parameter(n))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
auto eh = ops_to_simplifiers.find(n->get_type_info());
|
||||
if (eh != ops_to_simplifiers.end())
|
||||
{
|
||||
replaced = eh->second(n) || replaced;
|
||||
}
|
||||
}
|
||||
return replaced;
|
||||
}
|
||||
@@ -1,451 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <typeindex>
|
||||
#include <typeinfo>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/convert.hpp"
|
||||
#include "ngraph/op/non_zero.hpp"
|
||||
#include "ngraph/op/pad.hpp"
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/op/shape_of.hpp"
|
||||
#include "ngraph/op/slice.hpp"
|
||||
#include "ngraph/op/squeeze.hpp"
|
||||
#include "ngraph/op/stop_gradient.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/op/unsqueeze.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/opsets/opset3.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "nop_elimination.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
#define TI(x) x::type_info
|
||||
|
||||
static bool eliminate_nop(const std::shared_ptr<Node>& node)
|
||||
{
|
||||
// skip if shapes are dynamic
|
||||
if (node->get_input_partial_shape(0).is_dynamic() ||
|
||||
node->get_output_partial_shape(0).is_dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node->get_input_shape(0) == node->get_output_shape(0))
|
||||
{
|
||||
return replace_output_update_name(node->output(0), node->input_value(0));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool eliminate_sum(const std::shared_ptr<Node>& node)
|
||||
{
|
||||
auto sum = as_type_ptr<op::v0::Sum>(node);
|
||||
if (sum->get_reduction_axes().empty())
|
||||
{
|
||||
return replace_output_update_name(node->output(0), node->input_value(0));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool eliminate_convert(const std::shared_ptr<Node>& node)
|
||||
{
|
||||
bool is_out_type_agnostic = false;
|
||||
static const std::set<NodeTypeInfo> type_agnostic{TI(opset3::NonZero)};
|
||||
if (node->output(0).get_target_inputs().size() == 1)
|
||||
{
|
||||
Input<Node> out = *node->output(0).get_target_inputs().begin();
|
||||
is_out_type_agnostic = type_agnostic.count(out.get_node()->get_type_info()) == 1;
|
||||
}
|
||||
auto convert = as_type_ptr<opset3::Convert>(node);
|
||||
auto input = convert->input_value(0);
|
||||
if (convert->get_convert_element_type() == input.get_element_type() || is_out_type_agnostic)
|
||||
{
|
||||
if (is_out_type_agnostic && is_type<opset3::Convert>(input.get_node()))
|
||||
{
|
||||
input = input.get_node()->input_value(0);
|
||||
}
|
||||
return replace_output_update_name(node->output(0), input);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool eliminate_concat(const std::shared_ptr<Node>& node)
|
||||
{
|
||||
auto node_input = node->input_value(0);
|
||||
|
||||
// remove concat with single input
|
||||
if (node->get_input_size() == 1)
|
||||
{
|
||||
return replace_output_update_name(node->output(0), node_input);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool eliminate_reshape_v1(const std::shared_ptr<Node>& node)
|
||||
{
|
||||
auto input = node->input_value(0);
|
||||
// check if reshape is not identity op
|
||||
if (input.get_partial_shape().is_dynamic() || node->get_output_partial_shape(0).is_dynamic())
|
||||
{
|
||||
NGRAPH_DEBUG << node << " has dynamic shapes.";
|
||||
return false;
|
||||
}
|
||||
// remove identity op
|
||||
if (input.get_shape() == node->get_output_shape(0))
|
||||
{
|
||||
return replace_output_update_name(node->output(0), input);
|
||||
}
|
||||
// eliminate redundant reshape, squeeze, or unsqueeze
|
||||
if (is_type<opset3::Squeeze>(input.get_node()) ||
|
||||
is_type<opset3::Unsqueeze>(input.get_node()) || is_type<opset3::Reshape>(input.get_node()))
|
||||
{
|
||||
auto shape = node->get_output_shape(0);
|
||||
std::vector<int64_t> vi;
|
||||
vi.assign(shape.begin(), shape.end());
|
||||
auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
|
||||
auto new_reshape =
|
||||
make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
|
||||
return replace_node_update_name(node, new_reshape);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static size_t count_unknown_dims(const PartialShape& ps)
|
||||
{
|
||||
size_t rc = 0;
|
||||
if (ps.is_static())
|
||||
{
|
||||
return rc;
|
||||
}
|
||||
for (auto i = 0; i < ps.rank().get_length(); i++)
|
||||
{
|
||||
if (ps[i].is_dynamic())
|
||||
{
|
||||
rc += 1;
|
||||
}
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
static bool replace_squeeze_unsqueeze(const std::shared_ptr<Node>& node)
|
||||
{
|
||||
auto shape_ps = node->get_output_partial_shape(0);
|
||||
if (shape_ps.rank().get_length() == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if (count_unknown_dims(shape_ps) > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
std::vector<int64_t> target_shape;
|
||||
for (auto i = 0; i < shape_ps.rank().get_length(); i++)
|
||||
{
|
||||
if (shape_ps[i].is_dynamic())
|
||||
{
|
||||
target_shape.emplace_back(-1);
|
||||
}
|
||||
else
|
||||
{
|
||||
target_shape.emplace_back(shape_ps[i].get_length());
|
||||
}
|
||||
}
|
||||
|
||||
shared_ptr<Node> reshape;
|
||||
auto input = node->input_value(0).get_node_shared_ptr();
|
||||
auto pat =
|
||||
opset3::Constant::create<int64_t>(element::i64, Shape{target_shape.size()}, target_shape);
|
||||
|
||||
if (is_type<opset3::Reshape>(input) || is_type<opset3::Squeeze>(input) ||
|
||||
is_type<opset3::Unsqueeze>(input))
|
||||
{
|
||||
reshape = make_shared<opset3::Reshape>(input->input_value(0), pat, false);
|
||||
}
|
||||
else
|
||||
{
|
||||
reshape = make_shared<opset3::Reshape>(node->input_value(0), pat, false);
|
||||
}
|
||||
|
||||
// skip if reshape is nop
|
||||
if (reshape->get_input_partial_shape(0).same_scheme(shape_ps))
|
||||
{
|
||||
return replace_output_update_name(node->output(0), reshape->input_value(0));
|
||||
}
|
||||
else
|
||||
{
|
||||
return replace_node_update_name(node, reshape);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::vector<int64_t> get_unsqueeze_axes(const PartialShape& data_shape,
|
||||
const PartialShape& out_shape)
|
||||
{
|
||||
std::vector<int64_t> axes;
|
||||
size_t i = 0;
|
||||
for (auto o = 0; o < out_shape.rank().get_length(); o++)
|
||||
{
|
||||
if (i < data_shape.rank().get_length() && data_shape[i].same_scheme(out_shape[o]))
|
||||
{
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
if (out_shape[o].is_static() && out_shape[o] == 1)
|
||||
{
|
||||
axes.push_back(o);
|
||||
}
|
||||
}
|
||||
return axes;
|
||||
}
|
||||
|
||||
static std::vector<int64_t> get_squeeze_axes(const PartialShape& data_shape,
|
||||
const PartialShape& out_shape)
|
||||
{
|
||||
std::vector<int64_t> axes;
|
||||
size_t out_i = 0;
|
||||
for (auto i = 0; i < data_shape.rank().get_length(); i++)
|
||||
{
|
||||
if (out_i < out_shape.rank().get_length() && data_shape[i].same_scheme(out_shape[out_i]))
|
||||
{
|
||||
out_i += 1;
|
||||
continue;
|
||||
}
|
||||
if (data_shape[i].is_static() && data_shape[i] == 1)
|
||||
{
|
||||
axes.push_back(i);
|
||||
}
|
||||
}
|
||||
return axes;
|
||||
}
|
||||
|
||||
static bool eliminate_unsqueeze(const std::shared_ptr<Node>& node)
|
||||
{
|
||||
auto out_shape = node->get_output_partial_shape(0);
|
||||
// try to replace all squeeze/unsqueeze with reshape
|
||||
if (out_shape.rank().get_length() != 0 && count_unknown_dims(out_shape) < 2)
|
||||
{
|
||||
return replace_squeeze_unsqueeze(node);
|
||||
}
|
||||
|
||||
auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(node);
|
||||
auto input = unsqueeze->input_value(0).get_node_shared_ptr();
|
||||
auto squeeze = as_type_ptr<opset3::Squeeze>(input);
|
||||
auto replace_unsqueeze_only = [&](const vector<int64_t>& axes) {
|
||||
auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
|
||||
auto new_unsq = make_shared<opset3::Unsqueeze>(input->input_value(0), axes_const);
|
||||
if (unsqueeze->get_output_partial_shape(0).same_scheme(
|
||||
new_unsq->get_output_partial_shape(0)))
|
||||
{
|
||||
return replace_node_update_name(unsqueeze, new_unsq);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
// eliminate redundant squeeze->unsqueeze
|
||||
if (squeeze)
|
||||
{
|
||||
const auto& data_shape = squeeze->input_value(0).get_partial_shape();
|
||||
if (ngraph::compare_constants(squeeze->input_value(1).get_node_shared_ptr(),
|
||||
unsqueeze->input_value(1).get_node_shared_ptr()))
|
||||
{
|
||||
return replace_output_update_name(unsqueeze->output(0), squeeze->input_value(0));
|
||||
}
|
||||
if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if (out_shape.rank().get_length() > data_shape.rank().get_length())
|
||||
{
|
||||
// check if single unsqueeze can handle this
|
||||
auto axes = get_unsqueeze_axes(data_shape, out_shape);
|
||||
if (axes.size() + data_shape.rank().get_length() == out_shape.rank().get_length())
|
||||
{
|
||||
return replace_unsqueeze_only(axes);
|
||||
}
|
||||
}
|
||||
if (out_shape.rank().get_length() < data_shape.rank().get_length())
|
||||
{
|
||||
// check if single squeeze can handle this
|
||||
auto axes = get_squeeze_axes(data_shape, out_shape);
|
||||
if (data_shape.rank().get_length() - axes.size() == out_shape.rank().get_length())
|
||||
{
|
||||
auto axes_const =
|
||||
opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
|
||||
auto new_sq = make_shared<opset3::Squeeze>(input->input_value(0), axes_const);
|
||||
if (unsqueeze->get_output_partial_shape(0).same_scheme(
|
||||
new_sq->get_output_partial_shape(0)))
|
||||
{
|
||||
return replace_node_update_name(unsqueeze, new_sq);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// eliminate redundant unsqueeze->unsqueeze
|
||||
auto unsqueeze_i = as_type_ptr<opset3::Unsqueeze>(input);
|
||||
if (unsqueeze_i)
|
||||
{
|
||||
const auto& data_shape = unsqueeze_i->input_value(0).get_partial_shape();
|
||||
if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
auto axes = get_unsqueeze_axes(data_shape, out_shape);
|
||||
return replace_unsqueeze_only(axes);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool eliminate_squeeze(const std::shared_ptr<Node>& node)
|
||||
{
|
||||
auto out_shape = node->get_output_partial_shape(0);
|
||||
// try to replace all unsqueeze/squeeze with reshape
|
||||
if (out_shape.rank().get_length() != 0 && count_unknown_dims(out_shape) < 2)
|
||||
{
|
||||
return replace_squeeze_unsqueeze(node);
|
||||
}
|
||||
|
||||
auto squeeze = as_type_ptr<opset3::Squeeze>(node);
|
||||
auto input = squeeze->input_value(0).get_node_shared_ptr();
|
||||
auto replace_squeeze_only = [&](const vector<int64_t>& axes) {
|
||||
auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
|
||||
auto new_sq = make_shared<opset3::Squeeze>(input->input_value(0), axes_const);
|
||||
if (squeeze->get_output_partial_shape(0).same_scheme(new_sq->get_output_partial_shape(0)))
|
||||
{
|
||||
return replace_node_update_name(squeeze, new_sq);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
// eliminate redundant unsqueeze->squeeze
|
||||
if (auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(input))
|
||||
{
|
||||
PartialShape data_shape;
|
||||
if (op::is_parameter(input))
|
||||
{
|
||||
data_shape = unsqueeze->input(0).get_partial_shape();
|
||||
}
|
||||
else
|
||||
{
|
||||
data_shape = input->input(0).get_partial_shape();
|
||||
}
|
||||
if (ngraph::compare_constants(unsqueeze->input_value(1).get_node_shared_ptr(),
|
||||
squeeze->input_value(1).get_node_shared_ptr()))
|
||||
{
|
||||
return replace_output_update_name(squeeze->output(0), unsqueeze->input_value(0));
|
||||
}
|
||||
if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if (out_shape.rank().get_length() < data_shape.rank().get_length())
|
||||
{
|
||||
// check if single squeeze can handle this
|
||||
auto axes = get_squeeze_axes(data_shape, out_shape);
|
||||
if (data_shape.rank().get_length() == out_shape.rank().get_length() + axes.size())
|
||||
{
|
||||
return replace_squeeze_only(axes);
|
||||
}
|
||||
}
|
||||
if (out_shape.rank().get_length() > data_shape.rank().get_length())
|
||||
{
|
||||
// check if single unsqueeze can handle this
|
||||
auto axes = get_unsqueeze_axes(data_shape, out_shape);
|
||||
if (data_shape.rank().get_length() + axes.size() == out_shape.rank().get_length())
|
||||
{
|
||||
auto axes_const =
|
||||
opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
|
||||
auto new_unsq = make_shared<opset3::Unsqueeze>(input->input_value(0), axes_const);
|
||||
if (squeeze->get_output_partial_shape(0).same_scheme(
|
||||
new_unsq->get_output_partial_shape(0)))
|
||||
{
|
||||
replace_output_update_name(squeeze, new_unsq);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
// eliminate redundant squeeze->squeeze
|
||||
if (auto squeeze_i = as_type_ptr<opset3::Squeeze>(input))
|
||||
{
|
||||
PartialShape data_shape;
|
||||
if (op::is_parameter(input))
|
||||
{
|
||||
data_shape = squeeze_i->input(0).get_partial_shape();
|
||||
}
|
||||
else
|
||||
{
|
||||
data_shape = input->input(0).get_partial_shape();
|
||||
}
|
||||
if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
auto axes = get_squeeze_axes(data_shape, out_shape);
|
||||
return replace_squeeze_only(axes);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool eliminate_stop_gradient(const std::shared_ptr<Node>& node)
|
||||
{
|
||||
replace_output_update_name(node->output(0), node->input_value(0));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pass::NopElimination::run_on_function(std::shared_ptr<Function> function)
|
||||
{
|
||||
static const std::unordered_map<NodeTypeInfo, std::function<bool(const std::shared_ptr<Node>&)>>
|
||||
dispatcher{{TI(op::v0::Pad), &eliminate_nop},
|
||||
{TI(opset3::Pad), &eliminate_nop},
|
||||
{TI(op::v0::Sum), &eliminate_sum},
|
||||
{TI(opset3::Convert), &eliminate_convert},
|
||||
{TI(op::v0::Slice), &eliminate_nop},
|
||||
{TI(op::v0::StopGradient), &eliminate_stop_gradient},
|
||||
{TI(opset3::Reshape), &eliminate_reshape_v1},
|
||||
{TI(opset3::Concat), &eliminate_concat},
|
||||
{TI(opset3::Squeeze), &eliminate_squeeze},
|
||||
{TI(opset3::Unsqueeze), &eliminate_unsqueeze},
|
||||
{TI(op::v0::Broadcast), &eliminate_nop}};
|
||||
|
||||
bool clobbered = false;
|
||||
|
||||
for (const auto& n : function->get_ops())
|
||||
{
|
||||
// Work around a warning [-Wpotentially-evaluated-expression]
|
||||
const Node& node = *n;
|
||||
auto handler = dispatcher.find(node.get_type_info());
|
||||
if (handler != dispatcher.end())
|
||||
{
|
||||
clobbered = handler->second(n) || clobbered;
|
||||
}
|
||||
}
|
||||
|
||||
return clobbered;
|
||||
}
|
||||
Reference in New Issue
Block a user