Remove old transformation passes (#1463)
This commit is contained in:
parent
3be1f6b6fa
commit
534fe35c0a
@ -17,7 +17,6 @@
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph/pass/algebraic_simplification.hpp>
|
||||
#include <ngraph/pass/visualize_tree.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_convolutions.hpp>
|
||||
#include <ngraph_ops/convolution_ie.hpp>
|
||||
|
@ -17,7 +17,6 @@
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph/pass/algebraic_simplification.hpp>
|
||||
#include <ngraph/pass/visualize_tree.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_convolutions.hpp>
|
||||
#include <ngraph_ops/convolution_ie.hpp>
|
||||
|
@ -18,7 +18,6 @@
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/reshape_elimination.hpp"
|
||||
#include "pyngraph/passes/manager.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
@ -28,6 +27,4 @@ void regclass_pyngraph_passes_Manager(py::module m)
|
||||
py::class_<ngraph::pass::Manager, std::shared_ptr<ngraph::pass::Manager>> manager(m, "Manager");
|
||||
manager.doc() = "ngraph.impl.pass.Manager wraps ngraph::pass::Manager";
|
||||
manager.def("run_passes", &ngraph::pass::Manager::run_passes);
|
||||
manager.def("register_pass",
|
||||
&ngraph::pass::Manager::register_pass<ngraph::pass::ReshapeElimination>);
|
||||
}
|
||||
|
@ -454,13 +454,6 @@ set (SRC
|
||||
partial_shape.hpp
|
||||
pass/algebraic_simplification.cpp
|
||||
pass/algebraic_simplification.hpp
|
||||
pass/assign_layout.hpp
|
||||
pass/implicit_broadcast_elimination.hpp
|
||||
pass/implicit_broadcast_elimination.cpp
|
||||
pass/batch_fusion.hpp
|
||||
pass/batch_fusion.cpp
|
||||
pass/common_function_collection.cpp
|
||||
pass/common_function_collection.hpp
|
||||
pass/constant_folding_arithmetic_reduction.cpp
|
||||
pass/constant_folding_convert.cpp
|
||||
pass/constant_folding_dequantize.cpp
|
||||
@ -480,62 +473,25 @@ set (SRC
|
||||
pass/constant_folding_transpose.cpp
|
||||
pass/constant_folding.cpp
|
||||
pass/constant_folding.hpp
|
||||
pass/constant_to_broadcast.cpp
|
||||
pass/convert_fp32_to_fp16.hpp
|
||||
pass/convert_fp32_to_fp16.cpp
|
||||
pass/core_fusion.cpp
|
||||
pass/core_fusion.hpp
|
||||
pass/cse.cpp
|
||||
pass/cse.hpp
|
||||
pass/dump_sorted.cpp
|
||||
pass/dump_sorted.hpp
|
||||
pass/dyn_elimination.cpp
|
||||
pass/dyn_elimination.hpp
|
||||
pass/fused_op_decomposition.cpp
|
||||
pass/fused_op_decomposition.hpp
|
||||
pass/get_output_element_elimination.cpp
|
||||
pass/get_output_element_elimination.hpp
|
||||
pass/graph_rewrite.cpp
|
||||
pass/graph_rewrite.hpp
|
||||
pass/like_replacement.cpp
|
||||
pass/like_replacement.hpp
|
||||
pass/liveness.cpp
|
||||
pass/liveness.hpp
|
||||
pass/manager.cpp
|
||||
pass/manager.hpp
|
||||
pass/manager_state.hpp
|
||||
pass/memory_layout.cpp
|
||||
pass/memory_layout.hpp
|
||||
pass/memory_visualize.cpp
|
||||
pass/memory_visualize.hpp
|
||||
pass/nop_elimination.cpp
|
||||
pass/nop_elimination.hpp
|
||||
pass/pass.cpp
|
||||
pass/pass.hpp
|
||||
pass/pass_config.cpp
|
||||
pass/pass_config.hpp
|
||||
pass/propagate_cacheability.cpp
|
||||
pass/propagate_cacheability.hpp
|
||||
pass/reshape_elimination.cpp
|
||||
pass/reshape_elimination.hpp
|
||||
pass/reshape_sinking.cpp
|
||||
pass/reshape_sinking.hpp
|
||||
pass/serialize.cpp
|
||||
pass/serialize.hpp
|
||||
pass/shape_relevance.cpp
|
||||
pass/shape_relevance.hpp
|
||||
pass/validate_graph.cpp
|
||||
pass/validate_graph.hpp
|
||||
pass/validate.cpp
|
||||
pass/validate.hpp
|
||||
pass/visualize_tree.cpp
|
||||
pass/visualize_tree.hpp
|
||||
pass/zero_dim_tensor_elimination.cpp
|
||||
pass/zero_dim_tensor_elimination.cpp
|
||||
pass/zero_dim_tensor_elimination.hpp
|
||||
pass/zero_dim_tensor_elimination.hpp
|
||||
pass/concat_fusion.hpp
|
||||
pass/concat_fusion.cpp
|
||||
pass/pass_util.hpp
|
||||
pass/pass_util.cpp
|
||||
pattern/matcher.cpp
|
||||
|
@ -1,58 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
#include <sstream>
|
||||
|
||||
#include "ngraph/descriptor/output.hpp"
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
template <typename LT>
|
||||
class AssignLayout : public NodePass
|
||||
{
|
||||
public:
|
||||
virtual bool run_on_node(std::shared_ptr<Node> node) override
|
||||
{
|
||||
try
|
||||
{
|
||||
for (size_t i = 0; i < node->get_output_size(); ++i)
|
||||
{
|
||||
auto tv = &node->output(i).get_tensor();
|
||||
if (nullptr == tv->get_tensor_layout())
|
||||
{
|
||||
auto layout = std::make_shared<LT>(*tv);
|
||||
tv->set_tensor_layout(layout);
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << "Error with node " << *node << ": ";
|
||||
ss << e.what();
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
@ -1,189 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <array>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <stack>
|
||||
#include <typeindex>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "batch_fusion.hpp"
|
||||
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/dot.hpp"
|
||||
#include "ngraph/op/group_conv.hpp"
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/op/slice.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
#include "ngraph/pattern/op/label.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
#define TI(x) std::type_index(typeid(x))
|
||||
|
||||
std::shared_ptr<Node> set_or_check_if_same(std::shared_ptr<Node> oldn, std::shared_ptr<Node> newn)
|
||||
{
|
||||
if (!oldn)
|
||||
{
|
||||
return newn;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (oldn != newn)
|
||||
{
|
||||
NGRAPH_DEBUG << " different data nodes";
|
||||
return nullptr;
|
||||
}
|
||||
return oldn;
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_trivial_convolution(std::shared_ptr<op::Convolution> conv)
|
||||
{
|
||||
Strides stride_1{1, 1};
|
||||
CoordinateDiff pad_0{0, 0};
|
||||
return conv->get_window_dilation_strides() == stride_1 &&
|
||||
conv->get_data_dilation_strides() == stride_1 && conv->get_padding_above() == pad_0 &&
|
||||
conv->get_padding_below() == pad_0;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> fuse_group_convolution(const std::shared_ptr<Node>& n)
|
||||
{
|
||||
Shape win_size_1{1, 1, 1, 1};
|
||||
auto data_label = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 4, 9});
|
||||
auto weights_label = std::make_shared<pattern::op::Label>(element::f32, Shape{4, 2, 3});
|
||||
|
||||
auto slice_data = std::make_shared<op::Slice>(
|
||||
data_label, Coordinate{0, 0, 0}, Coordinate{1, 2, 9}, Strides{1, 1, 1});
|
||||
auto slice_weights = std::make_shared<op::Slice>(
|
||||
weights_label, Coordinate{0, 0, 0}, Coordinate{2, 2, 3}, Strides{1, 1, 1});
|
||||
|
||||
auto slice_weights_label =
|
||||
std::make_shared<pattern::op::Label>(slice_weights, nullptr, NodeVector{slice_weights});
|
||||
auto conv = std::make_shared<op::Convolution>(slice_data, slice_weights_label);
|
||||
auto matcher = std::make_shared<pattern::Matcher>(conv);
|
||||
|
||||
NGRAPH_DEBUG << "In simplify_concat (group convolution) for " << n->get_name();
|
||||
|
||||
std::shared_ptr<Node> data;
|
||||
std::shared_ptr<Node> weights;
|
||||
|
||||
auto concat = std::static_pointer_cast<op::Concat>(n);
|
||||
std::shared_ptr<op::Convolution> sconv;
|
||||
|
||||
NodeVector slices;
|
||||
|
||||
const size_t CHANNEL = 1;
|
||||
if (concat->get_concatenation_axis() != CHANNEL)
|
||||
{
|
||||
NGRAPH_DEBUG << "concatenating on an axis different from channel";
|
||||
return {nullptr};
|
||||
}
|
||||
|
||||
for (auto val : n->input_values())
|
||||
{
|
||||
auto arg = val.get_node_shared_ptr();
|
||||
if (!matcher->match(arg))
|
||||
{
|
||||
NGRAPH_DEBUG << arg->get_name() << " doesn't match";
|
||||
return {nullptr};
|
||||
}
|
||||
|
||||
sconv = std::static_pointer_cast<op::Convolution>(arg);
|
||||
|
||||
if (arg->get_input_shape(0).size() != 4)
|
||||
{
|
||||
NGRAPH_DEBUG << "convolution data's rank isn't equal to 4";
|
||||
return {nullptr};
|
||||
}
|
||||
|
||||
if (!is_trivial_convolution(sconv))
|
||||
{
|
||||
NGRAPH_DEBUG << arg->get_name() << " isn't trivial convolution";
|
||||
return {nullptr};
|
||||
}
|
||||
|
||||
auto pattern_map = matcher->get_pattern_map();
|
||||
data = set_or_check_if_same(data, pattern_map[data_label]);
|
||||
weights = set_or_check_if_same(weights, pattern_map[weights_label]);
|
||||
|
||||
if (!data || !weights)
|
||||
{
|
||||
NGRAPH_DEBUG << "data or weights nodes are different among slices";
|
||||
return {nullptr};
|
||||
}
|
||||
|
||||
const size_t IC = 1;
|
||||
auto slice = pattern_map[slice_weights_label];
|
||||
if (weights->get_shape().at(IC) != slice->get_shape().at(IC))
|
||||
{
|
||||
slices.push_back(slice);
|
||||
}
|
||||
}
|
||||
|
||||
// TF-flavoured group convolution needs channels re-arranged
|
||||
// MKLDNN requires group slicing to be done on OC
|
||||
// MKLDNN [4,2,-]
|
||||
// ordering w00 w01 w10 w11 w20 w21 w30 w31 produces g00 g01 g10 g11
|
||||
// whereas
|
||||
// TF [2,4,-]
|
||||
// ordering w00 w01 w02 w03 w10 w11 w12 w13 produces g00 g10 g01 g11
|
||||
const size_t CONCAT_AXIS_OC = 0;
|
||||
if (!slices.empty())
|
||||
{
|
||||
weights = std::make_shared<op::Concat>(slices, CONCAT_AXIS_OC);
|
||||
}
|
||||
|
||||
auto new_conv = std::make_shared<op::GroupConvolution>(data,
|
||||
weights,
|
||||
sconv->get_window_movement_strides(),
|
||||
sconv->get_window_dilation_strides(),
|
||||
sconv->get_padding_below(),
|
||||
sconv->get_padding_above(),
|
||||
sconv->get_data_dilation_strides(),
|
||||
n->input_values().size());
|
||||
|
||||
return move(new_conv);
|
||||
}
|
||||
|
||||
bool ngraph::pass::BatchFusion::run_on_function(std::shared_ptr<Function> func)
|
||||
{
|
||||
bool modified = false;
|
||||
|
||||
for (auto n : func->get_ordered_ops())
|
||||
{
|
||||
const Node& node = *n;
|
||||
if (TI(node) == TI(op::Concat))
|
||||
{
|
||||
if (m_fusion_type.is_set(FusionType::REGULAR_FUSIONS))
|
||||
{
|
||||
if (auto fused_conv = fuse_group_convolution(n))
|
||||
{
|
||||
func->replace_node(n, fused_conv);
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return modified;
|
||||
}
|
@ -1,40 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class NGRAPH_API BatchFusion : public ngraph::pass::FunctionPass
|
||||
{
|
||||
public:
|
||||
BatchFusion(FusionTypeMask type = FusionType::ALL_FUSIONS)
|
||||
: FunctionPass()
|
||||
, m_fusion_type(type)
|
||||
{
|
||||
set_property(PassProperty::REQUIRE_STATIC_SHAPE, true);
|
||||
}
|
||||
virtual bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
|
||||
|
||||
private:
|
||||
FusionTypeMask m_fusion_type;
|
||||
};
|
||||
}
|
||||
}
|
@ -1,113 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "common_function_collection.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
pass::CommonFunctionCollection::CommonFunctionCollection(function<string(Node&, string)> emitter,
|
||||
unordered_map<Node*, Node*>& result_map,
|
||||
string& emitted_functions)
|
||||
: m_emit_op_as_function(emitter)
|
||||
, m_node_function_map(result_map)
|
||||
, m_emitted_functions(emitted_functions)
|
||||
{
|
||||
}
|
||||
|
||||
pass::CommonFunctionCollection::~CommonFunctionCollection()
|
||||
{
|
||||
}
|
||||
|
||||
bool pass::CommonFunctionCollection::run_on_module(vector<shared_ptr<Function>>& functions)
|
||||
{
|
||||
// This for loop creates a collection of functions that are called more than once
|
||||
// and emitting them as globally callable functions.
|
||||
|
||||
// match_function_map `key` contains the entire string of the function emitted for the
|
||||
// `value` Node*
|
||||
unordered_map<string, Node*> match_function_map;
|
||||
stringstream ss;
|
||||
const string function_name = "__f__";
|
||||
for (const shared_ptr<Function>& current_function : functions)
|
||||
{
|
||||
for (const shared_ptr<Node>& n : current_function->get_ordered_ops())
|
||||
{
|
||||
if (op::is_constant(n) || op::is_parameter(n))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if (op::is_op(n))
|
||||
{
|
||||
auto op = std::static_pointer_cast<op::Op>(n);
|
||||
auto annotations = op->get_op_annotations();
|
||||
// If an op is passed through, do not add it to the common function
|
||||
// collection so that the emitter can decide to eliminate it if desired
|
||||
if (annotations && annotations->get_in_place_oi_pairs().size() > 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
Node& node = *n;
|
||||
|
||||
// First emit the op as a function, something like this:
|
||||
// static void __f__(float* _arg0, float *_out1)
|
||||
// {
|
||||
// op specific code here
|
||||
// }
|
||||
//
|
||||
// Then do a simple string compare in match_function_map to see if there is
|
||||
// another op that emits the exact same code.
|
||||
// If a match is found then the current node is mapped to call the original node's
|
||||
// function and the original node is *also* mapped to call the original node's function.
|
||||
// We also emit the static function declaration to m_emitted_functions when the match
|
||||
// is found the first time.
|
||||
string match_function = m_emit_op_as_function(node, function_name);
|
||||
auto it = match_function_map.find(match_function);
|
||||
if (it != match_function_map.end())
|
||||
{
|
||||
m_node_function_map.insert({&node, it->second});
|
||||
if (m_node_function_map.find(it->second) == m_node_function_map.end())
|
||||
{
|
||||
m_node_function_map.insert({it->second, it->second});
|
||||
|
||||
// All of the functions are created with the same name `__f__` so here
|
||||
// we rename it to something unique so we can compile everything when done.
|
||||
auto offset = match_function.find(function_name);
|
||||
string emitted_function = match_function;
|
||||
string match_function_name = create_function_name(*it->second);
|
||||
emitted_function.replace(offset, function_name.size(), match_function_name);
|
||||
ss << emitted_function << "\n";
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
match_function_map.insert({match_function, &node});
|
||||
}
|
||||
}
|
||||
}
|
||||
m_emitted_functions = ss.str();
|
||||
return false;
|
||||
}
|
||||
|
||||
string pass::CommonFunctionCollection::create_function_name(const Node& node)
|
||||
{
|
||||
return "func_" + node.get_name();
|
||||
}
|
@ -1,62 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ngraph/code_writer.hpp"
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class CommonFunctionCollection;
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::CommonFunctionCollection : public ModulePass
|
||||
{
|
||||
public:
|
||||
/// \brief Create the CommonFunctionCollection pass
|
||||
/// \param function_emitter - This is a function that takes a reference to a Node and as string.
|
||||
/// The string is the name of the emitted function and the body of the function is
|
||||
/// the code for the op.
|
||||
/// \param result_map - This is a mapping of source node -> emitted static function node, where
|
||||
/// the key is the source node and the value is the emitted static function node. The
|
||||
/// name of the function to call is create_function_name(<emitted static function node>)
|
||||
/// \param emitted_functions - string to contain the emitted code for all of the static
|
||||
/// functions.
|
||||
CommonFunctionCollection(std::function<std::string(Node&, std::string)> function_emitter,
|
||||
std::unordered_map<Node*, Node*>& result_map,
|
||||
std::string& emitted_functions);
|
||||
|
||||
virtual ~CommonFunctionCollection() override;
|
||||
|
||||
bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
|
||||
|
||||
/// \brief Construct the name of the function to call for this op
|
||||
/// \param node - Node used to construct the function name. This node is the `value` of the
|
||||
/// result_map passed to the pass's constructor.
|
||||
/// \return string containing the name of the function to be called
|
||||
static std::string create_function_name(const Node& node);
|
||||
|
||||
private:
|
||||
std::function<std::string(Node&, std::string)> m_emit_op_as_function;
|
||||
std::unordered_map<Node*, Node*>& m_node_function_map;
|
||||
std::string& m_emitted_functions;
|
||||
};
|
@ -1,282 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "concat_fusion.hpp"
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
#include "ngraph/pattern/op/label.hpp"
|
||||
#include "ngraph/pattern/op/skip.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
namespace
|
||||
{
|
||||
bool check_self_concat_op(const std::shared_ptr<Node>& op)
|
||||
{
|
||||
auto input_vals = op->input_values();
|
||||
std::set<std::shared_ptr<Node>> input_args_set;
|
||||
for (auto val : input_vals)
|
||||
input_args_set.emplace(val.get_node_shared_ptr());
|
||||
return (input_args_set.size() == 1);
|
||||
}
|
||||
|
||||
bool check_concat_axis_dim_value(const std::shared_ptr<Node>& concat_op)
|
||||
{
|
||||
auto input_shape = concat_op->get_input_shape(0);
|
||||
size_t concat_axis =
|
||||
std::static_pointer_cast<op::Concat>(concat_op)->get_concatenation_axis();
|
||||
|
||||
return (input_shape[concat_axis] == 1);
|
||||
}
|
||||
|
||||
bool check_concat_has_no_fan_out(const std::shared_ptr<Node>& op)
|
||||
{
|
||||
auto no_fan_out = ngraph::pass::get_no_fan_out_function();
|
||||
return no_fan_out(op);
|
||||
}
|
||||
|
||||
bool valid_self_concat(const std::shared_ptr<Node>& Op)
|
||||
{
|
||||
if (!check_self_concat_op(Op))
|
||||
{
|
||||
NGRAPH_DEBUG << "self_concat_fusion: Matcher matched " << Op->get_name()
|
||||
<< " but it is not a self concat\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!check_concat_axis_dim_value(Op))
|
||||
{
|
||||
NGRAPH_DEBUG << "self_concat_fusion: Input shape value along concat axis of "
|
||||
<< Op->get_name() << " is not equal to 1\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<size_t> get_concatenation_axis_vector(const NodeVector& bounded_concat_ops)
|
||||
{
|
||||
std::vector<size_t> concat_axis_vec;
|
||||
for (auto iter : bounded_concat_ops)
|
||||
{
|
||||
auto concat_op = std::static_pointer_cast<op::Concat>(iter);
|
||||
concat_axis_vec.push_back(concat_op->get_concatenation_axis());
|
||||
}
|
||||
return concat_axis_vec;
|
||||
}
|
||||
}
|
||||
|
||||
void pass::ConcatElimination::construct_concat_elimination()
|
||||
{
|
||||
auto op_label = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 3});
|
||||
auto concat = std::make_shared<op::Concat>(NodeVector{op_label}, 0);
|
||||
auto concat_label = std::make_shared<pattern::op::Label>(concat, nullptr, NodeVector{concat});
|
||||
|
||||
auto callback = [op_label](pattern::Matcher& m) {
|
||||
NGRAPH_DEBUG
|
||||
<< "concat_elimination: In callback for construct_concat_elimination against node = "
|
||||
<< m.get_match_root()->get_name();
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
auto op = pattern_map[op_label];
|
||||
|
||||
auto root = as_type_ptr<op::Concat>(m.get_match_root());
|
||||
if (root && (root->get_input_shape(0) == root->get_output_shape(0)))
|
||||
{
|
||||
NGRAPH_DEBUG << " eliminated " << m.get_match_root() << "\n";
|
||||
replace_node(m.get_match_root(), op);
|
||||
|
||||
return true;
|
||||
}
|
||||
NGRAPH_DEBUG << " Incorrect match in callback\n";
|
||||
return false;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(concat_label, "ConcatElimination");
|
||||
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
|
||||
}
|
||||
|
||||
bool ngraph::pass::SelfConcatFusion::run_on_function(std::shared_ptr<Function> function)
|
||||
{
|
||||
bool modify_graph = false;
|
||||
auto has_multiple_inputs = [](std::shared_ptr<Node> n) {
|
||||
auto input_size = n->get_input_size();
|
||||
auto root = as_type_ptr<op::Concat>(n);
|
||||
return (root && input_size > 1);
|
||||
};
|
||||
|
||||
auto print_state_of_bounded_vectors = [this]() -> std::string {
|
||||
std::stringstream ss;
|
||||
ss << "-----------------------------------------------------------" << std::endl;
|
||||
ss << "State of bounded pattern node vectors: " << std::endl;
|
||||
ss << "-----------------------------------------------------------" << std::endl;
|
||||
ss << "Number of pattern node vectors: " << this->m_concat_pattern_vectors.size()
|
||||
<< std::endl;
|
||||
size_t c = 0;
|
||||
for (auto iter : this->m_concat_pattern_vectors)
|
||||
{
|
||||
ss << "For vector " << c << std::endl;
|
||||
auto iter_node_vec = iter;
|
||||
ss << "concat_op_vector: ";
|
||||
for (auto it : iter_node_vec)
|
||||
{
|
||||
ss << it->get_name() << " ";
|
||||
}
|
||||
ss << std::endl;
|
||||
c++;
|
||||
}
|
||||
ss << "-----------------------------" << std::endl;
|
||||
return ss.str();
|
||||
};
|
||||
|
||||
auto concat_op_label =
|
||||
std::make_shared<pattern::op::Label>(element::f32, Shape{1, 3}, has_multiple_inputs);
|
||||
auto matcher = std::make_shared<pattern::Matcher>(concat_op_label);
|
||||
for (auto n : function->get_ordered_ops())
|
||||
{
|
||||
construct_concat_patterns(matcher, concat_op_label, n);
|
||||
}
|
||||
|
||||
NGRAPH_DEBUG << print_state_of_bounded_vectors();
|
||||
|
||||
for (auto concat_op_pattern_node_vector : this->m_concat_pattern_vectors)
|
||||
{
|
||||
modify_graph = replace_patterns(concat_op_pattern_node_vector);
|
||||
}
|
||||
|
||||
return modify_graph;
|
||||
}
|
||||
|
||||
void ngraph::pass::SelfConcatFusion::construct_concat_patterns(
|
||||
const std::shared_ptr<pattern::Matcher>& matcher,
|
||||
const std::shared_ptr<pattern::op::Label>& concat_op_label,
|
||||
const std::shared_ptr<Node>& n)
|
||||
{
|
||||
if (matcher->match(n))
|
||||
{
|
||||
auto concat_op = matcher->get_pattern_map()[concat_op_label];
|
||||
if (!is_type<op::Concat>(concat_op))
|
||||
{
|
||||
NGRAPH_DEBUG << "self_concat_fusion: Pattern matcher matched incorrect op. Matched "
|
||||
<< concat_op->get_name() << " instead of a self concat";
|
||||
return;
|
||||
}
|
||||
if (!valid_self_concat(concat_op))
|
||||
{
|
||||
NGRAPH_DEBUG << "self_concat_fusion: " << concat_op->get_name()
|
||||
<< " is not a valid self concat\n";
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
NGRAPH_DEBUG << "self_concat_fusion: " << concat_op->get_name()
|
||||
<< " is a VALID self concat\n";
|
||||
}
|
||||
|
||||
auto& concat_vectors = this->m_concat_pattern_vectors;
|
||||
if (concat_vectors.empty())
|
||||
{
|
||||
concat_vectors.push_back(NodeVector{concat_op});
|
||||
}
|
||||
else
|
||||
{
|
||||
update_concat_pattern_vectors(concat_op);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ngraph::pass::SelfConcatFusion::update_concat_pattern_vectors(
|
||||
const std::shared_ptr<Node>& concat_op)
|
||||
{
|
||||
bool concat_source_found = false;
|
||||
for (auto& concat_pattern_vec : this->m_concat_pattern_vectors)
|
||||
{
|
||||
auto last_op_in_pattern_vec = concat_pattern_vec.back();
|
||||
if ((concat_op->input_value(0).get_node_shared_ptr() == last_op_in_pattern_vec) &&
|
||||
(check_concat_has_no_fan_out(last_op_in_pattern_vec)))
|
||||
{
|
||||
concat_pattern_vec.push_back(concat_op);
|
||||
concat_source_found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!concat_source_found)
|
||||
{
|
||||
this->m_concat_pattern_vectors.push_back(NodeVector{concat_op});
|
||||
}
|
||||
}
|
||||
|
||||
void ngraph::pass::SelfConcatFusion::remove_single_concat_op_pattern()
|
||||
{
|
||||
auto iter = m_concat_pattern_vectors.begin();
|
||||
while (iter != m_concat_pattern_vectors.end())
|
||||
{
|
||||
if (iter->size() == 1)
|
||||
{
|
||||
iter = m_concat_pattern_vectors.erase(iter);
|
||||
}
|
||||
else
|
||||
{
|
||||
iter++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool ngraph::pass::SelfConcatFusion::replace_patterns(const NodeVector& bounded_concat_ops)
|
||||
{
|
||||
auto scalarize_dim = [](std::vector<size_t> concat_axis_vector,
|
||||
const Shape& input_shape) -> Shape {
|
||||
|
||||
Shape scalarized_shape;
|
||||
for (size_t i = 0; i < input_shape.size(); i++)
|
||||
{
|
||||
auto it = std::find(concat_axis_vector.begin(), concat_axis_vector.end(), i);
|
||||
if (it == concat_axis_vector.end())
|
||||
{
|
||||
scalarized_shape.push_back(input_shape[i]);
|
||||
}
|
||||
}
|
||||
return scalarized_shape;
|
||||
};
|
||||
|
||||
auto concat_axis_vector = get_concatenation_axis_vector(bounded_concat_ops);
|
||||
|
||||
auto& first_bounded_concat = (*bounded_concat_ops.begin());
|
||||
auto driver_op = first_bounded_concat->input_value(0);
|
||||
const Shape& input_shape = first_bounded_concat->get_input_shape(0);
|
||||
|
||||
auto scalarized_shape = scalarize_dim(concat_axis_vector, input_shape);
|
||||
AxisVector axis_order = get_default_order(input_shape);
|
||||
auto reshape = std::make_shared<op::Reshape>(driver_op, axis_order, scalarized_shape);
|
||||
auto last_bounded_concat_op = bounded_concat_ops.back();
|
||||
auto broadcast_out_shape = last_bounded_concat_op->get_shape();
|
||||
auto broadcast =
|
||||
std::make_shared<op::Broadcast>(reshape, broadcast_out_shape, concat_axis_vector);
|
||||
|
||||
replace_node(last_bounded_concat_op, broadcast);
|
||||
return true;
|
||||
}
|
@ -1,61 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/graph_rewrite.hpp"
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "ngraph/pass/pass_util.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
#include "ngraph/pattern/op/label.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class ConcatElimination;
|
||||
class SelfConcatFusion;
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::ConcatElimination : public ngraph::pass::GraphRewrite
|
||||
{
|
||||
public:
|
||||
ConcatElimination()
|
||||
: GraphRewrite()
|
||||
{
|
||||
construct_concat_elimination();
|
||||
}
|
||||
|
||||
private:
|
||||
void construct_concat_elimination();
|
||||
};
|
||||
|
||||
class NGRAPH_API ngraph::pass::SelfConcatFusion : public ngraph::pass::FunctionPass
|
||||
{
|
||||
public:
|
||||
SelfConcatFusion() { set_property(PassProperty::REQUIRE_STATIC_SHAPE, true); }
|
||||
virtual bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
|
||||
|
||||
private:
|
||||
void update_concat_pattern_vectors(const std::shared_ptr<Node>&);
|
||||
void remove_single_concat_op_pattern();
|
||||
void construct_concat_patterns(const std::shared_ptr<pattern::Matcher>&,
|
||||
const std::shared_ptr<pattern::op::Label>&,
|
||||
const std::shared_ptr<Node>&);
|
||||
bool replace_patterns(const NodeVector&);
|
||||
std::vector<NodeVector> m_concat_pattern_vectors;
|
||||
};
|
@ -1,51 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/pass/constant_to_broadcast.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
bool pass::ConstantToBroadcast::run_on_node(shared_ptr<Node> node)
|
||||
{
|
||||
const size_t minimum_size_of_interest = 32;
|
||||
bool modified = false;
|
||||
if (node->description() == "Constant")
|
||||
{
|
||||
auto constant = static_pointer_cast<op::Constant>(node);
|
||||
size_t size = shape_size(constant->get_shape());
|
||||
if (size > minimum_size_of_interest)
|
||||
{
|
||||
if (constant->get_all_data_elements_bitwise_identical())
|
||||
{
|
||||
auto scalar_constant = make_shared<op::Constant>(
|
||||
constant->get_element_type(), Shape{}, constant->get_data_ptr());
|
||||
AxisSet broadcast_axes;
|
||||
for (size_t i = 0; i < constant->get_output_shape(0).size(); i++)
|
||||
{
|
||||
broadcast_axes.insert(i);
|
||||
}
|
||||
auto broadcast = make_shared<op::Broadcast>(
|
||||
scalar_constant, constant->get_output_shape(0), broadcast_axes);
|
||||
replace_node(constant, broadcast);
|
||||
}
|
||||
}
|
||||
}
|
||||
return modified;
|
||||
}
|
@ -1,33 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class ConstantToBroadcast;
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::ConstantToBroadcast : public NodePass
|
||||
{
|
||||
public:
|
||||
bool run_on_node(std::shared_ptr<ngraph::Node>) override;
|
||||
};
|
@ -1,722 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ngraph/pass/core_fusion.hpp"
|
||||
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/batch_norm.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/convert.hpp"
|
||||
#include "ngraph/op/convolution.hpp"
|
||||
#include "ngraph/op/divide.hpp"
|
||||
#include "ngraph/op/exp.hpp"
|
||||
#include "ngraph/op/log.hpp"
|
||||
#include "ngraph/op/max.hpp"
|
||||
#include "ngraph/op/max_pool.hpp"
|
||||
#include "ngraph/op/maximum.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/op/negative.hpp"
|
||||
#include "ngraph/op/not_equal.hpp"
|
||||
#include "ngraph/op/one_hot.hpp"
|
||||
#include "ngraph/op/pad.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/relu.hpp"
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/op/sigmoid.hpp"
|
||||
#include "ngraph/op/softmax.hpp"
|
||||
#include "ngraph/op/sqrt.hpp"
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/pass/graph_rewrite.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
#include "ngraph/pattern/op/label.hpp"
|
||||
#include "ngraph/pattern/op/skip.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
static shared_ptr<Node> construct_constant_node(int n)
|
||||
{
|
||||
return op::Constant::create(element::f32, Shape{}, {n});
|
||||
}
|
||||
|
||||
void pass::CoreFusion::construct_relu()
|
||||
{
|
||||
auto iconst0 = construct_constant_node(0);
|
||||
auto val = make_shared<pattern::op::Label>(iconst0);
|
||||
auto zero = make_shared<pattern::op::Label>(iconst0, nullptr, NodeVector{iconst0});
|
||||
|
||||
auto skip_broadcast = make_shared<pattern::op::Skip>(zero, pattern::has_class<op::Broadcast>());
|
||||
auto max = make_shared<op::Maximum>(skip_broadcast, val);
|
||||
|
||||
auto callback = [val, zero](pattern::Matcher& m) {
|
||||
NGRAPH_DEBUG << "In a callback for construct_relu against "
|
||||
<< m.get_match_root()->get_name();
|
||||
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
auto mzero = m.get_pattern_map()[zero];
|
||||
if (!is_zero(mzero))
|
||||
{
|
||||
NGRAPH_DEBUG << "zero constant = " << mzero->get_name() << " not equal to 0\n";
|
||||
return false;
|
||||
}
|
||||
auto mpattern = m.get_match_root();
|
||||
|
||||
auto cg = shared_ptr<Node>(new op::Relu(pattern_map[val]));
|
||||
replace_node(m.get_match_root(), cg);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = make_shared<pattern::Matcher>(max, "CoreFusion.Relu");
|
||||
this->add_matcher(m, callback, all_pass_property_off);
|
||||
}
|
||||
|
||||
void pass::CoreFusion::construct_sigmoid()
|
||||
{
|
||||
// construct variance
|
||||
auto input = make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
|
||||
auto neg_input = make_shared<op::Negative>(input);
|
||||
auto exp_neg_input = make_shared<op::Exp>(neg_input);
|
||||
|
||||
auto constant = make_shared<pattern::op::Label>(element::f32, Shape{3, 4});
|
||||
auto skip_broadcast =
|
||||
make_shared<pattern::op::Skip>(constant, pattern::has_class<op::Broadcast>());
|
||||
|
||||
auto add_exp = make_shared<op::Add>(exp_neg_input, skip_broadcast);
|
||||
auto divide_1_over_exp = make_shared<op::Divide>(skip_broadcast, add_exp);
|
||||
|
||||
// Define a call back that needs to called once the DFG matches the pattern
|
||||
auto callback = [input, constant](pattern::Matcher& m) {
|
||||
NGRAPH_DEBUG << "In a callback for construct_fprop_sigmoid pattern against "
|
||||
<< m.get_match_root()->get_name();
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
|
||||
if (m.get_match_root()->get_element_type() != element::f32)
|
||||
{
|
||||
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
|
||||
<< " type is not float!";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m.get_match_root()->get_output_size() != pattern_map[input]->get_output_size())
|
||||
{
|
||||
NGRAPH_DEBUG << "mpattern = " << m.get_match_root()->get_name()
|
||||
<< "input= " << pattern_map[input]->get_name() << "size dont match!";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!is_one(pattern_map[constant]))
|
||||
{
|
||||
NGRAPH_DEBUG << "Node not constant or not 1";
|
||||
return false;
|
||||
}
|
||||
auto sigmoid_node = make_shared<op::Sigmoid>(pattern_map[input]);
|
||||
replace_node(m.get_match_root(), sigmoid_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(divide_1_over_exp, "CoreFusion.Sigmoid");
|
||||
this->add_matcher(m, callback, all_pass_property_off);
|
||||
}
|
||||
|
||||
void pass::CoreFusion::construct_folded_batch_norm()
|
||||
{
|
||||
Shape shape{2, 2, 1, 1};
|
||||
auto input = make_shared<pattern::op::Label>(element::f32, shape);
|
||||
auto filters = make_shared<pattern::op::Label>(element::f32, shape);
|
||||
|
||||
auto pconv = make_shared<op::Convolution>(input,
|
||||
filters,
|
||||
Strides{1, 1},
|
||||
Strides{1, 1},
|
||||
CoordinateDiff{0, 0},
|
||||
CoordinateDiff{0, 0},
|
||||
Strides{1, 1});
|
||||
auto mean_shape = Shape{2};
|
||||
auto mean = make_shared<pattern::op::Label>(element::f32, mean_shape);
|
||||
auto var_shape = Shape{2};
|
||||
auto var = make_shared<pattern::op::Label>(element::f32, var_shape);
|
||||
auto gamma_shape = Shape{2};
|
||||
auto gamma = make_shared<pattern::op::Label>(element::f32, gamma_shape);
|
||||
auto beta_shape = Shape{2};
|
||||
auto beta = make_shared<pattern::op::Label>(element::f32, beta_shape);
|
||||
double eps = 0.001;
|
||||
auto shape_r = Shape{1, 2, 2, 2};
|
||||
auto bn = make_shared<op::BatchNormInference>(pconv, gamma, beta, mean, var, eps);
|
||||
|
||||
auto callback = [input, filters, mean, var, gamma, beta](pattern::Matcher& m) {
|
||||
NGRAPH_DEBUG << "In callback for folded batch norm against node = "
|
||||
<< m.get_match_root()->get_name();
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
|
||||
auto m_bn = static_pointer_cast<op::BatchNormInference>(m.get_match_root());
|
||||
auto m_conv =
|
||||
static_pointer_cast<op::Convolution>(m_bn->input_value(2).get_node_shared_ptr());
|
||||
|
||||
if (m_conv->get_users().size() > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m_conv->get_shape().size() != 4)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// new weights = old weights * gamma / sqrt(variance + epsilon)
|
||||
// new biases = -mean * gamma / sqrt(variance + epsilon) + beta
|
||||
|
||||
auto bn_eps = op::Constant::create(element::f32, Shape{}, {m_bn->get_eps_value()});
|
||||
auto var_eps = make_shared<op::Add>(
|
||||
pattern_map[var],
|
||||
make_shared<op::Broadcast>(bn_eps, pattern_map[var]->get_shape(), AxisSet{0}));
|
||||
auto sqrt_var_eps = make_shared<op::Sqrt>(var_eps);
|
||||
|
||||
auto mean_gamma = make_shared<op::Multiply>(pattern_map[mean], pattern_map[gamma]);
|
||||
auto new_biases = make_shared<op::Subtract>(
|
||||
pattern_map[beta], make_shared<op::Divide>(mean_gamma, sqrt_var_eps));
|
||||
auto weight_scaling = make_shared<op::Divide>(pattern_map[gamma], sqrt_var_eps);
|
||||
auto new_weights = make_shared<op::Multiply>(
|
||||
pattern_map[filters],
|
||||
make_shared<op::Broadcast>(
|
||||
weight_scaling, pattern_map[filters]->get_shape(), AxisSet{1, 2, 3}));
|
||||
|
||||
auto conv = make_shared<op::Convolution>(pattern_map[input],
|
||||
new_weights,
|
||||
m_conv->get_window_movement_strides(),
|
||||
m_conv->get_window_dilation_strides(),
|
||||
m_conv->get_padding_below(),
|
||||
m_conv->get_padding_above(),
|
||||
m_conv->get_data_dilation_strides());
|
||||
auto conv_bias =
|
||||
conv + make_shared<op::Broadcast>(new_biases, conv->get_shape(), AxisSet{0, 2, 3});
|
||||
replace_node(m.get_match_root(), conv_bias);
|
||||
|
||||
return true;
|
||||
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(bn, "CoreFusion.FoldedBatchNorm");
|
||||
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
|
||||
}
|
||||
|
||||
void pass::CoreFusion::construct_conv_affine_folding()
|
||||
{
|
||||
// A * Conv (input, filters) + B -> ConvBias (input, filters * A_c, B_c)
|
||||
Shape shape{2, 2, 1, 1};
|
||||
auto input = make_shared<pattern::op::Label>(element::f32, shape);
|
||||
auto filters = make_shared<pattern::op::Label>(element::f32, shape);
|
||||
|
||||
auto conv = make_shared<op::Convolution>(input,
|
||||
filters,
|
||||
Strides{1, 1},
|
||||
Strides{1, 1},
|
||||
CoordinateDiff{0, 0},
|
||||
CoordinateDiff{0, 0},
|
||||
Strides{1, 1});
|
||||
auto conv_label = make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
|
||||
|
||||
auto Ac = make_shared<pattern::op::Label>(element::f32, Shape{2});
|
||||
auto A = make_shared<op::Broadcast>(Ac, Shape{2, 2, 1, 1}, AxisSet{0, 2, 3});
|
||||
auto A_label = make_shared<pattern::op::Label>(A, nullptr, NodeVector{A});
|
||||
auto Bc = make_shared<pattern::op::Label>(element::f32, Shape{2});
|
||||
auto B = make_shared<op::Broadcast>(Bc, Shape{2, 2, 1, 1}, AxisSet{0, 2, 3});
|
||||
auto B_label = make_shared<pattern::op::Label>(B, nullptr, NodeVector{B});
|
||||
auto multiply = make_shared<op::Multiply>(conv_label, A_label);
|
||||
auto add = make_shared<op::Add>(multiply, B_label);
|
||||
|
||||
auto callback = [input, filters, conv_label, A_label, B_label](pattern::Matcher& m) {
|
||||
NGRAPH_DEBUG << "In callback for conv affine folding against node = "
|
||||
<< m.get_match_root()->get_name();
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
|
||||
auto conv_m = static_pointer_cast<op::Convolution>(pattern_map[conv_label]);
|
||||
|
||||
if (conv_m->get_users().size() > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (conv_m->get_shape().size() != 4)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
auto A_m = static_pointer_cast<op::Broadcast>(pattern_map[A_label]);
|
||||
auto B_m = static_pointer_cast<op::Broadcast>(pattern_map[B_label]);
|
||||
|
||||
// Check if values are being broadcast along channel (2nd) dimension
|
||||
auto is_channel_bcast = [](const shared_ptr<op::Broadcast>& bcast) {
|
||||
|
||||
if (bcast->get_input_shape(0).size() == 1 &&
|
||||
bcast->get_broadcast_axes() == AxisSet{0, 2, 3})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
if (bcast->get_input_shape(0).size() == 2)
|
||||
{
|
||||
auto input_shape = bcast->get_input_shape(0);
|
||||
if (input_shape[0] == 1 && bcast->get_broadcast_axes() == AxisSet{2, 3})
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
if (!is_channel_bcast(A_m) || !is_channel_bcast(B_m))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
auto get_bcast_input = [](const shared_ptr<op::Broadcast>& bcast) {
|
||||
if (bcast->get_input_shape(0).size() == 1)
|
||||
{
|
||||
return bcast->input_value(0).get_node_shared_ptr();
|
||||
}
|
||||
if (bcast->get_input_shape(0).size() == 2)
|
||||
{
|
||||
Shape bshape{bcast->get_input_shape(0)[1]};
|
||||
return static_pointer_cast<Node>(
|
||||
make_shared<op::Reshape>(bcast->input_value(0), AxisVector{0, 1}, bshape));
|
||||
}
|
||||
throw ngraph_error("Unexpected shape for bcast input");
|
||||
};
|
||||
|
||||
auto Ac_m = get_bcast_input(A_m);
|
||||
|
||||
// new weights = old weights * Ac_m
|
||||
// new biases = Bc_m
|
||||
|
||||
auto filters_n = make_shared<op::Multiply>(
|
||||
pattern_map[filters],
|
||||
make_shared<op::Broadcast>(Ac_m, pattern_map[filters]->get_shape(), AxisSet{1, 2, 3}));
|
||||
|
||||
auto conv_n = make_shared<op::Convolution>(pattern_map[input],
|
||||
filters_n,
|
||||
conv_m->get_window_movement_strides(),
|
||||
conv_m->get_window_dilation_strides(),
|
||||
conv_m->get_padding_below(),
|
||||
conv_m->get_padding_above(),
|
||||
conv_m->get_data_dilation_strides());
|
||||
auto convbias_n = conv_n + B_m;
|
||||
replace_node(m.get_match_root(), convbias_n);
|
||||
|
||||
return true;
|
||||
|
||||
};
|
||||
|
||||
auto m = make_shared<pattern::Matcher>(add, "CoreFusion.ConvAffineFolding");
|
||||
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
|
||||
}
|
||||
|
||||
static bool is_trivial_convolution(shared_ptr<op::Convolution> conv, bool skip_pad_checks = false)
|
||||
{
|
||||
Strides stride_1{1, 1};
|
||||
CoordinateDiff pad_0{0, 0};
|
||||
|
||||
return conv->get_window_dilation_strides() == stride_1 &&
|
||||
conv->get_data_dilation_strides() == stride_1 &&
|
||||
(skip_pad_checks ||
|
||||
(conv->get_padding_above() == pad_0 && conv->get_padding_below() == pad_0));
|
||||
}
|
||||
|
||||
static bool are_img_dims_equal(Shape conv_shape, Shape image_shape)
|
||||
{
|
||||
if (conv_shape.size() != 4)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return conv_shape[2] == image_shape[0] && conv_shape[3] == image_shape[1];
|
||||
}
|
||||
|
||||
static shared_ptr<Node> reduce_broadcast(shared_ptr<Node> broadcast)
|
||||
{
|
||||
const size_t H = 2;
|
||||
const size_t W = 3;
|
||||
auto matched_broadcast_w1 = static_pointer_cast<op::Broadcast>(broadcast);
|
||||
Shape shape_w1{matched_broadcast_w1->get_shape()};
|
||||
shape_w1[H] /= 2;
|
||||
shape_w1[W] /= 2;
|
||||
auto new_broadcast_w1 = std::make_shared<op::Broadcast>(
|
||||
matched_broadcast_w1->input_value(0), shape_w1, matched_broadcast_w1->get_broadcast_axes());
|
||||
return move(new_broadcast_w1);
|
||||
}
|
||||
|
||||
static size_t shape_to_index(Shape shape)
|
||||
{
|
||||
if (shape.size() != 4)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
const size_t HEIGHT_DIM = 2;
|
||||
const size_t WIDTH_DIM = 3;
|
||||
|
||||
if (shape.at(HEIGHT_DIM) != shape.at(WIDTH_DIM))
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
switch (shape.at(HEIGHT_DIM))
|
||||
{
|
||||
case 28: return 1;
|
||||
case 14: return 2;
|
||||
case 7: return 3;
|
||||
default: return 0;
|
||||
}
|
||||
}
|
||||
|
||||
void pass::CoreFusion::construct_reshape_broadcast()
|
||||
{
|
||||
Shape input_shape{10};
|
||||
auto input = make_shared<pattern::op::Label>(element::f32, input_shape);
|
||||
auto reshape1 = make_shared<op::Reshape>(input, AxisVector{0}, Shape{10, 1});
|
||||
auto broadcast = make_shared<op::Broadcast>(reshape1, Shape{10, 1, 20}, AxisSet{2});
|
||||
|
||||
auto callback = [input](pattern::Matcher& m) {
|
||||
NGRAPH_DEBUG << "In a callback for construct_reshape_broadcast against "
|
||||
<< m.get_match_root()->get_name();
|
||||
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
auto broadcast_m = static_pointer_cast<op::Broadcast>(m.get_match_root());
|
||||
auto reshape1_m =
|
||||
static_pointer_cast<op::Reshape>(broadcast_m->input_value(0).get_node_shared_ptr());
|
||||
auto input_m = m.get_pattern_value_map()[input];
|
||||
|
||||
// it doesn't seem to make sense to support shapes : [0] or [1]
|
||||
if (input_m.get_shape().size() != 1 || input_m.get_shape().at(0) < 2)
|
||||
{
|
||||
NGRAPH_DEBUG << "input_m isn't a scalar or contains zero dimension";
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t dim = input_m.get_shape().at(0);
|
||||
|
||||
// We are going to support the most common case where broadcast doesn't add 1-dimensions
|
||||
// since it's also very simple to implement
|
||||
size_t dim_one_count = 0;
|
||||
for (auto d : reshape1_m->get_shape())
|
||||
{
|
||||
if (d != 1 && d != dim)
|
||||
{
|
||||
NGRAPH_DEBUG << "Input is reshaped in a way we can't directly broadcast ( shape = "
|
||||
<< vector_to_string(reshape1_m->get_shape()) << ")";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (d == 1)
|
||||
{
|
||||
dim_one_count++;
|
||||
}
|
||||
}
|
||||
|
||||
AxisSet new_axes = broadcast_m->get_broadcast_axes();
|
||||
auto broadcast_shape = broadcast_m->get_shape();
|
||||
for (size_t i = 0; i < broadcast_shape.size(); i++)
|
||||
{
|
||||
if (broadcast_shape[i] == 1)
|
||||
{
|
||||
dim_one_count--;
|
||||
new_axes.insert(i);
|
||||
}
|
||||
}
|
||||
|
||||
if (dim_one_count != 0)
|
||||
{
|
||||
NGRAPH_DEBUG << "Broadcast adds 1-dimensions";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto new_broadcast =
|
||||
make_shared<op::Broadcast>(input_m, broadcast_m->get_shape(), new_axes);
|
||||
replace_node(m.get_match_root(), new_broadcast);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = make_shared<pattern::Matcher>(broadcast, "CoreFusion.ReshapeBroadcast");
|
||||
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
|
||||
}
|
||||
|
||||
void pass::CoreFusion::construct_reshape_softmax_reshape()
|
||||
{
|
||||
Shape input_shape{10, 20};
|
||||
AxisVector io{1, 0};
|
||||
auto input = make_shared<pattern::op::Label>(element::f32, input_shape);
|
||||
auto reshape1 = make_shared<op::Reshape>(input, io, Shape{20, 10});
|
||||
auto softmax = make_shared<op::Softmax>(reshape1, AxisSet{1});
|
||||
auto reshape2 = make_shared<op::Reshape>(softmax, io, input_shape);
|
||||
|
||||
auto callback = [input](pattern::Matcher& m) {
|
||||
NGRAPH_DEBUG << "In a callback for construct_reshape_softmax_reshape against "
|
||||
<< m.get_match_root()->get_name();
|
||||
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
auto reshape2_m = static_pointer_cast<op::Reshape>(m.get_match_root());
|
||||
auto softmax_m =
|
||||
static_pointer_cast<op::Softmax>(reshape2_m->input_value(0).get_node_shared_ptr());
|
||||
auto reshape1_m =
|
||||
static_pointer_cast<op::Reshape>(softmax_m->input_value(0).get_node_shared_ptr());
|
||||
auto input_m = m.get_pattern_map()[input];
|
||||
|
||||
if (!reshape2_m->get_is_transpose() || !reshape1_m->get_is_transpose())
|
||||
{
|
||||
NGRAPH_DEBUG << "we expect reshape2 and reshape1 both be dimshuffles";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input_m->get_shape() != reshape2_m->get_shape())
|
||||
{
|
||||
NGRAPH_DEBUG << "input and reshape2's shape are different";
|
||||
return false;
|
||||
}
|
||||
|
||||
AxisSet new_axes;
|
||||
const auto& axis_order = reshape2_m->get_input_order();
|
||||
for (auto axis : softmax_m->get_axes())
|
||||
{
|
||||
new_axes.insert(axis_order.at(axis));
|
||||
}
|
||||
|
||||
auto new_softmax = make_shared<op::Softmax>(input_m, new_axes);
|
||||
replace_node(m.get_match_root(), new_softmax);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = make_shared<pattern::Matcher>(reshape2, "CoreFusion.ReshapeSoftmaxReshape");
|
||||
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
|
||||
}
|
||||
|
||||
static bool
|
||||
zero_padded_conv_consistency_check(const std::shared_ptr<ngraph::Node>& match_root,
|
||||
const std::shared_ptr<ngraph::op::Constant>& pad_value_op,
|
||||
const std::shared_ptr<ngraph::Node>& pad_input,
|
||||
const std::shared_ptr<ngraph::op::Pad>& matched_pad,
|
||||
const ngraph::CoordinateDiff& padding_below,
|
||||
const ngraph::CoordinateDiff& padding_above,
|
||||
size_t batch_index,
|
||||
size_t channel_index)
|
||||
{
|
||||
// Only match float32 convolutions
|
||||
if (match_root->get_element_type() != ngraph::element::f32)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Only match zero padding
|
||||
if (pad_value_op->get_vector<float>().at(0) != 0.0f)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Only match 4D tensors
|
||||
if (pad_input->get_shape().size() != 4)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Only match convolutions with no padding specification
|
||||
if (padding_below != ngraph::CoordinateDiff(2) || padding_above != ngraph::CoordinateDiff(2))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Only match constant padding
|
||||
if (matched_pad->get_pad_mode() != ngraph::op::PadMode::CONSTANT)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Only match no padding in the batch dimension
|
||||
if (matched_pad->get_padding_above().at(batch_index) != 0 ||
|
||||
matched_pad->get_padding_below().at(batch_index) != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Only match no padding in the channel dimension
|
||||
if (matched_pad->get_padding_above().at(channel_index) != 0 ||
|
||||
matched_pad->get_padding_below().at(channel_index) != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void pass::CoreFusion::construct_zero_padded_reshaped_conv()
|
||||
{
|
||||
auto pad_input = std::make_shared<pattern::op::Label>(element::f32, Shape{});
|
||||
auto pad_value = std::make_shared<pattern::op::Label>(element::f32, Shape{});
|
||||
auto pad =
|
||||
std::make_shared<ngraph::op::Pad>(pad_input, pad_value, CoordinateDiff{}, CoordinateDiff{});
|
||||
auto pad_label = std::make_shared<pattern::op::Label>(pad, nullptr, NodeVector{pad});
|
||||
|
||||
auto reshape =
|
||||
std::make_shared<ngraph::op::Reshape>(pad_label, AxisVector{}, Shape{1, 1, 1, 1});
|
||||
auto reshape_label =
|
||||
std::make_shared<pattern::op::Label>(reshape, nullptr, NodeVector{reshape});
|
||||
|
||||
auto conv_filter = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
||||
|
||||
auto conv = std::make_shared<ngraph::op::Convolution>(reshape_label,
|
||||
conv_filter,
|
||||
Strides{1, 1},
|
||||
Strides{1, 1},
|
||||
CoordinateDiff{1, 1},
|
||||
CoordinateDiff{1, 1},
|
||||
Strides{1, 1});
|
||||
auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
|
||||
|
||||
auto callback = [pad_input, pad_value, pad_label, reshape_label, conv_filter, conv_label](
|
||||
pattern::Matcher& m) {
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
|
||||
auto pad_value_op = as_type_ptr<ngraph::op::Constant>(pattern_map[pad_value]);
|
||||
if (!pad_value_op)
|
||||
{
|
||||
NGRAPH_DEBUG << "Pad value must be a constant";
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& matched_conv = as_type_ptr<ngraph::op::Convolution>(pattern_map[conv_label]);
|
||||
const auto& matched_pad = as_type_ptr<ngraph::op::Pad>(pattern_map[pad_label]);
|
||||
const auto& matched_reshape =
|
||||
std::static_pointer_cast<ngraph::op::Reshape>(pattern_map[reshape_label]);
|
||||
|
||||
const auto& input_order = matched_reshape->get_input_order();
|
||||
auto hoisted_reshape_output_shape =
|
||||
ngraph::apply_permutation<Shape>(pattern_map[pad_input]->get_shape(), input_order);
|
||||
|
||||
auto hoisted_reshape = std::make_shared<ngraph::op::Reshape>(
|
||||
pattern_map[pad_input],
|
||||
input_order,
|
||||
Shape(hoisted_reshape_output_shape.begin(), hoisted_reshape_output_shape.end()));
|
||||
|
||||
if (!zero_padded_conv_consistency_check(m.get_match_root(),
|
||||
pad_value_op,
|
||||
pattern_map[pad_input],
|
||||
matched_pad,
|
||||
matched_conv->get_padding_below(),
|
||||
matched_conv->get_padding_above(),
|
||||
input_order[0],
|
||||
input_order[1]))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
CoordinateDiff padding_below{static_cast<CoordinateDiff::value_type>(
|
||||
matched_pad->get_padding_below().at(input_order[2])),
|
||||
static_cast<CoordinateDiff::value_type>(
|
||||
matched_pad->get_padding_below().at(input_order[3]))};
|
||||
CoordinateDiff padding_above{static_cast<CoordinateDiff::value_type>(
|
||||
matched_pad->get_padding_above().at(input_order[2])),
|
||||
static_cast<CoordinateDiff::value_type>(
|
||||
matched_pad->get_padding_above().at(input_order[3]))};
|
||||
|
||||
auto zero_padded_conv =
|
||||
std::make_shared<ngraph::op::Convolution>(hoisted_reshape,
|
||||
pattern_map[conv_filter],
|
||||
matched_conv->get_window_movement_strides(),
|
||||
matched_conv->get_window_dilation_strides(),
|
||||
padding_below,
|
||||
padding_above,
|
||||
matched_conv->get_data_dilation_strides());
|
||||
|
||||
ngraph::replace_node(m.get_match_root(), zero_padded_conv);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m =
|
||||
std::make_shared<ngraph::pattern::Matcher>(conv_label, "CoreFusion.ZeroPaddedReshapedConv");
|
||||
this->add_matcher(m, callback);
|
||||
}
|
||||
|
||||
void pass::CoreFusion::construct_zero_padded_conv()
|
||||
{
|
||||
auto pad_input = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
||||
auto pad_value = std::make_shared<pattern::op::Label>(element::f32, Shape{});
|
||||
auto pad = std::make_shared<ngraph::op::Pad>(
|
||||
pad_input, pad_value, CoordinateDiff{0, 0, 0, 0}, CoordinateDiff{0, 0, 0, 0});
|
||||
auto pad_label = std::make_shared<pattern::op::Label>(pad, nullptr, NodeVector{pad});
|
||||
|
||||
auto conv_filter = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
||||
|
||||
auto conv = std::make_shared<ngraph::op::Convolution>(pad_label,
|
||||
conv_filter,
|
||||
Strides{1, 1},
|
||||
Strides{1, 1},
|
||||
CoordinateDiff{1, 1},
|
||||
CoordinateDiff{1, 1},
|
||||
Strides{1, 1});
|
||||
auto conv_label = std::make_shared<pattern::op::Label>(conv, nullptr, NodeVector{conv});
|
||||
|
||||
auto callback = [pad_input, pad_value, pad_label, conv_filter, conv_label](
|
||||
pattern::Matcher& m) {
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
|
||||
auto pad_value_op = as_type_ptr<ngraph::op::Constant>(pattern_map[pad_value]);
|
||||
if (!pad_value_op)
|
||||
{
|
||||
NGRAPH_DEBUG << "Pad value must be a constant";
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& matched_conv =
|
||||
std::static_pointer_cast<ngraph::op::Convolution>(pattern_map[conv_label]);
|
||||
const auto& matched_pad = std::static_pointer_cast<ngraph::op::Pad>(pattern_map[pad_label]);
|
||||
|
||||
if (!zero_padded_conv_consistency_check(m.get_match_root(),
|
||||
pad_value_op,
|
||||
pattern_map[pad_input],
|
||||
matched_pad,
|
||||
matched_conv->get_padding_below(),
|
||||
matched_conv->get_padding_above(),
|
||||
0,
|
||||
1))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
CoordinateDiff padding_below{
|
||||
static_cast<CoordinateDiff::value_type>(matched_pad->get_padding_below().at(2)),
|
||||
static_cast<CoordinateDiff::value_type>(matched_pad->get_padding_below().at(3))};
|
||||
CoordinateDiff padding_above{
|
||||
static_cast<CoordinateDiff::value_type>(matched_pad->get_padding_above().at(2)),
|
||||
static_cast<CoordinateDiff::value_type>(matched_pad->get_padding_above().at(3))};
|
||||
|
||||
auto zero_padded_conv =
|
||||
std::make_shared<ngraph::op::Convolution>(pattern_map[pad_input],
|
||||
pattern_map[conv_filter],
|
||||
matched_conv->get_window_movement_strides(),
|
||||
matched_conv->get_window_dilation_strides(),
|
||||
padding_below,
|
||||
padding_above,
|
||||
matched_conv->get_data_dilation_strides());
|
||||
|
||||
ngraph::replace_node(m.get_match_root(), zero_padded_conv);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(conv_label, "CoreFusion.ZeroPaddedConv");
|
||||
this->add_matcher(m, callback);
|
||||
}
|
@ -1,55 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/graph_rewrite.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class CoreFusion;
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::CoreFusion : public ngraph::pass::GraphRewrite
|
||||
{
|
||||
public:
|
||||
CoreFusion(FusionTypeMask fusions = FusionType::REGULAR_FUSIONS)
|
||||
: GraphRewrite()
|
||||
{
|
||||
if (fusions.is_set(FusionType::REGULAR_FUSIONS))
|
||||
{
|
||||
construct_relu();
|
||||
construct_folded_batch_norm();
|
||||
construct_conv_affine_folding();
|
||||
construct_sigmoid();
|
||||
construct_reshape_broadcast();
|
||||
construct_reshape_softmax_reshape();
|
||||
construct_zero_padded_reshaped_conv();
|
||||
construct_zero_padded_conv();
|
||||
}
|
||||
}
|
||||
void construct_relu();
|
||||
void construct_folded_batch_norm();
|
||||
void construct_conv_affine_folding();
|
||||
void construct_sigmoid();
|
||||
void construct_reshape_broadcast();
|
||||
void construct_reshape_softmax_reshape();
|
||||
void construct_zero_padded_reshaped_conv();
|
||||
void construct_zero_padded_conv();
|
||||
};
|
@ -1,323 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <typeinfo>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "cse.hpp"
|
||||
#include "ngraph/axis_vector.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/op/abs.hpp"
|
||||
#include "ngraph/op/abs.hpp"
|
||||
#include "ngraph/op/acos.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/asin.hpp"
|
||||
#include "ngraph/op/atan.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/ceiling.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/cos.hpp"
|
||||
#include "ngraph/op/cosh.hpp"
|
||||
#include "ngraph/op/divide.hpp"
|
||||
#include "ngraph/op/exp.hpp"
|
||||
#include "ngraph/op/floor.hpp"
|
||||
#include "ngraph/op/log.hpp"
|
||||
#include "ngraph/op/maximum.hpp"
|
||||
#include "ngraph/op/minimum.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/op/negative.hpp"
|
||||
#include "ngraph/op/one_hot.hpp"
|
||||
#include "ngraph/op/power.hpp"
|
||||
#include "ngraph/op/product.hpp"
|
||||
#include "ngraph/op/relu.hpp"
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/op/sigmoid.hpp"
|
||||
#include "ngraph/op/sign.hpp"
|
||||
#include "ngraph/op/sin.hpp"
|
||||
#include "ngraph/op/sinh.hpp"
|
||||
#include "ngraph/op/softmax.hpp"
|
||||
#include "ngraph/op/sqrt.hpp"
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/op/tan.hpp"
|
||||
#include "ngraph/op/tanh.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
#define TI(x) type_index(typeid(x))
|
||||
|
||||
static bool cse_constant(shared_ptr<Node> a, shared_ptr<Node> b)
|
||||
{
|
||||
NGRAPH_DEBUG << "In cse_constant for " << a->get_name() << " and " << b->get_name();
|
||||
|
||||
if (a->get_shape() != b->get_shape() || a->get_element_type() != b->get_element_type())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const op::Constant* ca = static_cast<op::Constant*>(a.get());
|
||||
const op::Constant* cb = static_cast<op::Constant*>(b.get());
|
||||
|
||||
size_t size = shape_size(a->get_shape()) * a->get_element_type().size();
|
||||
|
||||
if (ca->get_all_data_elements_bitwise_identical() ||
|
||||
cb->get_all_data_elements_bitwise_identical())
|
||||
{
|
||||
if (ca->get_all_data_elements_bitwise_identical() &&
|
||||
cb->get_all_data_elements_bitwise_identical())
|
||||
{
|
||||
// Since both Constants are uniform we only need to compare a single element
|
||||
return !memcmp(ca->get_data_ptr(), cb->get_data_ptr(), a->get_element_type().size());
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Neither Constant is uniform so compare all elements
|
||||
return !memcmp(ca->get_data_ptr(), cb->get_data_ptr(), size);
|
||||
}
|
||||
}
|
||||
|
||||
static bool cse_reshape(shared_ptr<Node> a, shared_ptr<Node> b)
|
||||
{
|
||||
NGRAPH_DEBUG << "In cse_reshape for " << a->get_name() << " and " << b->get_name();
|
||||
|
||||
const op::Reshape* reshape_a = static_cast<ngraph::op::Reshape*>(a.get());
|
||||
const op::Reshape* reshape_b = static_cast<ngraph::op::Reshape*>(b.get());
|
||||
|
||||
return (a->input_value(0) == b->input_value(0)) &&
|
||||
(reshape_a->get_input_order() == reshape_b->get_input_order()) &&
|
||||
(reshape_a->get_output_shape(0) == reshape_b->get_output_shape(0));
|
||||
}
|
||||
|
||||
static bool cse_broadcast(shared_ptr<Node> a, shared_ptr<Node> b)
|
||||
{
|
||||
NGRAPH_DEBUG << "In cse_broadcast for " << a->get_name() << " and " << b->get_name();
|
||||
|
||||
const op::Broadcast* broadcast_a = static_cast<ngraph::op::Broadcast*>(a.get());
|
||||
const op::Broadcast* broadcast_b = static_cast<ngraph::op::Broadcast*>(b.get());
|
||||
|
||||
return (a->input_value(0) == b->input_value(0)) &&
|
||||
(broadcast_a->get_broadcast_axes() == broadcast_b->get_broadcast_axes()) &&
|
||||
(broadcast_a->get_broadcast_shape() == broadcast_b->get_broadcast_shape());
|
||||
}
|
||||
|
||||
static bool cse_unarywise(shared_ptr<Node> a, shared_ptr<Node> b)
|
||||
{
|
||||
NGRAPH_DEBUG << "In cse_unarywise for " << a->get_name() << " and " << b->get_name();
|
||||
|
||||
return a->input_value(0) == b->input_value(0);
|
||||
}
|
||||
|
||||
static bool cse_binarywise(shared_ptr<Node> a, shared_ptr<Node> b)
|
||||
{
|
||||
NGRAPH_DEBUG << "In cse_binary for " << a->get_name() << " and " << b->get_name();
|
||||
|
||||
return (a->input_value(0) == b->input_value(0) && a->input_value(1) == b->input_value(1)) ||
|
||||
(a->input_value(1) == b->input_value(0) && a->input_value(0) == b->input_value(1));
|
||||
}
|
||||
|
||||
static bool cse_reduction(shared_ptr<Node> a, shared_ptr<Node> b)
|
||||
{
|
||||
NGRAPH_DEBUG << "In cse_reduction for " << a->get_name() << " and " << b->get_name();
|
||||
|
||||
const op::util::ArithmeticReduction* ar_a =
|
||||
static_cast<op::util::ArithmeticReduction*>(a.get());
|
||||
const op::util::ArithmeticReduction* ar_b =
|
||||
static_cast<op::util::ArithmeticReduction*>(b.get());
|
||||
|
||||
return ar_a->input_value(0) == ar_b->input_value(0) &&
|
||||
ar_a->get_reduction_axes() == ar_b->get_reduction_axes();
|
||||
}
|
||||
|
||||
static bool cse_one_hot(shared_ptr<Node> a, shared_ptr<Node> b)
|
||||
{
|
||||
NGRAPH_DEBUG << "In cse_one_hot for " << a->get_name() << " and " << b->get_name();
|
||||
|
||||
const op::OneHot* one_hot_a = static_cast<ngraph::op::OneHot*>(a.get());
|
||||
const op::OneHot* one_hot_b = static_cast<ngraph::op::OneHot*>(b.get());
|
||||
|
||||
return (a->input_value(0) == b->input_value(0)) &&
|
||||
(one_hot_a->get_one_hot_axis() == one_hot_b->get_one_hot_axis()) &&
|
||||
(a->get_shape() == b->get_shape());
|
||||
}
|
||||
|
||||
// To enable CSE for a new op, add a mapping between the op and a cse handler function to the map
|
||||
// below. If the op doesn't map to an existing handler, create a new handler to check if
|
||||
// all inputs and attributes for two nodes are exactly same.
|
||||
static unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>
|
||||
initialize_ops_to_cse_handlers()
|
||||
{
|
||||
return unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>(
|
||||
{{TI(op::Abs), cse_unarywise},
|
||||
{TI(op::Acos), cse_unarywise},
|
||||
{TI(op::Asin), cse_unarywise},
|
||||
{TI(op::Atan), cse_unarywise},
|
||||
{TI(op::Ceiling), cse_unarywise},
|
||||
{TI(op::Constant), cse_constant},
|
||||
{TI(op::Cos), cse_unarywise},
|
||||
{TI(op::Cosh), cse_unarywise},
|
||||
{TI(op::Exp), cse_unarywise},
|
||||
{TI(op::Floor), cse_unarywise},
|
||||
{TI(op::Log), cse_unarywise},
|
||||
{TI(op::Negative), cse_unarywise},
|
||||
{TI(op::OneHot), cse_one_hot},
|
||||
{TI(op::Relu), cse_unarywise},
|
||||
{TI(op::Sigmoid), cse_unarywise},
|
||||
{TI(op::Sign), cse_unarywise},
|
||||
{TI(op::Sin), cse_unarywise},
|
||||
{TI(op::Sinh), cse_unarywise},
|
||||
//{TI(op::Softmax), cse_unarywise},
|
||||
{TI(op::Sqrt), cse_unarywise},
|
||||
{TI(op::Tan), cse_unarywise},
|
||||
{TI(op::Tanh), cse_unarywise},
|
||||
{TI(op::Add), cse_binarywise},
|
||||
{TI(op::Divide), cse_binarywise},
|
||||
{TI(op::Maximum), cse_binarywise},
|
||||
{TI(op::Minimum), cse_binarywise},
|
||||
{TI(op::Multiply), cse_binarywise},
|
||||
{TI(op::Power), cse_binarywise},
|
||||
{TI(op::Subtract), cse_binarywise},
|
||||
{TI(op::Sum), cse_reduction},
|
||||
{TI(op::Product), cse_reduction},
|
||||
{TI(op::Reshape), cse_reshape},
|
||||
{TI(op::Broadcast), cse_broadcast}});
|
||||
}
|
||||
|
||||
static unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>
|
||||
ops_to_cse_handlers = initialize_ops_to_cse_handlers();
|
||||
|
||||
class NodeKey
|
||||
{
|
||||
public:
|
||||
NodeKey(const shared_ptr<Node>& n,
|
||||
unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>&
|
||||
backend_handlers)
|
||||
: m_node(n)
|
||||
, m_node_ref(*n)
|
||||
, m_ti(TI(m_node_ref))
|
||||
, m_backend_handlers(backend_handlers)
|
||||
{
|
||||
}
|
||||
|
||||
shared_ptr<Node> get_node() const { return m_node; }
|
||||
bool operator==(const NodeKey& other) const
|
||||
{
|
||||
if (m_ti == other.m_ti)
|
||||
{
|
||||
auto eh = ops_to_cse_handlers.find(m_ti);
|
||||
if (eh != ops_to_cse_handlers.end())
|
||||
{
|
||||
return eh->second(m_node, other.m_node);
|
||||
}
|
||||
|
||||
eh = m_backend_handlers.find(m_ti);
|
||||
if (eh != m_backend_handlers.end())
|
||||
{
|
||||
return eh->second(m_node, other.m_node);
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private:
|
||||
shared_ptr<Node> m_node;
|
||||
// m_node_ref is only to allow getting the type_index in the ctor
|
||||
Node& m_node_ref;
|
||||
std::type_index m_ti;
|
||||
unordered_map<type_index, function<bool(shared_ptr<Node>, shared_ptr<Node>)>>&
|
||||
m_backend_handlers;
|
||||
};
|
||||
|
||||
namespace std
|
||||
{
|
||||
template <>
|
||||
struct hash<NodeKey>
|
||||
{
|
||||
size_t operator()(const NodeKey& k) const
|
||||
{
|
||||
Node& p_this = *k.get_node().get();
|
||||
auto ti = TI(p_this);
|
||||
|
||||
hash<type_index> type_hash_compute{};
|
||||
auto type_hash = type_hash_compute(ti);
|
||||
|
||||
vector<size_t> arg_ids;
|
||||
|
||||
arg_ids.push_back(type_hash);
|
||||
|
||||
std::vector<Output<Node>> cargs;
|
||||
for (auto input : k.get_node()->inputs())
|
||||
{
|
||||
cargs.push_back(input.get_source_output());
|
||||
}
|
||||
|
||||
// TODO: Do we need another map, so we could
|
||||
// specify how to compute hash for each op?
|
||||
if (ngraph::op::is_commutative(&p_this))
|
||||
{
|
||||
sort(begin(cargs), end(cargs));
|
||||
}
|
||||
|
||||
for (auto arg : cargs)
|
||||
{
|
||||
arg_ids.push_back(arg.get_node_shared_ptr()->get_instance_id());
|
||||
arg_ids.push_back(arg.get_index());
|
||||
}
|
||||
|
||||
auto hashc = ngraph::hash_combine(arg_ids);
|
||||
return hashc;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
bool ngraph::pass::CommonSubexpressionElimination::run_on_function(shared_ptr<ngraph::Function> f)
|
||||
{
|
||||
bool replaced = false;
|
||||
unordered_map<NodeKey, shared_ptr<Node>> expressions{};
|
||||
|
||||
for (auto n : f->get_ordered_ops())
|
||||
{
|
||||
if (op::is_output(n) || op::is_parameter(n))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
NodeKey n_key(n, m_backend_cse_handlers);
|
||||
if (expressions.count(n_key))
|
||||
{
|
||||
ngraph::replace_node(n, expressions.at(n_key));
|
||||
replaced = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
expressions.insert(make_pair(n_key, n));
|
||||
}
|
||||
}
|
||||
|
||||
return replaced;
|
||||
}
|
@ -1,70 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class CommonSubexpressionElimination;
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief The Common Subexpression Elimination pass removes duplicate computations in a given
|
||||
/// computation graph.
|
||||
///
|
||||
/// Two computations are considered to be duplicates of each other if both apply the same operation
|
||||
/// to the same set of inputs, with the same attributes.
|
||||
///
|
||||
/// In the example shown below, the original graph has duplicate Add computations.
|
||||
/// After applying this pass, the graph is optimized to have only one Add computation.
|
||||
/// <table>
|
||||
/// <tr><th>Before the pass:</th>
|
||||
/// <th> After the pass</th>
|
||||
/// </tr>
|
||||
/// <tr>
|
||||
/// <td> \image html add_commutative_precse.svg </td>
|
||||
/// <td> \image html add_commutative_postcse.svg </td>
|
||||
/// </tr>
|
||||
/// </table>
|
||||
class NGRAPH_API ngraph::pass::CommonSubexpressionElimination : public FunctionPass
|
||||
{
|
||||
public:
|
||||
CommonSubexpressionElimination()
|
||||
: FunctionPass()
|
||||
{
|
||||
set_property(PassProperty::REQUIRE_STATIC_SHAPE, true);
|
||||
}
|
||||
|
||||
CommonSubexpressionElimination(
|
||||
const std::unordered_map<std::type_index,
|
||||
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>&
|
||||
backend_cse_handlers)
|
||||
: FunctionPass()
|
||||
, m_backend_cse_handlers(backend_cse_handlers)
|
||||
{
|
||||
set_property(PassProperty::REQUIRE_STATIC_SHAPE, true);
|
||||
}
|
||||
|
||||
std::unordered_map<std::type_index,
|
||||
std::function<bool(std::shared_ptr<Node>, std::shared_ptr<Node>)>>
|
||||
m_backend_cse_handlers;
|
||||
|
||||
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
|
||||
};
|
@ -1,77 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "ngraph/descriptor/input.hpp"
|
||||
#include "ngraph/descriptor/output.hpp"
|
||||
#include "ngraph/pass/dump_sorted.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
pass::DumpSorted::DumpSorted(const string& output_file)
|
||||
: m_output_file{output_file}
|
||||
{
|
||||
}
|
||||
|
||||
bool pass::DumpSorted::run_on_module(vector<shared_ptr<Function>>& functions)
|
||||
{
|
||||
ofstream out{m_output_file};
|
||||
if (out)
|
||||
{
|
||||
for (shared_ptr<Function> f : functions)
|
||||
{
|
||||
out << "=====================================================================\n";
|
||||
out << f->get_name() << " start\n";
|
||||
out << "=====================================================================\n";
|
||||
for (const shared_ptr<Node>& node : f->get_ordered_ops())
|
||||
{
|
||||
out << node->get_name() << "(";
|
||||
vector<string> inputs;
|
||||
for (auto& input : node->inputs())
|
||||
{
|
||||
inputs.push_back(input.get_tensor().get_name());
|
||||
}
|
||||
out << join(inputs);
|
||||
out << ") -> ";
|
||||
|
||||
vector<string> outputs;
|
||||
for (auto& output : node->outputs())
|
||||
{
|
||||
outputs.push_back(output.get_tensor().get_name());
|
||||
}
|
||||
out << join(outputs);
|
||||
out << "\n";
|
||||
|
||||
for (const descriptor::Tensor* tensor : node->liveness_new_list)
|
||||
{
|
||||
out << " N " << tensor->get_name() << "\n";
|
||||
}
|
||||
for (const descriptor::Tensor* tensor : node->liveness_free_list)
|
||||
{
|
||||
out << " F " << tensor->get_name() << "\n";
|
||||
}
|
||||
}
|
||||
out << "=====================================================================\n";
|
||||
out << f->get_name() << " end\n";
|
||||
out << "=====================================================================\n";
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
@ -1,40 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class DumpSorted;
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::DumpSorted : public ModulePass
|
||||
{
|
||||
public:
|
||||
DumpSorted(const std::string& output_file);
|
||||
|
||||
virtual bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
|
||||
|
||||
private:
|
||||
const std::string m_output_file;
|
||||
};
|
@ -27,7 +27,6 @@
|
||||
#include "ngraph/pass/graph_rewrite.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "ngraph/pass/serialize.hpp"
|
||||
#include "ngraph/pass/visualize_tree.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
@ -36,8 +35,6 @@ using namespace ngraph;
|
||||
|
||||
pass::Manager::Manager()
|
||||
: m_visualize(getenv_bool("NGRAPH_ENABLE_VISUALIZE_TRACING"))
|
||||
, m_serialize(getenv_bool("NGRAPH_ENABLE_SERIALIZE_TRACING"))
|
||||
|
||||
{
|
||||
}
|
||||
|
||||
@ -139,7 +136,7 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool /* transitive */)
|
||||
function_changed = call_graph_pass->run_on_call_graph(func->get_ordered_ops());
|
||||
}
|
||||
|
||||
if (m_visualize || m_serialize)
|
||||
if (m_visualize)
|
||||
{
|
||||
// visualizations and serializations will be named after the outermost function
|
||||
const size_t num_digits_in_pass_index = 3;
|
||||
@ -156,12 +153,6 @@ void pass::Manager::run_passes(shared_ptr<Function> func, bool /* transitive */)
|
||||
vt.set_ops_to_details(get_state().get_visualize_tree_ops_map());
|
||||
vt.run_on_module(f_array);
|
||||
}
|
||||
|
||||
if (m_serialize)
|
||||
{
|
||||
pass::Serialization st(base_filename + ".json");
|
||||
st.run_on_module(f_array);
|
||||
}
|
||||
}
|
||||
index++;
|
||||
pass_timer.stop();
|
||||
|
@ -58,7 +58,6 @@ public:
|
||||
PassConfig& get_pass_config() { return m_pass_config; }
|
||||
void set_pass_config(const PassConfig& pass_config) { m_pass_config = pass_config; }
|
||||
void set_pass_visualization(bool new_state) { m_visualize = new_state; }
|
||||
void set_pass_serialization(bool new_state) { m_serialize = new_state; }
|
||||
/// \brief Set flag to enable/disable running Validate pass after executing
|
||||
/// each registered pass
|
||||
/// \param new_state Value "true" enables Validate pass run; "false", otherwise
|
||||
@ -106,6 +105,5 @@ private:
|
||||
ManagerState m_state;
|
||||
PassConfig m_pass_config;
|
||||
bool m_visualize = false;
|
||||
bool m_serialize = false;
|
||||
bool m_per_pass_validation = true;
|
||||
};
|
||||
|
@ -1,293 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <exception>
|
||||
#include <sstream>
|
||||
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/get_output_element.hpp"
|
||||
#include "ngraph/op/slice.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/pass/liveness.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/memory_layout.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
pass::MemoryLayout::MemoryLayout(size_t alignment, bool disable_memory_sharing)
|
||||
: m_alignment(alignment)
|
||||
, m_disable_memory_sharing(disable_memory_sharing)
|
||||
{
|
||||
if (m_alignment == 0)
|
||||
{
|
||||
throw invalid_argument("Memory alignment must be > 0");
|
||||
}
|
||||
}
|
||||
|
||||
bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function)
|
||||
{
|
||||
MemoryManager mm(m_alignment, m_disable_memory_sharing);
|
||||
for (shared_ptr<Node> node : function->get_ordered_ops())
|
||||
{
|
||||
std::map<descriptor::Tensor*, descriptor::Tensor*> in_place_outputs;
|
||||
std::set<const descriptor::Tensor*> reused_inputs;
|
||||
|
||||
if (op::is_op(node))
|
||||
{
|
||||
auto op = std::static_pointer_cast<op::Op>(node);
|
||||
// concat and slice in_place_oi should be treated differently
|
||||
if (!is_type<op::Concat>(node) && !is_type<op::Slice>(node))
|
||||
{
|
||||
if (auto op_annotations = op->get_op_annotations())
|
||||
{
|
||||
for (auto oi_pair : op_annotations->get_in_place_oi_pairs())
|
||||
{
|
||||
auto output = &node->output(oi_pair.output).get_tensor();
|
||||
auto input = &node->get_input_tensor(oi_pair.input);
|
||||
auto input_node = node->get_input_node_ptr(oi_pair.input);
|
||||
|
||||
// For destructive kernel, this should be the last use
|
||||
// Non-destructive kernels can pass through if memory sharing is disabled
|
||||
if ((node->liveness_free_list.count(input) != 0 ||
|
||||
is_type<op::GetOutputElement>(node) ||
|
||||
(m_disable_memory_sharing && !oi_pair.destructive &&
|
||||
!op::is_parameter(input_node) && !op::is_constant(input_node))) &&
|
||||
node->liveness_new_list.count(output) != 0)
|
||||
|
||||
{
|
||||
NGRAPH_DEBUG << "Reusing " << input->get_name() << " for "
|
||||
<< output->get_name();
|
||||
in_place_outputs.insert({output, input});
|
||||
reused_inputs.insert(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (descriptor::Tensor* tensor : node->liveness_new_list)
|
||||
{
|
||||
size_t offset = in_place_outputs.count(tensor)
|
||||
? in_place_outputs.at(tensor)->get_pool_offset()
|
||||
: mm.allocate(tensor->size());
|
||||
tensor->set_pool_offset(offset);
|
||||
}
|
||||
|
||||
if (!m_disable_memory_sharing)
|
||||
{
|
||||
for (const descriptor::Tensor* tensor : node->liveness_free_list)
|
||||
{
|
||||
if (reused_inputs.count(tensor) == 0)
|
||||
{
|
||||
mm.free(tensor->get_pool_offset());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
function->set_temporary_pool_size(mm.max_allocated());
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
pass::MemoryManager::node::node(size_t size, block_state state)
|
||||
: m_size{size}
|
||||
, m_state{state}
|
||||
{
|
||||
}
|
||||
|
||||
pass::MemoryManager::MemoryManager(size_t alignment, bool disable_memory_reuse)
|
||||
: m_alignment{alignment}
|
||||
, m_scheme{disable_memory_reuse ? allocation_scheme::NO_REUSE : allocation_scheme::FIRST_FIT}
|
||||
, m_max_allocated{0}
|
||||
{
|
||||
if (m_alignment == 0)
|
||||
{
|
||||
throw invalid_argument("Memory alignment must be > 0");
|
||||
}
|
||||
m_node_list.emplace_back(numeric_limits<size_t>::max(), block_state::FREE);
|
||||
}
|
||||
|
||||
size_t pass::MemoryManager::allocate(size_t size)
|
||||
{
|
||||
size_t rc = 0;
|
||||
switch (m_scheme)
|
||||
{
|
||||
case allocation_scheme::FIRST_FIT: rc = first_fit(size); break;
|
||||
case allocation_scheme::BEST_FIT: rc = best_fit(size); break;
|
||||
case allocation_scheme::NO_REUSE: rc = no_reuse_allocator(size); break;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
size_t pass::MemoryManager::no_reuse_allocator(size_t size)
|
||||
{
|
||||
size_t offset = m_max_allocated;
|
||||
m_max_allocated += align(size, m_alignment);
|
||||
return offset;
|
||||
}
|
||||
|
||||
size_t pass::MemoryManager::best_fit(size_t size)
|
||||
{
|
||||
size = align(size, m_alignment);
|
||||
size_t offset = 0;
|
||||
size_t min_delta = numeric_limits<size_t>::max();
|
||||
auto best_fit = m_node_list.end();
|
||||
size_t best_offset = offset;
|
||||
for (auto it = m_node_list.begin(); it != m_node_list.end(); ++it)
|
||||
{
|
||||
if (it->m_state == block_state::FREE && it->m_size >= size)
|
||||
{
|
||||
size_t delta = it->m_size - size;
|
||||
if (delta < min_delta)
|
||||
{
|
||||
min_delta = delta;
|
||||
best_fit = it;
|
||||
best_offset = offset;
|
||||
}
|
||||
}
|
||||
offset += it->m_size;
|
||||
}
|
||||
|
||||
if (best_fit == m_node_list.end())
|
||||
{
|
||||
throw bad_alloc();
|
||||
}
|
||||
|
||||
if (min_delta == 0)
|
||||
{
|
||||
// exact fit
|
||||
best_fit->m_state = block_state::ALLOCATED;
|
||||
}
|
||||
else
|
||||
{
|
||||
m_node_list.insert(best_fit, node{size, block_state::ALLOCATED});
|
||||
best_fit->m_size -= size;
|
||||
}
|
||||
m_max_allocated = max(m_max_allocated, best_offset + size);
|
||||
|
||||
return best_offset;
|
||||
}
|
||||
|
||||
size_t pass::MemoryManager::first_fit(size_t size)
|
||||
{
|
||||
size = align(size, m_alignment);
|
||||
size_t offset = 0;
|
||||
bool found = false;
|
||||
for (auto it = m_node_list.begin(); it != m_node_list.end(); ++it)
|
||||
{
|
||||
if (it->m_state == block_state::FREE && it->m_size >= size)
|
||||
{
|
||||
if (it->m_size > size)
|
||||
{
|
||||
m_node_list.insert(it, node{size, block_state::ALLOCATED});
|
||||
it->m_size -= size;
|
||||
}
|
||||
else
|
||||
{
|
||||
// exact fit
|
||||
it->m_state = block_state::ALLOCATED;
|
||||
}
|
||||
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
offset += it->m_size;
|
||||
}
|
||||
if (!found)
|
||||
{
|
||||
throw bad_alloc();
|
||||
}
|
||||
m_max_allocated = max(m_max_allocated, offset + size);
|
||||
|
||||
return offset;
|
||||
}
|
||||
|
||||
void pass::MemoryManager::free(size_t offset)
|
||||
{
|
||||
size_t search_offset = 0;
|
||||
bool found = false;
|
||||
for (auto it = m_node_list.begin(); it != m_node_list.end(); ++it)
|
||||
{
|
||||
if (offset == search_offset)
|
||||
{
|
||||
list<node>::iterator it_next = next(it);
|
||||
if (it == m_node_list.begin())
|
||||
{
|
||||
// free the first node in the list
|
||||
it->m_state = block_state::FREE;
|
||||
}
|
||||
else
|
||||
{
|
||||
// node has predecessor
|
||||
list<node>::iterator it_prev = prev(it);
|
||||
if (it_prev->m_state == block_state::FREE)
|
||||
{
|
||||
it->m_size += it_prev->m_size;
|
||||
m_node_list.erase(it_prev);
|
||||
}
|
||||
}
|
||||
if (it_next != m_node_list.end() && it_next->m_state == block_state::FREE)
|
||||
{
|
||||
// join this node with next
|
||||
it->m_size += it_next->m_size;
|
||||
m_node_list.erase(it_next);
|
||||
}
|
||||
it->m_state = block_state::FREE;
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
search_offset += it->m_size;
|
||||
}
|
||||
if (!found)
|
||||
{
|
||||
throw runtime_error("bad free");
|
||||
}
|
||||
}
|
||||
|
||||
void pass::MemoryManager::dump(ostream& out)
|
||||
{
|
||||
for (const node& n : m_node_list)
|
||||
{
|
||||
out << "size=" << n.m_size << ", ";
|
||||
out << (n.m_state == block_state::FREE ? "FREE" : "ALLOCATED");
|
||||
out << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
size_t pass::MemoryManager::align(size_t size, size_t alignment)
|
||||
{
|
||||
if (alignment == 0)
|
||||
{
|
||||
throw invalid_argument("alignment must be > 0");
|
||||
}
|
||||
if (size == 0)
|
||||
{
|
||||
size = alignment;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto remainder = size % alignment;
|
||||
if (remainder > 0)
|
||||
{
|
||||
size += (alignment - remainder);
|
||||
}
|
||||
}
|
||||
return size;
|
||||
}
|
@ -1,97 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
#include <list>
|
||||
#include <sstream>
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class MemoryLayout;
|
||||
class MemoryNode;
|
||||
class MemoryManager;
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::MemoryLayout : public FunctionPass
|
||||
{
|
||||
public:
|
||||
MemoryLayout(size_t alignment = 1, bool disable_memory_sharing = false);
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
|
||||
|
||||
private:
|
||||
size_t m_alignment;
|
||||
bool m_disable_memory_sharing;
|
||||
};
|
||||
|
||||
class NGRAPH_API ngraph::pass::MemoryManager
|
||||
{
|
||||
public:
|
||||
enum class block_state
|
||||
{
|
||||
FREE,
|
||||
ALLOCATED
|
||||
};
|
||||
|
||||
enum class allocation_scheme
|
||||
{
|
||||
FIRST_FIT,
|
||||
BEST_FIT,
|
||||
NO_REUSE
|
||||
};
|
||||
|
||||
class node
|
||||
{
|
||||
public:
|
||||
node(size_t size, block_state state);
|
||||
|
||||
bool is_free() const { return m_state == block_state::FREE; }
|
||||
size_t m_size;
|
||||
block_state m_state;
|
||||
};
|
||||
|
||||
MemoryManager(size_t alignment = 1, bool disable_reuse = false);
|
||||
// memory_manager& alignment(size_t a);
|
||||
|
||||
size_t allocate(size_t size);
|
||||
void free(size_t offset);
|
||||
|
||||
void dump(std::ostream&);
|
||||
|
||||
static size_t align(size_t x, size_t alignment);
|
||||
|
||||
std::list<node>::iterator begin() { return m_node_list.begin(); }
|
||||
std::list<node>::iterator end() { return m_node_list.end(); }
|
||||
std::list<node>::const_iterator begin() const { return m_node_list.cbegin(); }
|
||||
std::list<node>::const_iterator end() const { return m_node_list.cend(); }
|
||||
const std::list<node>& get_node_list() const { return m_node_list; }
|
||||
size_t max_allocated() const { return m_max_allocated; }
|
||||
private:
|
||||
size_t first_fit(size_t size);
|
||||
size_t best_fit(size_t size);
|
||||
size_t no_reuse_allocator(size_t size);
|
||||
|
||||
std::list<node> m_node_list;
|
||||
size_t m_alignment;
|
||||
allocation_scheme m_scheme;
|
||||
size_t m_max_allocated;
|
||||
};
|
@ -1,266 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "memory_visualize.hpp"
|
||||
#include "ngraph/descriptor/tensor.hpp"
|
||||
#include "ngraph/function.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
pass::MemoryVisualize::MemoryVisualize(const string& filename)
|
||||
: m_filename{filename}
|
||||
{
|
||||
}
|
||||
|
||||
bool pass::MemoryVisualize::run_on_module(vector<shared_ptr<Function>>& functions)
|
||||
{
|
||||
ofstream file(m_filename);
|
||||
{
|
||||
for (shared_ptr<Function> f : functions)
|
||||
{
|
||||
vector<shared_ptr<Node>> nodes = f->get_ordered_ops();
|
||||
file << "<!DOCTYPE html>\n<html>\n";
|
||||
file << "<head>\n";
|
||||
file << " <style>\n";
|
||||
file << " th, td {\n";
|
||||
file << " border-bottom: 1px solid #ddd;\n";
|
||||
file << " width: 200px;\n";
|
||||
file << " }\n";
|
||||
file << " table, td, th {\n";
|
||||
// file << " border: 1px solid #ddd;\n";
|
||||
// file << " text-align: left;\n";
|
||||
file << " }\n";
|
||||
file << " table {\n";
|
||||
file << " border-collapse: collapse;\n";
|
||||
// file << " width: 100%;\n";
|
||||
file << " }\n";
|
||||
// file << " tr:hover {background-color: #f5f5f5}\n";
|
||||
file << " tr:nth-child(even) {background-color: #f2f2f2}\n";
|
||||
file << " </style>\n";
|
||||
file << "</head>\n";
|
||||
|
||||
file << "<body>\n";
|
||||
unordered_set<descriptor::Tensor*> tensors;
|
||||
size_t temp_max_size = 0;
|
||||
for (shared_ptr<Node> node : nodes)
|
||||
{
|
||||
tensors.insert(node->liveness_new_list.begin(), node->liveness_new_list.end());
|
||||
}
|
||||
for (descriptor::Tensor* tensor : tensors)
|
||||
{
|
||||
temp_max_size += tensor->size();
|
||||
}
|
||||
|
||||
// file << "<table>\n";
|
||||
// file << "<tr><td>Persistent Memory Footprint</td><td align=\"right\">";
|
||||
// file << computation_decl.exop_block.persistent_size() << "</td></tr>\n";
|
||||
// file << "<tr><td>Temporary Memory Footprint</td><td align=\"right\">";
|
||||
// file << computation_decl.exop_block.memory_footprint() << "</td></tr>\n";
|
||||
// file << "<tr><td>Max temporary Memory Footprint</td><td align=\"right\">";
|
||||
// file << temp_max_size << "</td></tr>\n";
|
||||
// file << "</table>\n";
|
||||
|
||||
file << "<hr>\n";
|
||||
draw_tensor_weight(file, nodes);
|
||||
// file << "<hr>\n";
|
||||
// draw_op_influence(file);
|
||||
file << "<hr>\n";
|
||||
draw_histogram(file, nodes);
|
||||
// file << "<hr>\n";
|
||||
file << "</body>\n</html>\n";
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
unordered_set<const descriptor::Tensor*>
|
||||
pass::MemoryVisualize::find_largest_op(const vector<shared_ptr<Node>>& nodes)
|
||||
{
|
||||
size_t largest_size = 0;
|
||||
unordered_set<const descriptor::Tensor*> liveness_list;
|
||||
unordered_set<const descriptor::Tensor*> largest_live_list;
|
||||
for (shared_ptr<Node> exop : nodes)
|
||||
{
|
||||
size_t size = 0;
|
||||
for (const descriptor::Tensor* tensor : exop->liveness_new_list)
|
||||
{
|
||||
liveness_list.insert(tensor);
|
||||
size += tensor->size();
|
||||
}
|
||||
for (const descriptor::Tensor* tensor : liveness_list)
|
||||
{
|
||||
size += tensor->size();
|
||||
}
|
||||
if (size > largest_size)
|
||||
{
|
||||
largest_size = size;
|
||||
largest_live_list = liveness_list;
|
||||
}
|
||||
}
|
||||
return largest_live_list;
|
||||
}
|
||||
|
||||
void pass::MemoryVisualize::draw_tensor_weight(ostream& file, const vector<shared_ptr<Node>>& nodes)
|
||||
{
|
||||
unordered_set<const descriptor::Tensor*> largest_live_list = find_largest_op(nodes);
|
||||
|
||||
unordered_map<const descriptor::Tensor*, size_t> age_list;
|
||||
vector<const descriptor::Tensor*> tensor_set;
|
||||
unordered_map<const descriptor::Tensor*, shared_ptr<Node>> generator_op;
|
||||
file << "<table>\n";
|
||||
file << " <tr>";
|
||||
file << "<th align=\"left\">tensor</th>";
|
||||
file << "<th align=\"right\">size</th>";
|
||||
file << "<th align=\"right\">age</th>";
|
||||
file << "<th align=\"right\">generator weight</th>";
|
||||
file << "</tr>\n";
|
||||
size_t i = 0;
|
||||
for (shared_ptr<Node> exop : nodes)
|
||||
{
|
||||
for (const descriptor::Tensor* tensor : exop->liveness_new_list)
|
||||
{
|
||||
age_list[tensor] = i;
|
||||
generator_op[tensor] = exop;
|
||||
}
|
||||
for (const descriptor::Tensor* tensor : exop->liveness_free_list)
|
||||
{
|
||||
size_t start = age_list[tensor];
|
||||
age_list[tensor] = (i - start);
|
||||
tensor_set.push_back(tensor);
|
||||
}
|
||||
i++;
|
||||
}
|
||||
sort(tensor_set.begin(),
|
||||
tensor_set.end(),
|
||||
[](const descriptor::Tensor* t1, const descriptor::Tensor* t2) {
|
||||
return t1->size() < t2->size();
|
||||
});
|
||||
for (const descriptor::Tensor* tensor : tensor_set)
|
||||
{
|
||||
int generator_weight = compute_op_weight(generator_op[tensor]);
|
||||
if (largest_live_list.find(tensor) != largest_live_list.end())
|
||||
{
|
||||
file << " <tr style=\"background-color: #f0c0f0\">";
|
||||
}
|
||||
else
|
||||
{
|
||||
file << " <tr>";
|
||||
}
|
||||
file << "<td>" << tensor->get_name() << "</td>";
|
||||
file << "<td align=\"right\">" << tensor->size() << "</td>";
|
||||
file << "<td align=\"right\">" << age_list[tensor] << "</td>";
|
||||
file << "<td align=\"right\">" << generator_weight << "/td>";
|
||||
file << "</tr>\n";
|
||||
}
|
||||
|
||||
file << "</table>\n";
|
||||
}
|
||||
|
||||
void pass::MemoryVisualize::draw_histogram(ostream& file, const vector<shared_ptr<Node>>& nodes)
|
||||
{
|
||||
size_t stroke_width = 14;
|
||||
size_t text_offset = 4;
|
||||
size_t offset = 200;
|
||||
size_t width = 1000;
|
||||
size_t scale = width - offset;
|
||||
size_t line_spacing = static_cast<size_t>(stroke_width * 1.5);
|
||||
size_t line_count = 0;
|
||||
for (shared_ptr<Node> node : nodes)
|
||||
{
|
||||
(void)node;
|
||||
line_count += 1;
|
||||
}
|
||||
size_t height = line_count * line_spacing + stroke_width;
|
||||
size_t memory_footprint = max<size_t>(1, MemoryVisualize::memory_footprint(nodes));
|
||||
|
||||
file << "<svg viewBox=\"0 0 " << width << " " << height << "\">\n";
|
||||
size_t y = 0;
|
||||
for (shared_ptr<Node> node : nodes)
|
||||
{
|
||||
float usage = float(MemoryVisualize::memory_usage(node));
|
||||
float footprint = float(MemoryVisualize::memory_footprint(node));
|
||||
y += line_spacing;
|
||||
size_t x1 = offset;
|
||||
size_t x2 = static_cast<size_t>(((usage / memory_footprint) * scale) + offset);
|
||||
file << "<text x=\"" << 0 << "\" y=\"" << y + text_offset << "\" fill=\""
|
||||
<< "black"
|
||||
<< "\">" << node->get_name() << "</text>\n";
|
||||
file << "<line x1=\"" << x1 << "\" y1=\"" << y << "\" x2=\"" << x2 << "\" y2=\"" << y
|
||||
<< "\"";
|
||||
file << " style=\"stroke:forestgreen;stroke-width:" << stroke_width << "\" />\n";
|
||||
x1 = x2;
|
||||
x2 = static_cast<size_t>(((footprint / memory_footprint) * scale) + offset);
|
||||
file << "<line x1=\"" << x1 << "\" y1=\"" << y << "\" x2=\"" << x2 << "\" y2=\"" << y
|
||||
<< "\"";
|
||||
file << " style=\"stroke:firebrick;stroke-width:" << stroke_width << "\" />\n";
|
||||
}
|
||||
file << "</svg>\n";
|
||||
}
|
||||
|
||||
void pass::MemoryVisualize::draw_op_influence(ostream& file, const vector<shared_ptr<Node>>& nodes)
|
||||
{
|
||||
file << "<table>\n";
|
||||
file << " <tr>";
|
||||
file << "<th align=\"left\">op</th>";
|
||||
file << "<th align=\"right\">influence</th>";
|
||||
file << "</tr>\n";
|
||||
for (shared_ptr<Node> exop : nodes)
|
||||
{
|
||||
int weight = compute_op_weight(exop);
|
||||
file << " <tr>";
|
||||
file << "<td>" << exop->get_name() << "</td>";
|
||||
file << "<td align=\"right\">" << weight << "</td>";
|
||||
file << "</tr>\n";
|
||||
}
|
||||
}
|
||||
|
||||
int pass::MemoryVisualize::compute_op_weight(const shared_ptr<Node> exop)
|
||||
{
|
||||
int mass = 0;
|
||||
for (const descriptor::Tensor* tensor : exop->liveness_new_list)
|
||||
{
|
||||
mass += static_cast<int>(tensor->size());
|
||||
}
|
||||
for (const descriptor::Tensor* tensor : exop->liveness_free_list)
|
||||
{
|
||||
mass -= static_cast<int>(tensor->size());
|
||||
}
|
||||
return mass;
|
||||
}
|
||||
|
||||
size_t pass::MemoryVisualize::memory_usage(shared_ptr<Node> /* node */)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t pass::MemoryVisualize::memory_footprint(shared_ptr<Node> /* node */)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
size_t pass::MemoryVisualize::memory_footprint(const std::vector<shared_ptr<Node>>& /* nodes */)
|
||||
{
|
||||
return 0;
|
||||
}
|
@ -1,52 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <list>
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class MemoryVisualize;
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::MemoryVisualize : public ModulePass
|
||||
{
|
||||
public:
|
||||
MemoryVisualize(const std::string& filename);
|
||||
virtual bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
|
||||
|
||||
private:
|
||||
std::unordered_set<const descriptor::Tensor*>
|
||||
find_largest_op(const std::vector<std::shared_ptr<Node>>& nodes);
|
||||
void draw_tensor_weight(std::ostream& file, const std::vector<std::shared_ptr<Node>>& nodes);
|
||||
void draw_histogram(std::ostream& file, const std::vector<std::shared_ptr<Node>>& nodes);
|
||||
void draw_op_influence(std::ostream& file, const std::vector<std::shared_ptr<Node>>& nodes);
|
||||
int compute_op_weight(std::shared_ptr<Node> exop);
|
||||
|
||||
static size_t memory_usage(std::shared_ptr<Node>);
|
||||
static size_t memory_footprint(std::shared_ptr<Node>);
|
||||
static size_t memory_footprint(const std::vector<std::shared_ptr<Node>>&);
|
||||
|
||||
const std::string m_filename;
|
||||
};
|
@ -1,76 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/pass/propagate_cacheability.hpp"
|
||||
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/util/op_annotations.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
bool pass::PropagateCacheability::run_on_function(shared_ptr<Function> function)
|
||||
{
|
||||
for (auto& node : function->get_ordered_ops())
|
||||
{
|
||||
if (op::is_op(node))
|
||||
{
|
||||
auto op = static_pointer_cast<op::Op>(node);
|
||||
NGRAPH_DEBUG << "propagate cacheability: node is " << node->get_name();
|
||||
auto op_annotations = op->get_op_annotations();
|
||||
if (!op_annotations)
|
||||
{
|
||||
NGRAPH_DEBUG << "propagate cacheability: create op_annotations";
|
||||
op_annotations = op_annotations_factory();
|
||||
op->set_op_annotations(op_annotations);
|
||||
}
|
||||
if (op::is_parameter(node))
|
||||
{
|
||||
auto parameter = static_pointer_cast<op::Parameter>(node);
|
||||
op_annotations->set_cacheable(parameter->get_cacheable());
|
||||
NGRAPH_DEBUG << "propagate cacheability: cacheability is "
|
||||
<< parameter->get_cacheable();
|
||||
}
|
||||
else
|
||||
{
|
||||
bool cacheable = true;
|
||||
for (auto input : node->inputs())
|
||||
{
|
||||
auto input_value_node = input.get_source_output().get_node_shared_ptr();
|
||||
NGRAPH_DEBUG << "propagate cacheability: arg is " << *input_value_node;
|
||||
if (op::is_op(input_value_node))
|
||||
{
|
||||
auto arg_op = static_pointer_cast<op::Op>(input_value_node);
|
||||
auto arg_op_annotations = arg_op->get_op_annotations();
|
||||
NGRAPH_CHECK(arg_op_annotations);
|
||||
if (!arg_op_annotations->is_cacheable())
|
||||
{
|
||||
cacheable = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
NGRAPH_DEBUG << "propagate cacheability: cacheability is " << cacheable;
|
||||
op_annotations->set_cacheable(cacheable);
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
@ -1,52 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class PropagateCacheability;
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::PropagateCacheability : public FunctionPass
|
||||
{
|
||||
public:
|
||||
PropagateCacheability()
|
||||
: FunctionPass()
|
||||
{
|
||||
}
|
||||
|
||||
PropagateCacheability(
|
||||
std::function<std::shared_ptr<ngraph::op::util::OpAnnotations>(void)> func)
|
||||
: FunctionPass()
|
||||
, op_annotations_factory(func)
|
||||
{
|
||||
}
|
||||
|
||||
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
|
||||
|
||||
private:
|
||||
std::function<std::shared_ptr<ngraph::op::util::OpAnnotations>(void)> op_annotations_factory =
|
||||
[]() -> std::shared_ptr<ngraph::op::util::OpAnnotations> {
|
||||
auto op_annotations = std::make_shared<ngraph::op::util::OpAnnotations>();
|
||||
return op_annotations;
|
||||
};
|
||||
};
|
@ -1,304 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "reshape_elimination.hpp"
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/dot.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
#include "ngraph/pattern/op/label.hpp"
|
||||
#include "ngraph/pattern/op/skip.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
void pass::ReshapeElimination::construct_identity_reshape_pattern()
|
||||
{
|
||||
Shape shape_op{3};
|
||||
Shape shape_r1{1, 3};
|
||||
|
||||
auto op = make_shared<pattern::op::Label>(element::f32, shape_op);
|
||||
auto reshape1 = make_shared<op::Reshape>(op, AxisVector{0}, shape_r1);
|
||||
|
||||
auto callback = [op](pattern::Matcher& m) {
|
||||
NGRAPH_DEBUG << "In callback for construct_identity_reshape_pattern against node = "
|
||||
<< m.get_match_root()->get_name();
|
||||
auto pattern_map = m.get_pattern_value_map();
|
||||
auto gop = pattern_map[op];
|
||||
|
||||
auto r1 = as_type_ptr<op::Reshape>(m.get_match_root());
|
||||
|
||||
if (r1->get_shape() != gop.get_shape())
|
||||
{
|
||||
NGRAPH_DEBUG << "Not a no-op; Shapes are different!";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto do_r1 = get_default_order(r1->get_shape());
|
||||
|
||||
if (do_r1 != r1->get_input_order())
|
||||
{
|
||||
NGRAPH_DEBUG << "Not a no-op; Not in default input order!";
|
||||
return false;
|
||||
}
|
||||
|
||||
m.get_match_value().replace(gop);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = make_shared<pattern::Matcher>(reshape1);
|
||||
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
|
||||
}
|
||||
|
||||
void pass::ReshapeElimination::construct_reshapex2_pattern()
|
||||
{
|
||||
Shape shape_op{3};
|
||||
Shape shape_r1{1, 3};
|
||||
|
||||
auto op = make_shared<pattern::op::Label>(element::f32, shape_op);
|
||||
auto reshape1 = make_shared<op::Reshape>(op, AxisVector{0}, shape_r1);
|
||||
auto reshape2 = make_shared<op::Reshape>(reshape1, AxisVector{0, 1}, shape_op);
|
||||
|
||||
auto callback = [op](pattern::Matcher& m) {
|
||||
NGRAPH_DEBUG << "In callback for construct_reshapex2_pattern against node = "
|
||||
<< m.get_match_root()->get_name();
|
||||
auto pattern_map = m.get_pattern_map();
|
||||
|
||||
auto gop = pattern_map[op];
|
||||
|
||||
auto r2 = static_pointer_cast<op::Reshape>(m.get_match_root());
|
||||
auto r1 = static_pointer_cast<op::Reshape>(r2->input_value(0).get_node_shared_ptr());
|
||||
|
||||
if (gop->get_shape() != m.get_match_root()->get_shape())
|
||||
{
|
||||
// First reshape transposes and second reshape only changes shape
|
||||
// Replace with a transpose that changes shape
|
||||
if (apply_permutation(gop->get_shape(), r1->get_input_order()) == r2->get_shape() &&
|
||||
r2->get_input_order() == get_default_order(r1->get_shape()) &&
|
||||
r1->get_users().size() == 1)
|
||||
{
|
||||
replace_node(m.get_match_root(),
|
||||
make_shared<op::Reshape>(gop, r1->get_input_order(), r2->get_shape()));
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
NGRAPH_DEBUG << "Operand shape doesn't match the shape of the second reshape!";
|
||||
NGRAPH_DEBUG << "gop " << gop->get_name()
|
||||
<< "shape = " << vector_to_string(gop->get_shape());
|
||||
NGRAPH_DEBUG << "match_root " << m.get_match_root()->get_name()
|
||||
<< "shape = " << vector_to_string(m.get_match_root()->get_shape());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check for sequence of reshapes/transposes that cancel out.
|
||||
auto do_r2 = get_default_order(r1->get_shape());
|
||||
auto do_r1 = get_default_order(gop->get_shape());
|
||||
|
||||
NGRAPH_DEBUG << "r1's i/o = " << vector_to_string(r1->get_input_order())
|
||||
<< "do_r1 = " << vector_to_string(do_r1);
|
||||
NGRAPH_DEBUG << "r2's i/o = " << vector_to_string(r2->get_input_order())
|
||||
<< "do_r2 = " << vector_to_string(do_r2);
|
||||
|
||||
if (r1->get_input_order() == do_r1 && r2->get_input_order() == do_r2)
|
||||
{
|
||||
NGRAPH_DEBUG << "Two reshapes were removed!";
|
||||
replace_node(m.get_match_root(), gop);
|
||||
return true;
|
||||
}
|
||||
|
||||
auto perm1 = apply_permutation(do_r1, r1->get_input_order());
|
||||
auto perm2 = apply_permutation(perm1, r2->get_input_order());
|
||||
if (perm2 == do_r1)
|
||||
{
|
||||
NGRAPH_DEBUG << "Two transposes were removed!";
|
||||
replace_node(m.get_match_root(), gop);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
auto m = make_shared<pattern::Matcher>(reshape2);
|
||||
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
|
||||
}
|
||||
|
||||
void pass::ReshapeElimination::construct_dot_transpose_pattern()
|
||||
{
|
||||
// dot(A,B).T = dot (B.T, A.T)
|
||||
auto dot_pred = [](shared_ptr<Node> n) { return is_type<op::Dot>(n); };
|
||||
|
||||
auto pdot = make_shared<pattern::op::Label>(element::f32, Shape{2, 1}, dot_pred);
|
||||
auto preshape = make_shared<op::Reshape>(pdot, AxisVector{1, 0}, Shape{1, 2});
|
||||
|
||||
auto callback = [](pattern::Matcher& m) {
|
||||
NGRAPH_DEBUG << "In callback for construct_dot_transpose_pattern against node = "
|
||||
<< m.get_match_root()->get_name();
|
||||
|
||||
auto mtranspose = static_pointer_cast<op::Reshape>(m.get_match_root());
|
||||
// this also checks the rank
|
||||
if (mtranspose->get_input_order() != AxisVector{1, 0})
|
||||
{
|
||||
NGRAPH_DEBUG << "Reshape isn't transpose. "
|
||||
<< vector_to_string(mtranspose->get_input_order());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto mdot = mtranspose->input_value(0).get_node_shared_ptr();
|
||||
if (mdot->get_shape().size() != 2)
|
||||
{
|
||||
NGRAPH_DEBUG << "Dot has the wrong shape. " << vector_to_string(mdot->get_shape());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto arg0 = mdot->input_value(0).get_node_shared_ptr();
|
||||
if (arg0->get_shape().size() != 2)
|
||||
{
|
||||
NGRAPH_DEBUG << "Arg0 has the wrong shape. " << vector_to_string(arg0->get_shape());
|
||||
return false;
|
||||
}
|
||||
auto reshape0_shape = Shape{arg0->get_shape().at(1), arg0->get_shape().at(0)};
|
||||
auto reshape0 = make_shared<op::Reshape>(arg0, AxisVector{1, 0}, reshape0_shape);
|
||||
|
||||
auto arg1 = mdot->input_value(1).get_node_shared_ptr();
|
||||
if (arg1->get_shape().size() != 2)
|
||||
{
|
||||
NGRAPH_DEBUG << "Arg1 has the wrong shape. " << vector_to_string(arg1->get_shape());
|
||||
return false;
|
||||
}
|
||||
auto reshape1_shape = Shape{arg1->get_shape().at(1), arg1->get_shape().at(0)};
|
||||
auto reshape1 = make_shared<op::Reshape>(arg1, AxisVector{1, 0}, reshape1_shape);
|
||||
|
||||
auto tdot = shared_ptr<Node>(new op::Dot(reshape1, reshape0));
|
||||
replace_node(m.get_match_root(), tdot);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = make_shared<pattern::Matcher>(preshape);
|
||||
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
|
||||
}
|
||||
|
||||
void pass::RecurrentReshapeElimination::construct_recurrent_reshape()
|
||||
{
|
||||
Shape shape_op{3};
|
||||
Shape shape_r{1, 3};
|
||||
|
||||
auto op = make_shared<pattern::op::Label>(element::f32, shape_op);
|
||||
auto reshape = make_shared<op::Reshape>(op, AxisVector{0}, shape_r);
|
||||
auto reshape_label =
|
||||
make_shared<pattern::op::Label>(reshape, get_no_fan_out_function(), NodeVector{reshape});
|
||||
|
||||
auto callback = [op, reshape_label](pattern::RecurrentMatcher& m) {
|
||||
NGRAPH_DEBUG << "In callback for construct_recurrent_reshape against node = "
|
||||
<< reshape_label->input_value(0).get_node_shared_ptr()->get_name();
|
||||
auto reshape_node_vector = m.get_bound_nodes_for_pattern(reshape_label);
|
||||
|
||||
// The bound node vector is in reverse order. It is convenient to have the
|
||||
// bound node vector in the correct order
|
||||
std::reverse(std::begin(reshape_node_vector), std::end(reshape_node_vector));
|
||||
|
||||
auto first_bound_reshape_op = reshape_node_vector.front();
|
||||
auto driver_op = first_bound_reshape_op->input_value(0);
|
||||
auto last_bound_reshape_op = reshape_node_vector.back();
|
||||
|
||||
// Need to check if the user of the last bound op is a reshape since the last reshape is
|
||||
// allowed to have fan-out but the matcher will discard any reshape if it has fan-out
|
||||
auto user_of_last_bound_reshape_op = last_bound_reshape_op->get_users(true)[0];
|
||||
if (is_type<op::Reshape>(user_of_last_bound_reshape_op))
|
||||
{
|
||||
reshape_node_vector.push_back(user_of_last_bound_reshape_op);
|
||||
last_bound_reshape_op = reshape_node_vector.back();
|
||||
}
|
||||
|
||||
// Return if the recurrent matcher matches only one reshape
|
||||
if (reshape_node_vector.size() == 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// The complete reshape node vector may not contain contiguous reshapes that can be
|
||||
// fused. Only the subset of reshapes with a reshape(any axis order) followed by reshapes
|
||||
// with default axis order can be fused. Creating such subpatterns here:
|
||||
std::vector<NodeVector> sub_patterns{NodeVector{first_bound_reshape_op}};
|
||||
for (auto it = std::next(reshape_node_vector.begin()); it != reshape_node_vector.end();
|
||||
it++)
|
||||
{
|
||||
auto r = as_type_ptr<op::Reshape>(*it);
|
||||
|
||||
// Check that the input to r is the last reshape stored in the
|
||||
// subpattern vector
|
||||
if (!r)
|
||||
{
|
||||
NGRAPH_DEBUG
|
||||
<< "Incorrect match. Something went wrong. Non-reshape op has been matched";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto default_order_r = get_default_order(r->get_input_shape(0));
|
||||
if (r->get_input_order() == default_order_r)
|
||||
{
|
||||
sub_patterns.back().push_back(r);
|
||||
}
|
||||
else
|
||||
{
|
||||
NGRAPH_DEBUG << r->get_name() << "does not have default axis order. "
|
||||
<< "It might be part of a different subpattern";
|
||||
sub_patterns.push_back(NodeVector{r});
|
||||
}
|
||||
}
|
||||
|
||||
bool modify_graph = false;
|
||||
|
||||
// Replace the patterns
|
||||
for (auto sub_pattern : sub_patterns)
|
||||
{
|
||||
// Do not consider subpatterns with just one reshape in them
|
||||
if (sub_pattern.size() == 1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
auto first_reshape = as_type_ptr<op::Reshape>(sub_pattern.front());
|
||||
auto input_to_first_reshape = first_reshape->input_value(0);
|
||||
auto last_reshape = as_type_ptr<op::Reshape>(sub_pattern.back());
|
||||
|
||||
auto new_input_order = first_reshape->get_input_order();
|
||||
auto new_out_shape = last_reshape->get_shape();
|
||||
|
||||
auto new_reshape = std::make_shared<op::Reshape>(
|
||||
input_to_first_reshape, new_input_order, new_out_shape);
|
||||
|
||||
replace_node(last_reshape, new_reshape);
|
||||
modify_graph = true;
|
||||
}
|
||||
|
||||
return modify_graph;
|
||||
};
|
||||
std::set<std::shared_ptr<pattern::op::Label>> empty_correlated_matches;
|
||||
auto m =
|
||||
std::make_shared<pattern::RecurrentMatcher>(reshape_label, op, empty_correlated_matches);
|
||||
this->add_matcher(m, callback, PassProperty::REQUIRE_STATIC_SHAPE);
|
||||
}
|
@ -1,60 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/graph_rewrite.hpp"
|
||||
#include "ngraph/pass/pass_util.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class ReshapeElimination;
|
||||
class RecurrentReshapeElimination;
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::ReshapeElimination : public ngraph::pass::GraphRewrite
|
||||
{
|
||||
public:
|
||||
ReshapeElimination()
|
||||
: GraphRewrite()
|
||||
{
|
||||
construct_dot_transpose_pattern();
|
||||
construct_identity_reshape_pattern();
|
||||
construct_reshapex2_pattern();
|
||||
}
|
||||
|
||||
private:
|
||||
void construct_dot_transpose_pattern();
|
||||
void construct_identity_reshape_pattern();
|
||||
void construct_reshapex2_pattern();
|
||||
};
|
||||
|
||||
class NGRAPH_API ngraph::pass::RecurrentReshapeElimination
|
||||
: public ngraph::pass::RecurrentGraphRewrite
|
||||
{
|
||||
public:
|
||||
RecurrentReshapeElimination()
|
||||
: RecurrentGraphRewrite()
|
||||
{
|
||||
construct_recurrent_reshape();
|
||||
}
|
||||
|
||||
private:
|
||||
void construct_recurrent_reshape();
|
||||
};
|
@ -1,642 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "reshape_sinking.hpp"
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ngraph/descriptor/input.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/op/batch_norm.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/convolution.hpp"
|
||||
#include "ngraph/op/dequantize.hpp"
|
||||
#include "ngraph/op/get_output_element.hpp"
|
||||
#include "ngraph/op/pad.hpp"
|
||||
#include "ngraph/op/quantize.hpp"
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/op/slice.hpp"
|
||||
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||
#include "ngraph/pattern/op/label.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
using ReshapeMap = unordered_map<shared_ptr<Node>, shared_ptr<op::Reshape>>;
|
||||
|
||||
static string describe_reshape(shared_ptr<Node> node)
|
||||
{
|
||||
stringstream ss;
|
||||
auto reshape = as_type_ptr<op::Reshape>(node);
|
||||
ss << reshape->get_name()
|
||||
<< " ( axis order = " << ngraph::vector_to_string(reshape->get_input_order())
|
||||
<< " , shape = " << vector_to_string(reshape->get_shape()) << " ) "
|
||||
<< " , child = " << reshape->input_value(0).get_node_shared_ptr()->get_name();
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
static shared_ptr<op::Reshape>
|
||||
make_reshape(shared_ptr<Node> arg, const AxisVector& input_order, const Shape& output_shape)
|
||||
{
|
||||
auto reshape = make_shared<op::Reshape>(arg, input_order, output_shape);
|
||||
NGRAPH_DEBUG << "Make Reshape " << describe_reshape(reshape);
|
||||
return reshape;
|
||||
}
|
||||
|
||||
static void
|
||||
write_reshapemap(ReshapeMap& reorders, shared_ptr<Node> target, shared_ptr<op::Reshape> reshape)
|
||||
{
|
||||
NGRAPH_DEBUG << "Write ReshapeMap[" << target->get_name()
|
||||
<< "] = " << describe_reshape(reshape);
|
||||
reorders[target] = reshape;
|
||||
}
|
||||
|
||||
static shared_ptr<op::Reshape> read_reshapemap(ReshapeMap& reorders, shared_ptr<Node> target)
|
||||
{
|
||||
auto reorder = reorders.at(target);
|
||||
NGRAPH_DEBUG << "Read ReshapeMap[" << target->get_name() << "] -> "
|
||||
<< describe_reshape(reorder);
|
||||
return reorder;
|
||||
}
|
||||
|
||||
static shared_ptr<op::Reshape> combine_reshapes(shared_ptr<op::Reshape> r1,
|
||||
shared_ptr<op::Reshape> r2)
|
||||
{
|
||||
auto default_order = ngraph::get_default_order(r1->get_shape());
|
||||
auto perm_r1 = apply_permutation(default_order, r1->get_input_order());
|
||||
auto perm_r2 = apply_permutation(perm_r1, r2->get_input_order());
|
||||
auto rreshape =
|
||||
make_reshape(r2->input_value(0).get_node_shared_ptr(), perm_r2, r2->get_shape());
|
||||
NGRAPH_DEBUG << "Combining " << describe_reshape(r1) << " and " << describe_reshape(r2)
|
||||
<< " into " << describe_reshape(rreshape);
|
||||
return rreshape;
|
||||
}
|
||||
|
||||
static void insert_reshape(shared_ptr<Node> target, shared_ptr<Node> reshape, size_t input_index)
|
||||
{
|
||||
NGRAPH_DEBUG << "Inserting reshape at input " << target->get_name() << " input index "
|
||||
<< input_index;
|
||||
auto arg = target->input(input_index).get_source_output();
|
||||
NGRAPH_DEBUG << "Arg shape: " << arg.get_shape();
|
||||
auto new_reshape = reshape->copy_with_new_inputs({arg});
|
||||
NGRAPH_DEBUG << "Inserting reshape " << describe_reshape(new_reshape) << " at input "
|
||||
<< target->get_name() << " input index " << input_index;
|
||||
target->input(input_index).replace_source_output(new_reshape->output(0));
|
||||
}
|
||||
|
||||
static void delete_reshape(shared_ptr<Node> reshape)
|
||||
{
|
||||
NGRAPH_DEBUG << "Removing reshape " << reshape->get_name();
|
||||
if (!reshape->get_users().empty())
|
||||
{
|
||||
ngraph::replace_node(reshape, reshape->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
static void mark_reshape_for_deletion(shared_ptr<Node> reshape,
|
||||
set<shared_ptr<Node>>& reshapes_to_delete)
|
||||
{
|
||||
NGRAPH_DEBUG << "Marking reshape " << reshape->get_name() << " for deletion";
|
||||
reshapes_to_delete.insert(reshape);
|
||||
}
|
||||
|
||||
static shared_ptr<op::Reshape> create_default_reshape(shared_ptr<Node> n)
|
||||
{
|
||||
auto default_order = ngraph::get_default_order(n->get_shape());
|
||||
auto default_reshape = make_reshape(n, default_order, n->get_shape());
|
||||
NGRAPH_DEBUG << "Default reshape: " << describe_reshape(default_reshape);
|
||||
return default_reshape;
|
||||
}
|
||||
|
||||
// compute an axis order that converts the given axis order to default
|
||||
static AxisSet get_quantization_axes_in_default_order(shared_ptr<op::Reshape> arg_reshape,
|
||||
const AxisSet& old_axis_set)
|
||||
{
|
||||
auto perm_to_def = ngraph::get_permutation_to_default_order(arg_reshape->get_input_order());
|
||||
AxisSet axis_set;
|
||||
for (auto axis : old_axis_set)
|
||||
{
|
||||
axis_set.insert(perm_to_def.at(axis));
|
||||
}
|
||||
return axis_set;
|
||||
}
|
||||
|
||||
struct Swimmer
|
||||
{
|
||||
Input<Node> input;
|
||||
shared_ptr<op::Reshape> reshape;
|
||||
};
|
||||
|
||||
// Swim is used to push/"swim" reshapes towards paramaters.
|
||||
// This is typically done for binary ops when
|
||||
// one operand is in nchw, while the other one is nhwc
|
||||
// we prefer nchw since a lot of ngraph ops require this format,
|
||||
// so keeping things in nchw allows us to eliminate as many reshapes
|
||||
// as possible
|
||||
void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
|
||||
{
|
||||
Swimmer sw{input, reshape};
|
||||
list<Swimmer> work_queue;
|
||||
work_queue.push_back(sw);
|
||||
|
||||
// TODO: if we support more ops (especially, with >1 args)
|
||||
// we will need to keep track of nodes we visited and their reshapes
|
||||
while (work_queue.size() > 0)
|
||||
{
|
||||
auto csw = work_queue.front();
|
||||
work_queue.pop_front();
|
||||
auto n_output = csw.input.get_source_output();
|
||||
auto n = n_output.get_node_shared_ptr();
|
||||
auto materialize = [csw, n_output]() {
|
||||
auto n = n_output.get_node_shared_ptr();
|
||||
auto new_reshape = csw.reshape->clone_with_new_inputs({n});
|
||||
new_reshape->merge_provenance_tags_from(n);
|
||||
NGRAPH_DEBUG << "Materializing new reshape " << describe_reshape(new_reshape);
|
||||
csw.input.replace_source_output(new_reshape->output(0));
|
||||
}; // Only swim past nodes which have a single user
|
||||
if (n->get_users().size() > 1)
|
||||
{
|
||||
materialize();
|
||||
continue;
|
||||
}
|
||||
NGRAPH_DEBUG << "Processing (swimming) " << n->get_name();
|
||||
if (op::is_unary_elementwise_arithmetic(n))
|
||||
{
|
||||
Swimmer nsw{n->input(0), csw.reshape};
|
||||
work_queue.push_back(nsw);
|
||||
NGRAPH_DEBUG << "Propagating reshape " << describe_reshape(csw.reshape) << " for "
|
||||
<< n->get_name() << " to " << n->input_value(0).get_node_shared_ptr();
|
||||
}
|
||||
else if (is_type<op::Broadcast>(n))
|
||||
{
|
||||
auto old_broadcast = static_pointer_cast<op::Broadcast>(n);
|
||||
auto broadcast_axes = old_broadcast->get_broadcast_axes();
|
||||
auto broadcast_reshape = csw.reshape;
|
||||
// swimming can only handle 1 dim change
|
||||
if (broadcast_reshape->get_shape().size() - old_broadcast->get_shape().size() > 1)
|
||||
{
|
||||
materialize();
|
||||
continue;
|
||||
}
|
||||
bool in_order = true;
|
||||
AxisSet new_broadcast_axes;
|
||||
vector<size_t> new_source_axes;
|
||||
auto input_order = broadcast_reshape->get_input_order();
|
||||
for (size_t i = 0; i < input_order.size(); i++)
|
||||
{
|
||||
if (broadcast_axes.count(input_order.at(i)) != 0)
|
||||
{
|
||||
new_broadcast_axes.insert(i);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (new_source_axes.size() != 0 && new_source_axes.back() > input_order.at(i))
|
||||
{
|
||||
in_order = false;
|
||||
}
|
||||
new_source_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
auto broadcast_input = old_broadcast->input_value(0).get_node_shared_ptr();
|
||||
if (!in_order)
|
||||
{
|
||||
AxisVector new_source_axes_sorted{new_source_axes};
|
||||
sort(new_source_axes_sorted.begin(), new_source_axes_sorted.end());
|
||||
map<size_t, size_t> old_new_source_axes;
|
||||
for (size_t i = 0; new_source_axes_sorted.size(); i++)
|
||||
{
|
||||
old_new_source_axes.insert({new_source_axes.at(i), i});
|
||||
}
|
||||
|
||||
AxisVector new_source_axis_order;
|
||||
for (auto axis : new_source_axes_sorted)
|
||||
{
|
||||
new_source_axis_order.push_back(old_new_source_axes.at(axis));
|
||||
}
|
||||
|
||||
auto new_arg_shape =
|
||||
ngraph::apply_permutation(broadcast_input->get_shape(), new_source_axis_order);
|
||||
broadcast_input =
|
||||
make_reshape(broadcast_input, new_source_axis_order, new_arg_shape);
|
||||
}
|
||||
|
||||
auto new_broadcast = make_shared<op::Broadcast>(
|
||||
broadcast_input, broadcast_reshape->get_shape(), new_broadcast_axes);
|
||||
csw.input.replace_source_output(new_broadcast->output(0));
|
||||
}
|
||||
// TODO: Add cases to push through Reshape and BinaryElementwiseArithmetic
|
||||
else
|
||||
{
|
||||
// materialize
|
||||
materialize();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// convert_binary_to_default_order is used when one of the arguments
|
||||
// of a binary op isn't in the default format (i.e. nhwc instead of nchw)
|
||||
// We have to normalize this other argument to nchw by swimming nchw towards parameters
|
||||
// as far as we can
|
||||
static void convert_binary_to_default_order(shared_ptr<Node> binary,
|
||||
const Input<Node>& input,
|
||||
shared_ptr<Node> right,
|
||||
ReshapeMap& reorders,
|
||||
set<shared_ptr<Node>>& reshapes_to_delete)
|
||||
{
|
||||
auto left = input.get_source_output().get_node_shared_ptr();
|
||||
auto perm_to_def =
|
||||
ngraph::get_permutation_to_default_order(reorders.at(right)->get_input_order());
|
||||
auto new_shape = apply_permutation(left->get_shape(), perm_to_def);
|
||||
NGRAPH_DEBUG << "right = " << ngraph::vector_to_string(right->get_shape()) << ", "
|
||||
<< right->get_name();
|
||||
auto new_reshape = make_reshape(left, perm_to_def, new_shape);
|
||||
NGRAPH_DEBUG << "left : About to swim " << describe_reshape(new_reshape) << " up to "
|
||||
<< left->get_name();
|
||||
// this should now insert and swim reshape on right
|
||||
swim(input, new_reshape);
|
||||
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
|
||||
write_reshapemap(reorders, binary, read_reshapemap(reorders, right));
|
||||
}
|
||||
|
||||
static void materialize_shapes(shared_ptr<Node> n,
|
||||
ReshapeMap& reorders,
|
||||
set<shared_ptr<Node>>& reshapes_to_delete)
|
||||
{
|
||||
// skip multiple output nodes and deal with GOEs exclusively
|
||||
if (n->get_output_size() > 1)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < n->input_values().size(); i++)
|
||||
{
|
||||
// materialize all pending reshapes, flush pending reshapes
|
||||
auto arg = n->input_value(i).get_node_shared_ptr();
|
||||
if (reorders.count(arg) != 0)
|
||||
{
|
||||
auto arg_reshape = reorders.at(arg);
|
||||
NGRAPH_DEBUG << "Materializing " << describe_reshape(arg_reshape) << " for "
|
||||
<< arg->get_name();
|
||||
mark_reshape_for_deletion(arg_reshape, reshapes_to_delete);
|
||||
auto arg_shape = arg->get_shape();
|
||||
if (arg_reshape->get_input_order() != get_default_order(arg->get_shape()))
|
||||
{
|
||||
// Insert if arg needs to be transposed.
|
||||
insert_reshape(n, arg_reshape, i);
|
||||
}
|
||||
// no swimming up
|
||||
}
|
||||
}
|
||||
write_reshapemap(reorders, n, create_default_reshape(n));
|
||||
}
|
||||
|
||||
static void sink_reshape(shared_ptr<op::Reshape> reshape,
|
||||
ReshapeMap& reorders,
|
||||
set<shared_ptr<Node>>& reshapes_to_delete)
|
||||
{
|
||||
NGRAPH_DEBUG << "Sinking Reshape :" << describe_reshape(reshape);
|
||||
auto orig_reshape = reorders.at(reshape->input_value(0).get_node_shared_ptr());
|
||||
// 1) Not a Transpose or 2) Rank changing operation.
|
||||
if ((reshape->get_output_shape(0).size() != reshape->get_input_order().size()) ||
|
||||
(!reshape->get_is_transpose()))
|
||||
{
|
||||
NGRAPH_DEBUG << "Materializing " << describe_reshape(orig_reshape) << " for reshape "
|
||||
<< describe_reshape(reshape);
|
||||
insert_reshape(reshape, orig_reshape, 0);
|
||||
mark_reshape_for_deletion(orig_reshape, reshapes_to_delete);
|
||||
write_reshapemap(reorders, reshape, create_default_reshape(reshape));
|
||||
}
|
||||
else
|
||||
{
|
||||
// combine both reshapes
|
||||
auto new_reshape = combine_reshapes(orig_reshape, reshape);
|
||||
// remove original reshape now it's combined with a new one
|
||||
// should be safe to remove an already detached node
|
||||
mark_reshape_for_deletion(orig_reshape, reshapes_to_delete);
|
||||
// replace reshape with combined one
|
||||
ngraph::replace_node(reshape, new_reshape);
|
||||
mark_reshape_for_deletion(new_reshape, reshapes_to_delete);
|
||||
write_reshapemap(reorders, new_reshape, new_reshape);
|
||||
}
|
||||
}
|
||||
|
||||
static void sink_unary(shared_ptr<Node> n,
|
||||
ReshapeMap& reorders,
|
||||
set<shared_ptr<Node>>& /* reshapes_to_delete */)
|
||||
{
|
||||
auto arg_reshape = read_reshapemap(reorders, n->input_value(0).get_node_shared_ptr());
|
||||
NGRAPH_DEBUG << "Propagating " << describe_reshape(arg_reshape) << " for " << n->get_name();
|
||||
write_reshapemap(reorders, n, arg_reshape);
|
||||
}
|
||||
|
||||
static void sink_binary(shared_ptr<Node> binary,
|
||||
ReshapeMap& reorders,
|
||||
set<shared_ptr<Node>>& reshapes_to_delete)
|
||||
{
|
||||
auto left = binary->input_value(0).get_node_shared_ptr();
|
||||
auto right = binary->input_value(1).get_node_shared_ptr();
|
||||
|
||||
if (reorders.at(left)->get_input_order() == reorders.at(right)->get_input_order())
|
||||
{
|
||||
NGRAPH_DEBUG << "Propagating " << describe_reshape(reorders.at(left)) << " for "
|
||||
<< binary->get_name();
|
||||
write_reshapemap(reorders, binary, read_reshapemap(reorders, left));
|
||||
// at this point, both reshapes will be eventually removed
|
||||
mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete);
|
||||
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
|
||||
}
|
||||
else if (reorders.at(left)->get_input_order() == ngraph::get_default_order(left->get_shape()))
|
||||
{
|
||||
convert_binary_to_default_order(
|
||||
binary, binary->input(0), right, reorders, reshapes_to_delete);
|
||||
}
|
||||
else if (reorders.at(right)->get_input_order() == ngraph::get_default_order(right->get_shape()))
|
||||
{
|
||||
convert_binary_to_default_order(
|
||||
binary, binary->input(1), left, reorders, reshapes_to_delete);
|
||||
}
|
||||
else
|
||||
{
|
||||
NGRAPH_DEBUG << "Materializing both reshapes for " << binary->get_name();
|
||||
NGRAPH_DEBUG << "Left = " << describe_reshape(reorders.at(left));
|
||||
NGRAPH_DEBUG << "Right = " << describe_reshape(reorders.at(right));
|
||||
mark_reshape_for_deletion(reorders.at(left), reshapes_to_delete);
|
||||
mark_reshape_for_deletion(reorders.at(right), reshapes_to_delete);
|
||||
insert_reshape(binary, reorders.at(left), 0);
|
||||
insert_reshape(binary, reorders.at(right), 1);
|
||||
}
|
||||
}
|
||||
|
||||
static void sink_slice(shared_ptr<op::Slice> n,
|
||||
ReshapeMap& reorders,
|
||||
set<shared_ptr<Node>>& /* reshapes_to_delete */)
|
||||
{
|
||||
auto arg_reshape = reorders.at(n->input_value(0).get_node_shared_ptr());
|
||||
auto order = arg_reshape->get_input_order();
|
||||
|
||||
// we need the correct input shape to produce the right output shape
|
||||
// we are going to create a label of the right input shape,
|
||||
// so a new slice will have the right shape
|
||||
auto def_order = ngraph::get_permutation_to_default_order(order);
|
||||
auto input_shape = ngraph::apply_permutation(arg_reshape->get_shape(), def_order);
|
||||
auto dummy_correct_shape =
|
||||
make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape);
|
||||
|
||||
auto new_lower = ngraph::apply_permutation(n->get_lower_bounds(), def_order);
|
||||
auto new_upper = ngraph::apply_permutation(n->get_upper_bounds(), def_order);
|
||||
auto new_strides = ngraph::apply_permutation(n->get_strides(), def_order);
|
||||
auto new_slice = make_shared<op::Slice>(dummy_correct_shape, new_lower, new_upper, new_strides);
|
||||
ngraph::replace_node(dummy_correct_shape, n->input_value(0).get_node_shared_ptr());
|
||||
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_slice->get_name();
|
||||
ngraph::replace_node(n, new_slice);
|
||||
|
||||
auto new_reshape = make_reshape(new_slice, order, n->get_shape());
|
||||
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
|
||||
write_reshapemap(reorders, new_slice, new_reshape);
|
||||
}
|
||||
|
||||
static void sink_pad(shared_ptr<op::Pad> n,
|
||||
ReshapeMap& reorders,
|
||||
set<shared_ptr<Node>>& /* reshapes_to_delete */)
|
||||
{
|
||||
auto arg_reshape = reorders.at(n->input_value(0).get_node_shared_ptr());
|
||||
auto order = arg_reshape->get_input_order();
|
||||
// we need the correct input shape to produce the right output shape
|
||||
// we are going to create a label of the right input shape,
|
||||
// so a new pad will have the right shape
|
||||
auto def_order = ngraph::get_permutation_to_default_order(order);
|
||||
auto input_shape = ngraph::apply_permutation(arg_reshape->get_shape(), def_order);
|
||||
auto dummy_correct_shape =
|
||||
make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape);
|
||||
|
||||
auto new_lower = ngraph::apply_permutation(n->get_padding_below(), def_order);
|
||||
auto new_upper = ngraph::apply_permutation(n->get_padding_above(), def_order);
|
||||
auto new_pad = make_shared<op::Pad>(dummy_correct_shape,
|
||||
n->input_value(1).get_node_shared_ptr(),
|
||||
new_lower,
|
||||
new_upper,
|
||||
n->get_pad_mode());
|
||||
ngraph::replace_node(dummy_correct_shape, n->input_value(0).get_node_shared_ptr());
|
||||
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_pad->get_name();
|
||||
ngraph::replace_node(n, new_pad);
|
||||
auto new_reshape = make_reshape(new_pad, order, n->get_shape());
|
||||
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
|
||||
write_reshapemap(reorders, new_pad, new_reshape);
|
||||
}
|
||||
static void sink_quantize(shared_ptr<op::Quantize> quantize,
|
||||
ReshapeMap& reorders,
|
||||
set<shared_ptr<Node>>& /* reshapes_to_delete */)
|
||||
{
|
||||
auto arg_reshape = reorders.at(quantize->input_value(0).get_node_shared_ptr());
|
||||
AxisSet axes_in_def_order =
|
||||
get_quantization_axes_in_default_order(arg_reshape, quantize->get_axes());
|
||||
auto new_quantize = make_shared<op::Quantize>(quantize->input_value(0),
|
||||
quantize->input_value(1),
|
||||
quantize->input_value(2),
|
||||
quantize->get_element_type(),
|
||||
axes_in_def_order,
|
||||
quantize->get_round_mode());
|
||||
|
||||
ngraph::replace_node(quantize, new_quantize);
|
||||
write_reshapemap(reorders, new_quantize, arg_reshape);
|
||||
}
|
||||
|
||||
static void sink_concat(shared_ptr<op::Concat> n,
|
||||
ReshapeMap& reorders,
|
||||
set<shared_ptr<Node>>& reshapes_to_delete)
|
||||
{
|
||||
auto arg_reshape = reorders.at(n->input_value(0).get_node_shared_ptr());
|
||||
auto order = arg_reshape->get_input_order();
|
||||
// we need the correct input shape to produce the right output shape
|
||||
// we are going to create a label of the right input shape,
|
||||
// so a new slice will have the right shape
|
||||
auto def_order = ngraph::get_permutation_to_default_order(order);
|
||||
auto input_shape = ngraph::apply_permutation(arg_reshape->get_shape(), def_order);
|
||||
auto dummy_correct_shape =
|
||||
make_shared<pattern::op::Label>(arg_reshape->get_element_type(), input_shape);
|
||||
|
||||
NodeVector new_args;
|
||||
new_args.push_back(dummy_correct_shape);
|
||||
|
||||
for (size_t i = 1; i < n->get_input_size(); i++)
|
||||
{
|
||||
auto iarg_reshape = reorders.at(n->input_value(i).get_node_shared_ptr());
|
||||
auto iorder = iarg_reshape->get_input_order();
|
||||
if (iorder != order)
|
||||
{
|
||||
NGRAPH_DEBUG << " input order at " << i << "-th arg is different from first arg";
|
||||
materialize_shapes(n, reorders, reshapes_to_delete);
|
||||
return;
|
||||
}
|
||||
|
||||
auto iinput_shape = ngraph::apply_permutation(iarg_reshape->get_shape(), def_order);
|
||||
auto idummy_correct_shape =
|
||||
make_shared<pattern::op::Label>(iarg_reshape->get_element_type(), iinput_shape);
|
||||
new_args.push_back(idummy_correct_shape);
|
||||
}
|
||||
|
||||
auto new_axis = order.at(n->get_concatenation_axis());
|
||||
auto new_concat = make_shared<op::Concat>(new_args, new_axis);
|
||||
// put back the original arguments
|
||||
for (size_t i = 0; i < new_concat->get_input_size(); i++)
|
||||
{
|
||||
ngraph::replace_node(new_args.at(i), n->input_value(i).get_node_shared_ptr());
|
||||
}
|
||||
NGRAPH_DEBUG << "Replacing " << n->get_name() << " with " << new_concat->get_name();
|
||||
ngraph::replace_node(n, new_concat);
|
||||
|
||||
auto new_reshape = make_reshape(new_concat, order, n->get_shape());
|
||||
NGRAPH_DEBUG << "Propagating " << describe_reshape(new_reshape) << " for " << n->get_name();
|
||||
write_reshapemap(reorders, new_concat, new_reshape);
|
||||
}
|
||||
|
||||
static void sink_dequantize(shared_ptr<op::Dequantize> dequantize,
|
||||
ReshapeMap& reorders,
|
||||
set<shared_ptr<Node>>& /* reshapes_to_delete */)
|
||||
{
|
||||
auto arg_reshape = reorders.at(dequantize->input_value(0).get_node_shared_ptr());
|
||||
AxisSet axes_in_def_order =
|
||||
get_quantization_axes_in_default_order(arg_reshape, dequantize->get_axes());
|
||||
auto new_dequantize = make_shared<op::Dequantize>(dequantize->input_value(0),
|
||||
dequantize->input_value(1),
|
||||
dequantize->input_value(2),
|
||||
dequantize->get_element_type(),
|
||||
axes_in_def_order);
|
||||
|
||||
ngraph::replace_node(dequantize, new_dequantize);
|
||||
write_reshapemap(reorders, new_dequantize, arg_reshape);
|
||||
}
|
||||
|
||||
// The goal of ReshapeSinking is to remove
|
||||
// round-trip reshapes(i.e. nhwc->nchw(nchw-only-op)->nhwc)
|
||||
// around nchw-only-op (e.g.Convolution, Batchnorm, Avg/MaxPool)
|
||||
// This is achieved by both **sinking**, propagating reshapes
|
||||
// through ops towards op::Results,
|
||||
// or **swimming** Reshapes up towards op::Parameter
|
||||
// For each op type we support we can either combine
|
||||
// two reshapes by replacing the existing Reshape,
|
||||
// materialize pending reshapes if they can't be propagated through op
|
||||
bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function> f)
|
||||
{
|
||||
ReshapeMap reorders;
|
||||
NodeVector results;
|
||||
set<shared_ptr<Node>> reshapes_to_delete;
|
||||
|
||||
// STEP 1 : Sink or Swim reshapes away for op clusters
|
||||
for (auto n : f->get_ordered_ops())
|
||||
{
|
||||
NGRAPH_DEBUG << "Start: Processing node " << n->get_name();
|
||||
// collect all Result nodes for a sanity check
|
||||
if (ngraph::op::is_output(n))
|
||||
{
|
||||
results.push_back(n);
|
||||
}
|
||||
|
||||
if (auto reshape = as_type_ptr<op::Reshape>(n))
|
||||
{
|
||||
sink_reshape(reshape, reorders, reshapes_to_delete);
|
||||
}
|
||||
else if (op::is_unary_elementwise_arithmetic(n))
|
||||
{
|
||||
sink_unary(n, reorders, reshapes_to_delete);
|
||||
}
|
||||
else if (op::is_binary_elementwise_arithmetic(n))
|
||||
{
|
||||
sink_binary(n, reorders, reshapes_to_delete);
|
||||
}
|
||||
else if (auto goe = as_type_ptr<op::GetOutputElement>(n))
|
||||
{
|
||||
write_reshapemap(reorders, goe, create_default_reshape(goe));
|
||||
}
|
||||
else if (auto quantize = as_type_ptr<op::Quantize>(n))
|
||||
{
|
||||
sink_quantize(quantize, reorders, reshapes_to_delete);
|
||||
}
|
||||
else if (auto dequantize = as_type_ptr<op::Dequantize>(n))
|
||||
{
|
||||
sink_dequantize(dequantize, reorders, reshapes_to_delete);
|
||||
}
|
||||
else if (auto slice = as_type_ptr<op::Slice>(n))
|
||||
{
|
||||
// A heuristic. If Reshape has multiple slice users, if sunk
|
||||
// it will be replicated by the number of its users
|
||||
// TODO: we should have a pre-pass that looks at this kind of
|
||||
// scenarios and marks some reshapes as too "toxic" to sink
|
||||
// For now, this heuristic works really well.
|
||||
// Note, get_users(*true*) which means we only care about
|
||||
// live users of Reshape. However get_users(*true*) cause
|
||||
// significant time increase on graphs with many slice ops,
|
||||
// so for now we are removing "true" check and let backend
|
||||
// handle reshape sinking for slice operation.
|
||||
if (slice->input_value(0).get_node_shared_ptr()->get_users().size() == 1)
|
||||
{
|
||||
sink_slice(slice, reorders, reshapes_to_delete);
|
||||
}
|
||||
else
|
||||
{
|
||||
materialize_shapes(n, reorders, reshapes_to_delete);
|
||||
}
|
||||
}
|
||||
else if (auto pad = as_type_ptr<op::Pad>(n))
|
||||
{
|
||||
sink_pad(pad, reorders, reshapes_to_delete);
|
||||
}
|
||||
else if (auto concat = as_type_ptr<op::Concat>(n))
|
||||
{
|
||||
sink_concat(concat, reorders, reshapes_to_delete);
|
||||
}
|
||||
else
|
||||
{
|
||||
materialize_shapes(n, reorders, reshapes_to_delete);
|
||||
}
|
||||
NGRAPH_DEBUG << "End: Processing node " << n->get_name();
|
||||
}
|
||||
|
||||
// STEP 2: purge all the reshapes we either sunk or swam.
|
||||
for (auto r : reshapes_to_delete)
|
||||
{
|
||||
delete_reshape(r);
|
||||
}
|
||||
|
||||
// make sure shapes are always materialized before results
|
||||
for (auto r : results)
|
||||
{
|
||||
NGRAPH_CHECK(r->get_shape() == r->get_input_shape(0) &&
|
||||
r->get_element_type() ==
|
||||
r->input_value(0).get_node_shared_ptr()->get_element_type(),
|
||||
" op::Result = ",
|
||||
*r,
|
||||
", Arg = ",
|
||||
*r->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
// STEP 3: fix wrong shape info wholesale
|
||||
for (auto n : f->get_ordered_ops())
|
||||
{
|
||||
n->revalidate_and_infer_types();
|
||||
}
|
||||
return true;
|
||||
}
|
@ -1,33 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class NGRAPH_API ReshapeSinking : public pass::FunctionPass
|
||||
{
|
||||
public:
|
||||
ReshapeSinking() { set_property(PassProperty::REQUIRE_STATIC_SHAPE, true); }
|
||||
bool run_on_function(std::shared_ptr<Function> function) override;
|
||||
};
|
||||
}
|
||||
}
|
@ -1,43 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "ngraph/file_util.hpp"
|
||||
#include "ngraph/pass/serialize.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#ifndef NGRAPH_JSON_DISABLE
|
||||
#include "ngraph/serializer.hpp"
|
||||
#include "nlohmann/json.hpp"
|
||||
#endif
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
pass::Serialization::Serialization(const string& name)
|
||||
: m_name{name}
|
||||
{
|
||||
}
|
||||
|
||||
bool pass::Serialization::run_on_module(vector<shared_ptr<Function>>& functions)
|
||||
{
|
||||
#ifndef NGRAPH_JSON_DISABLE
|
||||
// serializing the outermost functions
|
||||
// also implicitly serializes any inner functions
|
||||
serialize(m_name, functions.at(0), 4);
|
||||
#endif
|
||||
return false;
|
||||
}
|
@ -1,40 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class Serialization;
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::Serialization : public ModulePass
|
||||
{
|
||||
public:
|
||||
Serialization(const std::string& name);
|
||||
|
||||
virtual bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
|
||||
|
||||
private:
|
||||
const std::string m_name;
|
||||
};
|
@ -1,48 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/pass/validate_graph.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
bool pass::ValidateGraph::run_on_module(vector<shared_ptr<Function>>& functions)
|
||||
{
|
||||
for (shared_ptr<Function> f : functions)
|
||||
{
|
||||
validate_parameters(*f);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void pass::ValidateGraph::validate_parameters(const Function& function)
|
||||
{
|
||||
auto parameters = function.get_parameters();
|
||||
for (auto node : function.get_ops())
|
||||
{
|
||||
shared_ptr<op::Parameter> p = as_type_ptr<op::Parameter>(node);
|
||||
if (nullptr != p)
|
||||
{
|
||||
auto it = find_if(parameters.begin(),
|
||||
parameters.end(),
|
||||
[p](shared_ptr<op::Parameter> q) { return (p == q); });
|
||||
if (it == parameters.end())
|
||||
{
|
||||
throw ngraph_error("Function references undeclared parameter");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,38 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class ValidateGraph;
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::ValidateGraph : public ModulePass
|
||||
{
|
||||
public:
|
||||
bool run_on_module(std::vector<std::shared_ptr<ngraph::Function>>&) override;
|
||||
|
||||
private:
|
||||
void validate_parameters(const Function&);
|
||||
};
|
@ -1,178 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/op/avg_pool.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/convolution.hpp"
|
||||
#include "ngraph/op/max_pool.hpp"
|
||||
#include "ngraph/op/pad.hpp"
|
||||
#include "ngraph/op/product.hpp"
|
||||
#include "ngraph/op/replace_slice.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/type.hpp"
|
||||
#include "zero_dim_tensor_elimination.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
static bool has_zero_dim(const Output<Node>& output)
|
||||
{
|
||||
const auto& shape = output.get_shape();
|
||||
return find(shape.begin(), shape.end(), 0) != shape.end();
|
||||
}
|
||||
|
||||
static bool verify_no_internal_zero_length_ops(shared_ptr<Function> f)
|
||||
{
|
||||
set<Output<Node>> zero_length_source_outputs;
|
||||
for (auto n : f->get_ordered_ops())
|
||||
{
|
||||
if (op::is_output(n) || op::is_parameter(n) || op::is_constant(n) ||
|
||||
n->get_output_size() > 1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto& output : n->outputs())
|
||||
{
|
||||
if (has_zero_dim(output))
|
||||
{
|
||||
zero_length_source_outputs.insert(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// all zero-length ops should be in a result set
|
||||
// if we remove all such nodes included in the result set
|
||||
// from zero_length_nodes and there are still nodes left
|
||||
//(in zero_length_nodes), this means we have INTERNAL
|
||||
// zero-length nodes (which violates our assumption)
|
||||
for (auto r : f->get_results())
|
||||
{
|
||||
for (auto& input : r->inputs())
|
||||
{
|
||||
if (zero_length_source_outputs.count(input.get_source_output()) != 0)
|
||||
{
|
||||
zero_length_source_outputs.erase(input.get_source_output());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return zero_length_source_outputs.size() > 0;
|
||||
}
|
||||
|
||||
bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f)
|
||||
{
|
||||
bool replaced = false;
|
||||
auto cvals = vector<string>(0);
|
||||
// we need to go over all nodes since we could have sum or any other 0-length-tensor-to scalar
|
||||
// op as an internal node (i.e. a node that isn't an argument to `op::Result`)
|
||||
for (auto n : f->get_ordered_ops())
|
||||
{
|
||||
// don't try to replace `op::Result`
|
||||
// all multi-output feed into `GetOutputElement`
|
||||
// if any `GetOutputElement` is zero-length
|
||||
// we replace it w/ a signalling constant
|
||||
// so we don't have to deal w/ multi-output nodes directly
|
||||
if (op::is_output(n) || op::is_parameter(n) || n->get_output_size() > 1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (has_zero_dim(n))
|
||||
{
|
||||
// we don't have to create constants every time but this is the easiest
|
||||
// and it's CSE's job to eliminate the same ones
|
||||
auto constant = make_shared<op::Constant>(n->get_element_type(), n->get_shape(), cvals);
|
||||
replace_node(n, constant);
|
||||
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << constant->get_name();
|
||||
replaced = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (n->get_input_size() == 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto concat = as_type_ptr<op::Concat>(n))
|
||||
{
|
||||
OutputVector non_zero_dim_args;
|
||||
for (auto arg : concat->input_values())
|
||||
{
|
||||
if (!has_zero_dim(arg))
|
||||
{
|
||||
non_zero_dim_args.push_back(arg);
|
||||
}
|
||||
}
|
||||
|
||||
if (non_zero_dim_args.size() < concat->get_input_size())
|
||||
{
|
||||
auto new_concat = concat->clone_with_new_inputs(non_zero_dim_args);
|
||||
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with "
|
||||
<< new_concat->get_name();
|
||||
replace_node(concat, new_concat);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
else if (auto replace_slice = as_type_ptr<op::ReplaceSlice>(n))
|
||||
{
|
||||
const Shape& replacement_shape = replace_slice->get_input_shape(1);
|
||||
if (shape_size(replacement_shape) == 0)
|
||||
{
|
||||
// Op is a noop
|
||||
Output<Node> source_output = replace_slice->input_value(0);
|
||||
Output<Node> output = replace_slice->output(0);
|
||||
for (Input<Node> input : output.get_target_inputs())
|
||||
{
|
||||
input.replace_source_output(source_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto source_output = n->input_value(0);
|
||||
|
||||
if (source_output.get_node()->get_output_size() != 1 || !has_zero_dim(source_output))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
auto new_node = n->get_default_value();
|
||||
|
||||
if (!new_node)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
replaced = true;
|
||||
NGRAPH_DEBUG << " Replacing " << n->get_name() << " with " << new_node->get_name();
|
||||
replace_node(n, new_node);
|
||||
}
|
||||
|
||||
if (verify_no_internal_zero_length_ops(f))
|
||||
{
|
||||
throw ngraph_error("there were internal zero-length nodes in a graph");
|
||||
}
|
||||
|
||||
return replaced;
|
||||
}
|
@ -1,39 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class ZeroDimTensorElimination;
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::ZeroDimTensorElimination : public FunctionPass
|
||||
{
|
||||
public:
|
||||
ZeroDimTensorElimination()
|
||||
: FunctionPass()
|
||||
{
|
||||
set_property(PassProperty::REQUIRE_STATIC_SHAPE, true);
|
||||
}
|
||||
|
||||
virtual bool run_on_function(std::shared_ptr<ngraph::Function> f);
|
||||
};
|
@ -42,7 +42,6 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
||||
endif()
|
||||
|
||||
set(SRC
|
||||
algebraic_simplification.cpp
|
||||
aligned_buffer.cpp
|
||||
all_close_f.cpp
|
||||
assertion.cpp
|
||||
@ -58,8 +57,6 @@ set(SRC
|
||||
coordinate.cpp
|
||||
copy.cpp
|
||||
cpio.cpp
|
||||
cse.cpp
|
||||
dyn_elimination.cpp
|
||||
element_type.cpp
|
||||
eval.cpp
|
||||
file_util.cpp
|
||||
@ -98,16 +95,12 @@ set(SRC
|
||||
opset_pass/topk_opset_pass.cpp
|
||||
opset_pass/transpose_opset_pass.cpp
|
||||
partial_shape.cpp
|
||||
pass.cpp
|
||||
pass_liveness.cpp
|
||||
pass_manager.cpp
|
||||
pass_memory_layout.cpp
|
||||
pass_shape_relevance.cpp
|
||||
pattern.cpp
|
||||
provenance.cpp
|
||||
replace_node.cpp
|
||||
reshape_elimination.cpp
|
||||
reshape_sinking.cpp
|
||||
shape.cpp
|
||||
specialize_function.cpp
|
||||
tensor.cpp
|
||||
@ -195,15 +188,8 @@ set(SRC
|
||||
type_prop_benchmark.cpp
|
||||
type_prop_layers.cpp
|
||||
util.cpp
|
||||
zero_dim_tensor_elimination.cpp
|
||||
)
|
||||
|
||||
if(NGRAPH_INTERPRETER_ENABLE)
|
||||
list(APPEND SRC
|
||||
concat_fusion.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
# This code generates one source file per header file under ngraph/src where the source file
|
||||
# has just a single #include statement. This checks that each header in the source tree is
|
||||
# complete and self-contained so it can be included without requiring any other includes.
|
||||
|
@ -1,298 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ngraph/file_util.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/pass/concat_fusion.hpp"
|
||||
#include "ngraph/pass/graph_rewrite.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/visualize_tree.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
#include "ngraph/pattern/op/label.hpp"
|
||||
#include "ngraph/pattern/op/skip.hpp"
|
||||
#include "ngraph/serializer.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "util/all_close.hpp"
|
||||
#include "util/matcher.hpp"
|
||||
#include "util/random.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
TEST(concat_fusion, single_branch)
|
||||
{
|
||||
Shape shape_a{12, 8, 1, 1};
|
||||
auto generate_func = [shape_a]() {
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_a);
|
||||
|
||||
auto concat_1 = make_shared<op::Concat>(NodeVector{A}, 2);
|
||||
auto concat_2 = make_shared<op::Concat>(NodeVector{concat_1}, 2);
|
||||
auto concat_3 = make_shared<op::Concat>(
|
||||
NodeVector{concat_2, concat_2, concat_2, concat_2, concat_2, concat_2, concat_2}, 2);
|
||||
auto concat_4 = make_shared<op::Concat>(
|
||||
NodeVector{concat_3, concat_3, concat_3, concat_3, concat_3, concat_3, concat_3}, 3);
|
||||
auto f_concat_1 = make_shared<Function>(NodeVector{concat_4}, ParameterVector{A});
|
||||
return f_concat_1;
|
||||
};
|
||||
|
||||
auto baseline_f = generate_func();
|
||||
auto optimized_f = generate_func();
|
||||
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::ConcatElimination>();
|
||||
pass_manager.register_pass<pass::SelfConcatFusion>();
|
||||
pass_manager.run_passes(optimized_f);
|
||||
|
||||
test::Uniform<float> rng(0.0f, 100.0f);
|
||||
vector<vector<float>> args;
|
||||
vector<float> tensor_val(shape_size(baseline_input_shape));
|
||||
rng.initialize(tensor_val);
|
||||
args.push_back(tensor_val);
|
||||
|
||||
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
|
||||
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
|
||||
|
||||
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
|
||||
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
|
||||
size_t num_broadcast_optimzed = count_ops_of_type<op::Broadcast>(optimized_f);
|
||||
|
||||
ASSERT_EQ(num_reshapes_optimized, 1);
|
||||
ASSERT_EQ(num_broadcast_optimzed, 1);
|
||||
}
|
||||
|
||||
TEST(concat_fusion, multiple_branches_1)
|
||||
{
|
||||
Shape shape_a{16, 8, 1, 1};
|
||||
auto generate_func = [shape_a]() {
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_a);
|
||||
|
||||
auto concat_1 = make_shared<op::Concat>(NodeVector{A}, 2);
|
||||
auto concat_2 = make_shared<op::Concat>(NodeVector{concat_1}, 2);
|
||||
auto concat_3 = make_shared<op::Concat>(
|
||||
NodeVector{concat_2, concat_2, concat_2, concat_2, concat_2, concat_2, concat_2}, 2);
|
||||
auto concat_4 = make_shared<op::Concat>(
|
||||
NodeVector{concat_3, concat_3, concat_3, concat_3, concat_3, concat_3, concat_3}, 3);
|
||||
|
||||
auto concat_5 = make_shared<op::Concat>(NodeVector{A, A}, 2);
|
||||
auto concat_6 = make_shared<op::Concat>(NodeVector{concat_5, concat_5, concat_5}, 3);
|
||||
auto f_concat_1 = make_shared<Function>(NodeVector{concat_4, concat_6}, ParameterVector{A});
|
||||
return f_concat_1;
|
||||
};
|
||||
|
||||
auto baseline_f = generate_func();
|
||||
auto optimized_f = generate_func();
|
||||
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::ConcatElimination>();
|
||||
pass_manager.register_pass<pass::SelfConcatFusion>();
|
||||
pass_manager.run_passes(optimized_f);
|
||||
|
||||
test::Uniform<float> rng(0.0f, 100.0f);
|
||||
vector<vector<float>> args;
|
||||
vector<float> tensor_val(shape_size(baseline_input_shape));
|
||||
rng.initialize(tensor_val);
|
||||
args.push_back(tensor_val);
|
||||
|
||||
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
|
||||
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
|
||||
|
||||
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
|
||||
|
||||
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
|
||||
size_t num_broadcast_optimzed = count_ops_of_type<op::Broadcast>(optimized_f);
|
||||
|
||||
ASSERT_EQ(num_reshapes_optimized, 2);
|
||||
ASSERT_EQ(num_broadcast_optimzed, 2);
|
||||
}
|
||||
|
||||
TEST(concat_fusion, multiple_branches_2)
|
||||
{
|
||||
Shape shape_a{16, 8, 1, 1};
|
||||
auto generate_func = [shape_a]() {
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_a);
|
||||
auto concat_3 = make_shared<op::Concat>(NodeVector{A, A, A, A, A, A, A}, 2);
|
||||
auto concat_4 = make_shared<op::Concat>(
|
||||
NodeVector{concat_3, concat_3, concat_3, concat_3, concat_3, concat_3, concat_3}, 3);
|
||||
|
||||
auto concat_6 = make_shared<op::Concat>(NodeVector{A, A, A}, 3);
|
||||
auto f_concat_1 = make_shared<Function>(NodeVector{concat_4, concat_6}, ParameterVector{A});
|
||||
return f_concat_1;
|
||||
};
|
||||
|
||||
auto baseline_f = generate_func();
|
||||
auto optimized_f = generate_func();
|
||||
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::ConcatElimination>();
|
||||
pass_manager.register_pass<pass::SelfConcatFusion>();
|
||||
pass_manager.run_passes(optimized_f);
|
||||
|
||||
test::Uniform<float> rng(0.0f, 100.0f);
|
||||
vector<vector<float>> args;
|
||||
vector<float> tensor_val(shape_size(baseline_input_shape));
|
||||
rng.initialize(tensor_val);
|
||||
args.push_back(tensor_val);
|
||||
|
||||
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
|
||||
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
|
||||
|
||||
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
|
||||
|
||||
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
|
||||
size_t num_broadcast_optimized = count_ops_of_type<op::Broadcast>(optimized_f);
|
||||
|
||||
ASSERT_EQ(num_reshapes_optimized, 2);
|
||||
ASSERT_EQ(num_broadcast_optimized, 2);
|
||||
}
|
||||
|
||||
TEST(concat_fusion, non_fusable_self_concat)
|
||||
{
|
||||
Shape shape_a{32, 1, 1, 1};
|
||||
Shape shape_b{32, 1, 1};
|
||||
auto generate_func = [shape_a, shape_b]() {
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_a);
|
||||
auto B = make_shared<op::Parameter>(element::f32, shape_b);
|
||||
|
||||
auto concat_1 = make_shared<op::Concat>(NodeVector{A, A, A, A}, 1);
|
||||
auto concat_2 = make_shared<op::Concat>(
|
||||
NodeVector{concat_1, concat_1, concat_1, concat_1, concat_1, concat_1, concat_1}, 2);
|
||||
auto concat_3 = make_shared<op::Concat>(NodeVector{concat_2, concat_2}, 1);
|
||||
auto concat_4 = make_shared<op::Concat>(NodeVector{concat_3, concat_3, concat_3}, 3);
|
||||
|
||||
auto concat_5 = make_shared<op::Concat>(NodeVector{B, B, B, B, B, B, B}, 1);
|
||||
auto concat_6 = make_shared<op::Concat>(NodeVector{concat_5, concat_5, concat_5}, 2);
|
||||
auto broadcast = make_shared<op::Broadcast>(concat_6, Shape{32, 8, 7, 3}, AxisSet{1});
|
||||
auto add = make_shared<op::Add>(concat_4, broadcast);
|
||||
auto f_concat_1 = make_shared<Function>(NodeVector{add}, ParameterVector{A, B});
|
||||
return f_concat_1;
|
||||
};
|
||||
|
||||
auto baseline_f = generate_func();
|
||||
auto optimized_f = generate_func();
|
||||
auto baseline_input_shape_1 = baseline_f->get_parameters().at(0)->get_shape();
|
||||
auto baseline_input_shape_2 = baseline_f->get_parameters().at(1)->get_shape();
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::ConcatElimination>();
|
||||
pass_manager.register_pass<pass::SelfConcatFusion>();
|
||||
pass_manager.run_passes(optimized_f);
|
||||
|
||||
test::Uniform<float> rng(0.0f, 100.0f);
|
||||
vector<vector<float>> args;
|
||||
vector<float> tensor_val_1(shape_size(baseline_input_shape_1));
|
||||
vector<float> tensor_val_2(shape_size(baseline_input_shape_2));
|
||||
rng.initialize(tensor_val_1);
|
||||
rng.initialize(tensor_val_2);
|
||||
args.push_back(tensor_val_1);
|
||||
args.push_back(tensor_val_2);
|
||||
|
||||
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
|
||||
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
|
||||
|
||||
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
|
||||
|
||||
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
|
||||
size_t num_broadcast_optimzed = count_ops_of_type<op::Broadcast>(optimized_f);
|
||||
|
||||
ASSERT_EQ(num_reshapes_optimized, 3);
|
||||
ASSERT_EQ(num_broadcast_optimzed, 4);
|
||||
}
|
||||
|
||||
TEST(concat_fusion, self_concat_with_fan_out)
|
||||
{
|
||||
Shape shape_a{8, 1, 1, 1};
|
||||
Shape shape_b{8, 4, 1, 1};
|
||||
auto generate_func = [shape_a, shape_b]() {
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_a);
|
||||
auto B = make_shared<op::Parameter>(element::f32, shape_b);
|
||||
|
||||
auto concat_1 = make_shared<op::Concat>(NodeVector{A, A, A, A, A, A, A}, 2);
|
||||
auto concat_2 =
|
||||
make_shared<op::Concat>(NodeVector{concat_1, concat_1, concat_1, concat_1}, 1);
|
||||
auto concat_3 =
|
||||
make_shared<op::Concat>(NodeVector{concat_2, concat_2, concat_2, concat_2}, 3);
|
||||
|
||||
auto concat_4 = make_shared<op::Concat>(NodeVector{B, B, B, B, B, B, B}, 2);
|
||||
auto concat_5 = make_shared<op::Concat>(NodeVector{concat_4, concat_4, concat_4}, 3);
|
||||
auto concat_6 = make_shared<op::Concat>(NodeVector{concat_2, concat_4}, 3);
|
||||
auto f_concat_1 =
|
||||
make_shared<Function>(NodeVector{concat_3, concat_6}, ParameterVector{A, B});
|
||||
return f_concat_1;
|
||||
};
|
||||
|
||||
auto baseline_f = generate_func();
|
||||
auto optimized_f = generate_func();
|
||||
auto baseline_input_shape_1 = baseline_f->get_parameters().at(0)->get_shape();
|
||||
auto baseline_input_shape_2 = baseline_f->get_parameters().at(1)->get_shape();
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::ConcatElimination>();
|
||||
pass_manager.register_pass<pass::SelfConcatFusion>();
|
||||
pass_manager.run_passes(optimized_f);
|
||||
|
||||
test::Uniform<float> rng(0.0f, 100.0f);
|
||||
vector<vector<float>> args;
|
||||
vector<float> tensor_val_1(shape_size(baseline_input_shape_1));
|
||||
vector<float> tensor_val_2(shape_size(baseline_input_shape_2));
|
||||
rng.initialize(tensor_val_1);
|
||||
rng.initialize(tensor_val_2);
|
||||
args.push_back(tensor_val_1);
|
||||
args.push_back(tensor_val_2);
|
||||
|
||||
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
|
||||
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
|
||||
|
||||
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
|
||||
|
||||
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
|
||||
size_t num_broadcast_optimzed = count_ops_of_type<op::Broadcast>(optimized_f);
|
||||
|
||||
ASSERT_EQ(num_reshapes_optimized, 3);
|
||||
ASSERT_EQ(num_broadcast_optimzed, 3);
|
||||
}
|
||||
|
||||
TEST(concat_fusion, pass_property)
|
||||
{
|
||||
{
|
||||
auto pass = std::make_shared<ngraph::pass::ConcatElimination>();
|
||||
ASSERT_FALSE(pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
|
||||
ASSERT_FALSE(pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
|
||||
}
|
||||
|
||||
{
|
||||
auto pass = std::make_shared<ngraph::pass::SelfConcatFusion>();
|
||||
ASSERT_TRUE(pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
|
||||
ASSERT_FALSE(pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
|
||||
}
|
||||
}
|
@ -1,362 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/file_util.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/abs.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/divide.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/op/product.hpp"
|
||||
#include "ngraph/op/sqrt.hpp"
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/pass/cse.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
TEST(CSE, abs_abs)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto abs1 = std::make_shared<op::Abs>(A);
|
||||
auto abs2 = std::make_shared<op::Abs>(A);
|
||||
auto f = std::make_shared<Function>(NodeVector{abs1, abs2}, ParameterVector{A});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(),
|
||||
f->get_results().at(1)->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
TEST(CSE, abs_abs_negative)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto abs1 = std::make_shared<op::Abs>(A);
|
||||
auto abs2 = std::make_shared<op::Abs>(B);
|
||||
auto f = std::make_shared<Function>(NodeVector{abs1, abs2}, ParameterVector{A, B});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), abs1);
|
||||
ASSERT_EQ(f->get_results().at(1)->input_value(0).get_node_shared_ptr(), abs2);
|
||||
}
|
||||
|
||||
TEST(CSE, add_add)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto add1 = std::make_shared<op::Add>(A, B);
|
||||
auto add2 = std::make_shared<op::Add>(A, B);
|
||||
auto f = std::make_shared<Function>(NodeVector{add1, add2}, ParameterVector{A, B});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(),
|
||||
f->get_results().at(1)->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
TEST(CSE, add_add_commutative)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto add1 = std::make_shared<op::Add>(A, B);
|
||||
auto add2 = std::make_shared<op::Add>(B, A);
|
||||
auto f = std::make_shared<Function>(NodeVector{add1, add2}, ParameterVector{A, B});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(),
|
||||
f->get_results().at(1)->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
TEST(CSE, add_add_negative)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto C = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto D = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto add1 = std::make_shared<op::Add>(A, B);
|
||||
auto add2 = std::make_shared<op::Add>(C, D);
|
||||
auto f = std::make_shared<Function>(NodeVector{add1, add2}, ParameterVector{A, B, C, D});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), add1);
|
||||
ASSERT_EQ(f->get_results().at(1)->input_value(0).get_node_shared_ptr(), add2);
|
||||
}
|
||||
|
||||
TEST(CSE, abs_add)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto abs_a1 = std::make_shared<op::Abs>(A);
|
||||
auto abs_b1 = std::make_shared<op::Abs>(B);
|
||||
auto abs_a2 = std::make_shared<op::Abs>(A);
|
||||
auto abs_b2 = std::make_shared<op::Abs>(B);
|
||||
auto add1 = std::make_shared<op::Add>(abs_a1, abs_b1);
|
||||
auto add2 = std::make_shared<op::Add>(abs_a2, abs_b2);
|
||||
auto f = std::make_shared<Function>(NodeVector{add1, add2}, ParameterVector{A, B});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(),
|
||||
f->get_results().at(1)->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
TEST(CSE, abs_add_reshape_broadcast)
|
||||
{
|
||||
Shape zero_shape{1};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto abs_a1 = std::make_shared<op::Abs>(A);
|
||||
auto abs_b1 = std::make_shared<op::Abs>(B);
|
||||
auto abs_a2 = std::make_shared<op::Abs>(A);
|
||||
auto abs_b2 = std::make_shared<op::Abs>(B);
|
||||
auto add1 = std::make_shared<op::Add>(abs_a1, abs_b1);
|
||||
auto add2 = std::make_shared<op::Add>(abs_a2, abs_b2);
|
||||
{
|
||||
// success case
|
||||
auto reshape1 = std::make_shared<op::Reshape>(add1, AxisVector{0}, Shape{1, 1});
|
||||
auto reshape2 = std::make_shared<op::Reshape>(add2, AxisVector{0}, Shape{1, 1});
|
||||
auto broadcast1 = std::make_shared<op::Broadcast>(reshape1, Shape{1, 1, 3}, AxisSet{2});
|
||||
auto broadcast2 = std::make_shared<op::Broadcast>(reshape2, Shape{1, 1, 3}, AxisSet{2});
|
||||
auto f =
|
||||
std::make_shared<Function>(NodeVector{broadcast1, broadcast2}, ParameterVector{A, B});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(),
|
||||
f->get_results().at(1)->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
{
|
||||
// fail case
|
||||
auto reshape1 = std::make_shared<op::Reshape>(add1, AxisVector{0}, Shape{1});
|
||||
auto reshape2 = std::make_shared<op::Reshape>(add2, AxisVector{0}, Shape{1, 1});
|
||||
auto f = std::make_shared<Function>(NodeVector{reshape1, reshape2}, ParameterVector{A, B});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_NE(f->get_results().at(0)->input_value(0).get_node_shared_ptr(),
|
||||
f->get_results().at(1)->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
{
|
||||
// fail case
|
||||
auto broadcast1 = std::make_shared<op::Broadcast>(add1, Shape{1, 2}, AxisSet{1});
|
||||
auto broadcast2 = std::make_shared<op::Broadcast>(add2, Shape{1, 1, 2}, AxisSet{1, 2});
|
||||
auto f =
|
||||
std::make_shared<Function>(NodeVector{broadcast1, broadcast2}, ParameterVector{A, B});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_NE(f->get_results().at(0)->input_value(0).get_node_shared_ptr(),
|
||||
f->get_results().at(1)->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CSE, abs_add_abs_add)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto abs_a1 = std::make_shared<op::Abs>(A);
|
||||
auto abs_b1 = std::make_shared<op::Abs>(B);
|
||||
auto abs_a2 = std::make_shared<op::Abs>(A);
|
||||
auto abs_b2 = std::make_shared<op::Abs>(B);
|
||||
auto add1 = std::make_shared<op::Add>(abs_a1, abs_b1);
|
||||
auto add2 = std::make_shared<op::Add>(abs_a2, abs_b2);
|
||||
auto abs_add1 = std::make_shared<op::Abs>(add1);
|
||||
auto abs_add2 = std::make_shared<op::Abs>(add2);
|
||||
auto C = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto add3 = std::make_shared<op::Add>(abs_add1, C);
|
||||
auto add4 = std::make_shared<op::Add>(abs_add2, C);
|
||||
auto f = std::make_shared<Function>(NodeVector{add3, add4}, ParameterVector{A, B, C});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(),
|
||||
f->get_results().at(1)->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
TEST(CSE, abs_add_abs_add_negative)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto B = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto abs_a1 = std::make_shared<op::Abs>(A);
|
||||
auto abs_b1 = std::make_shared<op::Abs>(B);
|
||||
auto abs_a2 = std::make_shared<op::Abs>(A);
|
||||
auto abs_b2 = std::make_shared<op::Abs>(B);
|
||||
auto add1 = std::make_shared<op::Add>(abs_a1, abs_b1);
|
||||
auto add2 = std::make_shared<op::Add>(abs_a2, abs_b2);
|
||||
auto abs_add1 = std::make_shared<op::Abs>(add1);
|
||||
auto abs_add2 = std::make_shared<op::Abs>(add2);
|
||||
auto C = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto D = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto add3 = std::make_shared<op::Add>(abs_add1, C);
|
||||
auto add4 = std::make_shared<op::Add>(abs_add2, D);
|
||||
auto f = std::make_shared<Function>(NodeVector{add3, add4}, ParameterVector{A, B, C, D});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
auto oadd3 = f->get_results().at(0)->input_value(0).get_node_shared_ptr();
|
||||
auto oadd4 = f->get_results().at(1)->input_value(0).get_node_shared_ptr();
|
||||
ASSERT_EQ(oadd3, add3);
|
||||
ASSERT_EQ(oadd4, add4);
|
||||
ASSERT_EQ(oadd3->input_value(1).get_node_shared_ptr(), C);
|
||||
ASSERT_EQ(oadd4->input_value(1).get_node_shared_ptr(), D);
|
||||
ASSERT_EQ(oadd3->input_value(0).get_node_shared_ptr(),
|
||||
oadd4->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void execute_cse_reduction_test()
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, Shape{3, 5});
|
||||
auto a_reduction_op = std::make_shared<T>(A, AxisSet{0, 1});
|
||||
auto a_reduction_op2 = std::make_shared<T>(A, AxisSet{0, 1});
|
||||
auto a_reduction_op3 = std::make_shared<T>(A, AxisSet{0});
|
||||
auto sub_aa = a_reduction_op - a_reduction_op2;
|
||||
|
||||
auto B = std::make_shared<op::Parameter>(element::i32, Shape{3, 5});
|
||||
auto b_reduction_op = std::make_shared<T>(B, AxisSet{0, 1});
|
||||
|
||||
auto sub_ab = a_reduction_op - b_reduction_op;
|
||||
auto f = std::make_shared<Function>(NodeVector{sub_aa, sub_ab, a_reduction_op3},
|
||||
ParameterVector{A, B});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(sub_aa->input_value(0).get_node_shared_ptr(),
|
||||
sub_aa->input_value(1).get_node_shared_ptr());
|
||||
ASSERT_NE(sub_ab->input_value(0).get_node_shared_ptr(),
|
||||
sub_ab->input_value(1).get_node_shared_ptr());
|
||||
ASSERT_NE(f->get_results().at(2)->input_value(0).get_node_shared_ptr(),
|
||||
sub_aa->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
TEST(CSE, reduction_ops)
|
||||
{
|
||||
execute_cse_reduction_test<op::Sum>();
|
||||
execute_cse_reduction_test<op::Product>();
|
||||
}
|
||||
|
||||
TEST(CSE, constant)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto iconst0 = op::Constant::create(element::i32, Shape{}, {0});
|
||||
auto iconst0_1 = op::Constant::create(element::i32, Shape{}, {0});
|
||||
auto iconst1 = op::Constant::create(element::i32, Shape{}, {1});
|
||||
auto iconst1_1 = op::Constant::create(element::i32, Shape{}, {1});
|
||||
auto fconst0 = op::Constant::create(element::f32, Shape{}, {0});
|
||||
auto iconst111 = op::Constant::create(element::i32, Shape{3}, {1, 1, 1});
|
||||
auto iconst112 = op::Constant::create(element::i32, Shape{3}, {1, 1, 2});
|
||||
|
||||
auto abs0 = std::make_shared<op::Abs>(iconst0);
|
||||
auto abs0_1 = std::make_shared<op::Abs>(iconst0_1);
|
||||
|
||||
auto abs1 = std::make_shared<op::Abs>(iconst1);
|
||||
auto abs1_1 = std::make_shared<op::Abs>(iconst1_1);
|
||||
|
||||
auto absf = std::make_shared<op::Abs>(fconst0);
|
||||
|
||||
auto abs111 = std::make_shared<op::Abs>(iconst111);
|
||||
auto abs112 = std::make_shared<op::Abs>(iconst112);
|
||||
|
||||
auto f = std::make_shared<Function>(
|
||||
NodeVector{abs0, abs0_1, abs1, abs1_1, absf, abs111, abs112}, ParameterVector{});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(abs0->input_value(0).get_node_shared_ptr(),
|
||||
abs0_1->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_EQ(abs1->input_value(0).get_node_shared_ptr(),
|
||||
abs1_1->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_NE(abs0->input_value(0).get_node_shared_ptr(),
|
||||
abs1->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_NE(abs0->input_value(0).get_node_shared_ptr(),
|
||||
absf->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_NE(abs111->input_value(0).get_node_shared_ptr(),
|
||||
abs112->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
|
||||
TEST(CSE, one_hot)
|
||||
{
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<ngraph::pass::CommonSubexpressionElimination>();
|
||||
{
|
||||
Shape param_shape{8};
|
||||
Shape out_shape{8, 16};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, param_shape);
|
||||
auto onehot1 = std::make_shared<op::OneHot>(A, out_shape, 1);
|
||||
auto onehot2 = std::make_shared<op::OneHot>(A, out_shape, 1);
|
||||
auto f = std::make_shared<Function>(NodeVector{onehot1, onehot2}, ParameterVector{A});
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(),
|
||||
f->get_results().at(1)->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
{
|
||||
Shape param_shape{8, 1};
|
||||
Shape out_shape{8, 16};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, param_shape);
|
||||
auto reshape1 = std::make_shared<op::Reshape>(A, AxisVector{0, 1}, Shape{8});
|
||||
auto reshape2 = std::make_shared<op::Reshape>(A, AxisVector{0, 1}, Shape{8});
|
||||
auto onehot1 = std::make_shared<op::OneHot>(reshape1, out_shape, 1);
|
||||
auto onehot2 = std::make_shared<op::OneHot>(reshape2, out_shape, 1);
|
||||
auto f = std::make_shared<Function>(NodeVector{onehot1, onehot2}, ParameterVector{A});
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(),
|
||||
f->get_results().at(1)->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CSE, pass_property)
|
||||
{
|
||||
auto pass = std::make_shared<ngraph::pass::CommonSubexpressionElimination>();
|
||||
ASSERT_TRUE(pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
|
||||
ASSERT_FALSE(pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
|
||||
}
|
@ -1,52 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/pass/constant_to_broadcast.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/visualize_tree.hpp"
|
||||
#include "ngraph/serializer.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
TEST(pass, constant_to_broadcast)
|
||||
{
|
||||
Shape shape{128, 256, 1, 1};
|
||||
vector<float> v = {3};
|
||||
auto c = make_shared<op::Constant>(element::f32, shape, v);
|
||||
auto f = make_shared<Function>(c, ParameterVector{});
|
||||
|
||||
{
|
||||
ngraph::pass::Manager pm;
|
||||
pm.register_pass<pass::ConstantToBroadcast>();
|
||||
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 0);
|
||||
pm.run_passes(f);
|
||||
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 1);
|
||||
}
|
||||
}
|
@ -23,11 +23,9 @@
|
||||
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/pass/dump_sorted.hpp"
|
||||
#include "ngraph/pass/liveness.hpp"
|
||||
#include "ngraph/pass/liveness.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/visualize_tree.hpp"
|
||||
#include "pass/liveness.hpp"
|
||||
|
||||
#include "util/test_tools.hpp"
|
||||
|
||||
|
@ -54,16 +54,3 @@ namespace
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> /* f */) override { return false; }
|
||||
};
|
||||
}
|
||||
|
||||
// Regression test: We've had an issue in the past where enabling per-pass validation and
|
||||
// per-pass serialization at the same time causes a crash.
|
||||
TEST(pass_manager, serialize_with_revalidate_does_not_crash)
|
||||
{
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.set_per_pass_validation(true);
|
||||
pass_manager.set_pass_serialization(true);
|
||||
shared_ptr<DummyPass> dummy = pass_manager.register_pass<DummyPass>();
|
||||
|
||||
auto graph = make_test_graph();
|
||||
pass_manager.run_passes(graph);
|
||||
}
|
||||
|
@ -1,234 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/pass/dump_sorted.hpp"
|
||||
#include "ngraph/pass/liveness.hpp"
|
||||
#include "ngraph/pass/liveness.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/memory_layout.hpp"
|
||||
#include "ngraph/pass/visualize_tree.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
static vector<pass::MemoryManager::node> get_node_list(const pass::MemoryManager& mm)
|
||||
{
|
||||
vector<pass::MemoryManager::node> rc;
|
||||
rc.insert(rc.end(), mm.begin(), mm.end());
|
||||
return rc;
|
||||
}
|
||||
|
||||
TEST(memory_manager, allocate)
|
||||
{
|
||||
pass::MemoryManager mm{1};
|
||||
|
||||
// Special case, allocating size zero bumps the size of the alloc up to the alignment size
|
||||
EXPECT_EQ(0, mm.allocate(0));
|
||||
EXPECT_EQ(1, mm.allocate(10));
|
||||
EXPECT_EQ(11, mm.allocate(10));
|
||||
EXPECT_EQ(21, mm.allocate(10));
|
||||
}
|
||||
|
||||
TEST(memory_manager, free_first_allocated)
|
||||
{
|
||||
pass::MemoryManager mm{1};
|
||||
|
||||
EXPECT_EQ(0, mm.allocate(10));
|
||||
EXPECT_EQ(10, mm.allocate(10));
|
||||
EXPECT_EQ(3, mm.get_node_list().size());
|
||||
|
||||
mm.free(0);
|
||||
|
||||
auto node_list = get_node_list(mm);
|
||||
EXPECT_EQ(3, node_list.size());
|
||||
EXPECT_TRUE(node_list[0].is_free());
|
||||
EXPECT_FALSE(node_list[1].is_free());
|
||||
EXPECT_TRUE(node_list[2].is_free());
|
||||
}
|
||||
|
||||
TEST(memory_manager, free_middle_allocated)
|
||||
{
|
||||
pass::MemoryManager mm{1};
|
||||
|
||||
EXPECT_EQ(0, mm.allocate(10));
|
||||
EXPECT_EQ(10, mm.allocate(10));
|
||||
EXPECT_EQ(20, mm.allocate(10));
|
||||
EXPECT_EQ(30, mm.allocate(10));
|
||||
EXPECT_EQ(40, mm.allocate(10));
|
||||
EXPECT_EQ(6, mm.get_node_list().size());
|
||||
|
||||
mm.free(10);
|
||||
|
||||
auto node_list = get_node_list(mm);
|
||||
EXPECT_EQ(6, node_list.size());
|
||||
EXPECT_FALSE(node_list[0].is_free());
|
||||
EXPECT_TRUE(node_list[1].is_free());
|
||||
EXPECT_FALSE(node_list[2].is_free());
|
||||
EXPECT_FALSE(node_list[3].is_free());
|
||||
EXPECT_FALSE(node_list[4].is_free());
|
||||
}
|
||||
|
||||
TEST(memory_manager, free_last_allocated)
|
||||
{
|
||||
pass::MemoryManager mm{1};
|
||||
|
||||
EXPECT_EQ(0, mm.allocate(10));
|
||||
EXPECT_EQ(10, mm.allocate(10));
|
||||
EXPECT_EQ(20, mm.allocate(10));
|
||||
EXPECT_EQ(30, mm.allocate(10));
|
||||
EXPECT_EQ(40, mm.allocate(10));
|
||||
EXPECT_EQ(6, mm.get_node_list().size());
|
||||
|
||||
mm.free(40);
|
||||
|
||||
auto node_list = get_node_list(mm);
|
||||
EXPECT_EQ(5, node_list.size());
|
||||
EXPECT_FALSE(node_list[0].is_free());
|
||||
EXPECT_FALSE(node_list[1].is_free());
|
||||
EXPECT_FALSE(node_list[2].is_free());
|
||||
EXPECT_FALSE(node_list[3].is_free());
|
||||
EXPECT_TRUE(node_list[4].is_free());
|
||||
}
|
||||
|
||||
TEST(memory_manager, free_first_free)
|
||||
{
|
||||
pass::MemoryManager mm{1};
|
||||
|
||||
EXPECT_EQ(0, mm.allocate(10));
|
||||
EXPECT_EQ(10, mm.allocate(10));
|
||||
EXPECT_EQ(20, mm.allocate(10));
|
||||
EXPECT_EQ(30, mm.allocate(10));
|
||||
EXPECT_EQ(40, mm.allocate(10));
|
||||
EXPECT_EQ(6, mm.get_node_list().size());
|
||||
|
||||
mm.free(10);
|
||||
mm.free(0);
|
||||
|
||||
auto node_list = get_node_list(mm);
|
||||
EXPECT_EQ(5, node_list.size());
|
||||
EXPECT_TRUE(node_list[0].is_free());
|
||||
EXPECT_FALSE(node_list[1].is_free());
|
||||
EXPECT_FALSE(node_list[2].is_free());
|
||||
EXPECT_FALSE(node_list[3].is_free());
|
||||
}
|
||||
|
||||
TEST(memory_manager, free_middle_free)
|
||||
{
|
||||
pass::MemoryManager mm{1};
|
||||
|
||||
EXPECT_EQ(0, mm.allocate(10));
|
||||
EXPECT_EQ(10, mm.allocate(10));
|
||||
EXPECT_EQ(20, mm.allocate(10));
|
||||
EXPECT_EQ(30, mm.allocate(10));
|
||||
EXPECT_EQ(40, mm.allocate(10));
|
||||
EXPECT_EQ(6, mm.get_node_list().size());
|
||||
|
||||
mm.free(0);
|
||||
mm.free(20);
|
||||
mm.free(10);
|
||||
|
||||
auto node_list = get_node_list(mm);
|
||||
EXPECT_EQ(4, node_list.size());
|
||||
EXPECT_TRUE(node_list[0].is_free());
|
||||
EXPECT_FALSE(node_list[1].is_free());
|
||||
EXPECT_FALSE(node_list[2].is_free());
|
||||
}
|
||||
|
||||
TEST(memory_manager, max_allocated)
|
||||
{
|
||||
pass::MemoryManager mm{1};
|
||||
|
||||
EXPECT_EQ(0, mm.allocate(10));
|
||||
EXPECT_EQ(10, mm.allocate(10));
|
||||
EXPECT_EQ(20, mm.allocate(10));
|
||||
EXPECT_EQ(30, mm.allocate(10));
|
||||
EXPECT_EQ(40, mm.allocate(10));
|
||||
EXPECT_EQ(6, mm.get_node_list().size());
|
||||
|
||||
mm.free(0);
|
||||
mm.free(20);
|
||||
mm.free(10);
|
||||
|
||||
EXPECT_EQ(mm.max_allocated(), 50);
|
||||
}
|
||||
|
||||
TEST(memory_manager, bad_free)
|
||||
{
|
||||
pass::MemoryManager mm{1};
|
||||
|
||||
EXPECT_THROW(mm.free(10), std::runtime_error);
|
||||
}
|
||||
|
||||
TEST(memory_manager, align)
|
||||
{
|
||||
EXPECT_EQ(8, pass::MemoryManager::align(0, 8));
|
||||
EXPECT_EQ(8, pass::MemoryManager::align(1, 8));
|
||||
EXPECT_EQ(8, pass::MemoryManager::align(2, 8));
|
||||
EXPECT_EQ(8, pass::MemoryManager::align(3, 8));
|
||||
EXPECT_EQ(8, pass::MemoryManager::align(4, 8));
|
||||
EXPECT_EQ(8, pass::MemoryManager::align(5, 8));
|
||||
EXPECT_EQ(8, pass::MemoryManager::align(6, 8));
|
||||
EXPECT_EQ(8, pass::MemoryManager::align(7, 8));
|
||||
EXPECT_EQ(8, pass::MemoryManager::align(8, 8));
|
||||
EXPECT_EQ(16, pass::MemoryManager::align(9, 8));
|
||||
}
|
||||
|
||||
TEST(memory_manager, memory_align)
|
||||
{
|
||||
pass::MemoryManager mm{64};
|
||||
|
||||
EXPECT_EQ(0, mm.allocate(4));
|
||||
EXPECT_EQ(64, mm.allocate(4));
|
||||
EXPECT_EQ(128, mm.allocate(4));
|
||||
}
|
||||
|
||||
TEST(memory_layout, basic)
|
||||
{
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::Liveness>();
|
||||
pass_manager.register_pass<pass::MemoryLayout>();
|
||||
|
||||
auto graph = make_test_graph();
|
||||
pass_manager.run_passes(graph);
|
||||
auto sorted = graph->get_ordered_ops();
|
||||
size_t temporary_pool_size = graph->get_temporary_pool_size();
|
||||
EXPECT_EQ(12, temporary_pool_size);
|
||||
}
|
||||
|
||||
TEST(memory_layout, constant)
|
||||
{
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::Liveness>();
|
||||
pass_manager.register_pass<pass::MemoryLayout>();
|
||||
|
||||
Shape shape{1};
|
||||
auto c = op::Constant::create(element::i32, shape, {5});
|
||||
auto f = make_shared<Function>(make_shared<op::Negative>(c), ParameterVector{});
|
||||
|
||||
pass_manager.run_passes(f);
|
||||
auto sorted = f->get_ordered_ops();
|
||||
size_t temporary_pool_size = f->get_temporary_pool_size();
|
||||
EXPECT_EQ(4, temporary_pool_size);
|
||||
}
|
@ -23,7 +23,7 @@
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/shape_relevance.hpp"
|
||||
#include "pass/shape_relevance.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
@ -25,11 +25,11 @@
|
||||
#include "ngraph/builder/norm.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/pass/fused_op_decomposition.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/provenance.hpp"
|
||||
#include "opset0_downgrade.hpp"
|
||||
#include "opset1_upgrade.hpp"
|
||||
#include "pass/fused_op_decomposition.hpp"
|
||||
#include "util/provenance_enabler.hpp"
|
||||
|
||||
using namespace std;
|
||||
|
@ -1,497 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/file_util.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/pass/graph_rewrite.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/reshape_elimination.hpp"
|
||||
#include "ngraph/pass/visualize_tree.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
#include "ngraph/pattern/op/label.hpp"
|
||||
#include "ngraph/pattern/op/skip.hpp"
|
||||
#include "ngraph/serializer.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "util/all_close.hpp"
|
||||
#include "util/matcher.hpp"
|
||||
#include "util/random.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
#ifndef NGRAPH_JSON_DISABLE
|
||||
TEST(reshape_elimination, remove_reshape)
|
||||
{
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::ReshapeElimination>();
|
||||
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_fprop.json");
|
||||
const string json_string = file_util::read_file_to_string(json_path);
|
||||
stringstream ss(json_string);
|
||||
shared_ptr<Function> func = ngraph::deserialize(ss);
|
||||
size_t count_before = count_ops_of_type<op::Reshape>(func);
|
||||
pass_manager.run_passes(func);
|
||||
size_t count_after = count_ops_of_type<op::Reshape>(func);
|
||||
ASSERT_TRUE(count_after < count_before);
|
||||
}
|
||||
|
||||
TEST(reshape_elimination, remove_tranpose)
|
||||
{
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::ReshapeElimination>();
|
||||
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/tranpose.json");
|
||||
const string json_string = file_util::read_file_to_string(json_path);
|
||||
stringstream ss(json_string);
|
||||
shared_ptr<Function> func = ngraph::deserialize(ss);
|
||||
size_t count_before = count_ops_of_type<op::Reshape>(func);
|
||||
pass_manager.run_passes(func);
|
||||
size_t count_after = count_ops_of_type<op::Reshape>(func);
|
||||
ASSERT_TRUE(count_after < count_before);
|
||||
}
|
||||
|
||||
TEST(reshape_elimination, bn_bprop_rewrite)
|
||||
{
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::ReshapeElimination>();
|
||||
const string json_path = file_util::path_join(SERIALIZED_ZOO, "mxnet/bn_bprop.json");
|
||||
const string json_string = file_util::read_file_to_string(json_path);
|
||||
stringstream ss(json_string);
|
||||
shared_ptr<Function> func = ngraph::deserialize(ss);
|
||||
size_t count_before = count_ops_of_type<op::Reshape>(func);
|
||||
pass_manager.run_passes(func);
|
||||
size_t count_after = count_ops_of_type<op::Reshape>(func);
|
||||
ASSERT_TRUE(count_after < count_before);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef NGRAPH_INTERPRETER_ENABLE
|
||||
TEST(reshape_elimination, transpose_reshape_pattern_fuse)
|
||||
{
|
||||
auto generate_func = []() {
|
||||
auto input = make_shared<op::Parameter>(element::f32, Shape{8, 2, 4, 6});
|
||||
auto transpose = make_shared<op::Reshape>(input, AxisVector{0, 2, 1, 3}, Shape{8, 2, 4, 6});
|
||||
auto reshape =
|
||||
make_shared<op::Reshape>(transpose, AxisVector{0, 1, 2, 3}, Shape{8, 4, 2, 6});
|
||||
return make_shared<Function>(reshape, ParameterVector{input});
|
||||
};
|
||||
|
||||
auto fuse_func = generate_func();
|
||||
auto nofuse_func = generate_func();
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::ReshapeElimination>();
|
||||
pass_manager.run_passes(fuse_func);
|
||||
ASSERT_TRUE(count_ops_of_type<op::Reshape>(fuse_func) == 1);
|
||||
ASSERT_TRUE(count_ops_of_type<op::Reshape>(nofuse_func) == 2);
|
||||
|
||||
test::Uniform<float> rng(0.0f, 100.0f);
|
||||
vector<vector<float>> args;
|
||||
vector<float> tensor_val(shape_size(Shape{8, 2, 4, 6}));
|
||||
rng.initialize(tensor_val);
|
||||
args.push_back(tensor_val);
|
||||
|
||||
auto baseline_results = execute(fuse_func, args, "INTERPRETER");
|
||||
auto optimized_results = execute(nofuse_func, args, "INTERPRETER");
|
||||
|
||||
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(reshape_elimination, transpose_reshape_pattern_nofuse)
|
||||
{
|
||||
auto input = make_shared<op::Parameter>(element::f32, Shape{8, 2, 4, 6});
|
||||
auto transpose = make_shared<op::Reshape>(input, AxisVector{0, 2, 1, 3}, Shape{8, 2, 4, 6});
|
||||
auto reshape = make_shared<op::Reshape>(transpose, AxisVector{2, 1, 0, 3}, Shape{8, 4, 2, 6});
|
||||
auto f = make_shared<Function>(reshape, ParameterVector{input});
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::ReshapeElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_TRUE(count_ops_of_type<op::Reshape>(f) == 2);
|
||||
}
|
||||
|
||||
TEST(reshape_elimination, dot_transpose_to_dot_w_transpose_args)
|
||||
{
|
||||
Shape shape_w{2, 4};
|
||||
Shape shape_x{4, 1};
|
||||
auto W = make_shared<op::Parameter>(element::f32, shape_w);
|
||||
auto x = make_shared<op::Parameter>(element::f32, shape_x);
|
||||
|
||||
auto dot = make_shared<op::Dot>(W, x);
|
||||
auto reshape_dot = std::make_shared<op::Reshape>(dot, AxisVector{1, 0}, Shape{1, 2});
|
||||
auto graph = make_shared<op::Abs>(reshape_dot);
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::ReshapeElimination>();
|
||||
auto func = make_shared<Function>(graph, ParameterVector{W, x});
|
||||
pass_manager.run_passes(func);
|
||||
auto gdot = graph->input_value(0).get_node_shared_ptr();
|
||||
ASSERT_TRUE(as_type_ptr<op::Dot>(gdot));
|
||||
ASSERT_TRUE(as_type_ptr<op::Reshape>(gdot->input_value(0).get_node_shared_ptr()));
|
||||
ASSERT_TRUE(as_type_ptr<op::Reshape>(gdot->input_value(1).get_node_shared_ptr()));
|
||||
ASSERT_EQ(gdot->input_value(0).get_node_shared_ptr()->input_value(0).get_node_shared_ptr(), x);
|
||||
ASSERT_EQ(gdot->input_value(1).get_node_shared_ptr()->input_value(0).get_node_shared_ptr(), W);
|
||||
ASSERT_EQ(gdot->get_shape(), (Shape{1, 2}));
|
||||
}
|
||||
|
||||
#ifdef NGRAPH_INTERPRETER_ENABLE
|
||||
TEST(reshape_elimination, recurrent_reshapes)
|
||||
{
|
||||
Shape shape_a{2, 2, 3, 3, 2, 4};
|
||||
auto generate_func = [shape_a]() {
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_a);
|
||||
Shape shape_r_1{3, 2, 2, 4, 6};
|
||||
Shape shape_r_2{6, 8, 3, 2};
|
||||
Shape shape_r_3{6, 8, 6};
|
||||
Shape shape_r_4{6, 2, 2, 2, 6};
|
||||
Shape shape_r_5{2, 3, 2, 2, 2, 3, 2};
|
||||
Shape shape_r_6{48, 6};
|
||||
|
||||
auto r_1 = make_shared<op::Reshape>(A, AxisVector{2, 4, 0, 5, 3, 1}, shape_r_1);
|
||||
auto r_2 = make_shared<op::Reshape>(r_1, AxisVector{0, 1, 2, 3, 4}, shape_r_2);
|
||||
auto r_3 = make_shared<op::Reshape>(r_2, AxisVector{0, 1, 2, 3}, shape_r_3);
|
||||
auto r_4 = make_shared<op::Reshape>(r_3, AxisVector{0, 1, 2}, shape_r_4);
|
||||
auto r_5 = make_shared<op::Reshape>(r_4, AxisVector{0, 1, 2, 3, 4}, shape_r_5);
|
||||
auto r_6 = make_shared<op::Reshape>(r_5, AxisVector{0, 1, 2, 3, 4, 5, 6}, shape_r_6);
|
||||
|
||||
auto f = make_shared<Function>(r_6, ParameterVector{A});
|
||||
return f;
|
||||
};
|
||||
|
||||
auto baseline_f = generate_func();
|
||||
auto optimized_f = generate_func();
|
||||
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
|
||||
|
||||
pass::Manager pass_manager;
|
||||
// pass_manager.register_pass<pass::VisualizeTree>("before_recurrent_reshapes.png");
|
||||
pass_manager.register_pass<pass::RecurrentReshapeElimination>();
|
||||
// pass_manager.register_pass<pass::VisualizeTree>("after_recurrent_reshapes.png");
|
||||
pass_manager.run_passes(optimized_f);
|
||||
|
||||
test::Uniform<float> rng(0.0f, 100.0f);
|
||||
vector<vector<float>> args;
|
||||
vector<float> tensor_val(shape_size(baseline_input_shape));
|
||||
rng.initialize(tensor_val);
|
||||
args.push_back(tensor_val);
|
||||
|
||||
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
|
||||
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
|
||||
|
||||
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
|
||||
|
||||
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
|
||||
ASSERT_EQ(num_reshapes_optimized, 1);
|
||||
}
|
||||
|
||||
TEST(reshape_elimination, recurrent_reshapes_elimination)
|
||||
{
|
||||
Shape shape_a{2, 2, 3, 3, 2, 4};
|
||||
auto generate_func = [shape_a]() {
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_a);
|
||||
Shape shape_r_1{3, 2, 2, 4, 6};
|
||||
Shape shape_r_2{6, 8, 3, 2};
|
||||
Shape shape_r_3{6, 8, 6};
|
||||
Shape shape_r_4{6, 2, 2, 2, 6};
|
||||
Shape shape_r_5{2, 3, 2, 2, 2, 3, 2};
|
||||
Shape shape_r_6{48, 6};
|
||||
Shape shape_r_7{2, 2, 3, 3, 2, 4};
|
||||
|
||||
auto r_1 = make_shared<op::Reshape>(A, AxisVector{0, 1, 2, 3, 4, 5}, shape_r_1);
|
||||
auto r_2 = make_shared<op::Reshape>(r_1, AxisVector{0, 1, 2, 3, 4}, shape_r_2);
|
||||
auto r_3 = make_shared<op::Reshape>(r_2, AxisVector{0, 1, 2, 3}, shape_r_3);
|
||||
auto r_4 = make_shared<op::Reshape>(r_3, AxisVector{0, 1, 2}, shape_r_4);
|
||||
auto r_5 = make_shared<op::Reshape>(r_4, AxisVector{0, 1, 2, 3, 4}, shape_r_5);
|
||||
auto r_6 = make_shared<op::Reshape>(r_5, AxisVector{0, 1, 2, 3, 4, 5, 6}, shape_r_6);
|
||||
auto r_7 = make_shared<op::Reshape>(r_6, AxisVector{0, 1}, shape_r_7);
|
||||
auto f = make_shared<Function>(r_7, ParameterVector{A});
|
||||
return f;
|
||||
};
|
||||
|
||||
auto baseline_f = generate_func();
|
||||
auto optimized_f = generate_func();
|
||||
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
|
||||
|
||||
pass::Manager pass_manager;
|
||||
// pass_manager.register_pass<pass::VisualizeTree>("before_recurrent_reshapes_elimination.png");
|
||||
pass_manager.register_pass<pass::RecurrentReshapeElimination>();
|
||||
// pass_manager.register_pass<pass::VisualizeTree>("after_1_recurrent_reshapes_elimination.png");
|
||||
pass_manager.register_pass<pass::ReshapeElimination>();
|
||||
// pass_manager.register_pass<pass::VisualizeTree>("after_2_recurrent_reshapes_elimination.png");
|
||||
pass_manager.run_passes(optimized_f);
|
||||
|
||||
test::Uniform<float> rng(0.0f, 100.0f);
|
||||
vector<vector<float>> args;
|
||||
vector<float> tensor_val(shape_size(baseline_input_shape));
|
||||
rng.initialize(tensor_val);
|
||||
args.push_back(tensor_val);
|
||||
|
||||
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
|
||||
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
|
||||
|
||||
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
|
||||
|
||||
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
|
||||
ASSERT_EQ(num_reshapes_optimized, 0);
|
||||
}
|
||||
|
||||
TEST(reshape_elimination, recurrent_reshapes_fan_out)
|
||||
{
|
||||
Shape shape_a{4, 6, 10, 2};
|
||||
auto generate_func = [shape_a]() {
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_a);
|
||||
Shape shape_r_1{6, 4, 5, 4};
|
||||
Shape shape_r_2{24, 20};
|
||||
auto reshape_1 = make_shared<op::Reshape>(A, AxisVector{0, 3, 2, 1}, shape_r_1);
|
||||
auto reshape_2 = make_shared<op::Reshape>(reshape_1, AxisVector{0, 1, 2, 3}, shape_r_2);
|
||||
auto reshape_3 = make_shared<op::Reshape>(reshape_2, AxisVector{0, 1}, shape_a);
|
||||
auto f_ = make_shared<Function>(NodeVector{reshape_2, reshape_3}, ParameterVector{A});
|
||||
return f_;
|
||||
};
|
||||
|
||||
auto baseline_f = generate_func();
|
||||
auto optimized_f = generate_func();
|
||||
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
|
||||
|
||||
pass::Manager pass_manager;
|
||||
// pass_manager.register_pass<pass::VisualizeTree>("before_recurrent_reshapes_fan_out.png");
|
||||
pass_manager.register_pass<pass::RecurrentReshapeElimination>();
|
||||
// pass_manager.register_pass<pass::VisualizeTree>("after_recurrent_reshapes_fan_out.png");
|
||||
pass_manager.run_passes(optimized_f);
|
||||
|
||||
test::Uniform<float> rng(0.0f, 100.0f);
|
||||
vector<vector<float>> args;
|
||||
vector<float> tensor_val(shape_size(baseline_input_shape));
|
||||
rng.initialize(tensor_val);
|
||||
args.push_back(tensor_val);
|
||||
|
||||
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
|
||||
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
|
||||
|
||||
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
|
||||
|
||||
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
|
||||
ASSERT_EQ(num_reshapes_optimized, 2);
|
||||
}
|
||||
|
||||
TEST(reshape_elimination, recurrent_reshapes_fan_out_at_end)
|
||||
{
|
||||
Shape shape_a{12, 8, 1, 1};
|
||||
auto generate_func = [shape_a]() {
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_a);
|
||||
|
||||
auto reshape_1 = make_shared<op::Reshape>(A, AxisVector{0, 3, 2, 1}, Shape{4, 3, 8, 1});
|
||||
auto reshape_2 = make_shared<op::Reshape>(reshape_1, AxisVector{0, 1, 2, 3}, shape_a);
|
||||
auto reshape_3 =
|
||||
make_shared<op::Reshape>(reshape_2, AxisVector{0, 1, 2, 3}, Shape{4, 3, 8, 1});
|
||||
auto abs_1 = make_shared<op::Abs>(reshape_3);
|
||||
auto f_ = make_shared<Function>(NodeVector{abs_1, reshape_3}, ParameterVector{A});
|
||||
return f_;
|
||||
};
|
||||
|
||||
auto baseline_f = generate_func();
|
||||
auto optimized_f = generate_func();
|
||||
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
|
||||
|
||||
pass::Manager pass_manager;
|
||||
// pass_manager.register_pass<pass::VisualizeTree>("before_recurrent_reshapes_fan_out_at_end.png");
|
||||
pass_manager.register_pass<pass::RecurrentReshapeElimination>();
|
||||
// pass_manager.register_pass<pass::VisualizeTree>("after_recurrent_reshapes_fan_out_at_end.png");
|
||||
pass_manager.run_passes(optimized_f);
|
||||
|
||||
test::Uniform<float> rng(0.0f, 100.0f);
|
||||
vector<vector<float>> args;
|
||||
vector<float> tensor_val(shape_size(baseline_input_shape));
|
||||
rng.initialize(tensor_val);
|
||||
args.push_back(tensor_val);
|
||||
|
||||
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
|
||||
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
|
||||
|
||||
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
|
||||
|
||||
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
|
||||
ASSERT_EQ(num_reshapes_optimized, 1);
|
||||
}
|
||||
|
||||
TEST(reshape_elimination, recurrent_reshapes_multiple_fusions)
|
||||
{
|
||||
Shape shape_a{2, 2, 3, 3, 2, 4};
|
||||
auto generate_func = [shape_a]() {
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_a);
|
||||
Shape shape_r_1{3, 2, 2, 4, 6};
|
||||
Shape shape_r_2{6, 8, 3, 2};
|
||||
Shape shape_r_3{6, 8, 6};
|
||||
Shape shape_r_4{6, 2, 2, 2, 6};
|
||||
Shape shape_r_5{2, 3, 2, 2, 2, 3, 2};
|
||||
Shape shape_r_6{48, 6};
|
||||
|
||||
auto r_1 = make_shared<op::Reshape>(A, AxisVector{2, 4, 0, 5, 3, 1}, shape_r_1);
|
||||
auto r_2 = make_shared<op::Reshape>(r_1, AxisVector{0, 1, 2, 3, 4}, shape_r_2);
|
||||
auto r_3 = make_shared<op::Reshape>(r_2, AxisVector{0, 1, 2, 3}, shape_r_3);
|
||||
auto r_4 = make_shared<op::Reshape>(r_3, AxisVector{1, 0, 2}, shape_r_4);
|
||||
auto r_5 = make_shared<op::Reshape>(r_4, AxisVector{0, 1, 2, 3, 4}, shape_r_5);
|
||||
auto r_6 = make_shared<op::Reshape>(r_5, AxisVector{0, 1, 2, 3, 4, 5, 6}, shape_r_6);
|
||||
|
||||
auto f = make_shared<Function>(r_6, ParameterVector{A});
|
||||
return f;
|
||||
};
|
||||
|
||||
auto baseline_f = generate_func();
|
||||
auto optimized_f = generate_func();
|
||||
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
|
||||
|
||||
pass::Manager pass_manager;
|
||||
// pass_manager.register_pass<pass::VisualizeTree>(
|
||||
// "before_recurrent_reshapes_multiple_fusions.png");
|
||||
pass_manager.register_pass<pass::RecurrentReshapeElimination>();
|
||||
// pass_manager.register_pass<pass::VisualizeTree>(
|
||||
// "after_recurrent_reshapes_multiple_fusions.png");
|
||||
pass_manager.run_passes(optimized_f);
|
||||
|
||||
test::Uniform<float> rng(0.0f, 100.0f);
|
||||
vector<vector<float>> args;
|
||||
vector<float> tensor_val(shape_size(baseline_input_shape));
|
||||
rng.initialize(tensor_val);
|
||||
args.push_back(tensor_val);
|
||||
|
||||
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
|
||||
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
|
||||
|
||||
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
|
||||
|
||||
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
|
||||
ASSERT_EQ(num_reshapes_optimized, 2);
|
||||
}
|
||||
|
||||
TEST(reshape_elimination, nonrecurrent_reshapes)
|
||||
{
|
||||
Shape shape_a{8, 6, 1, 1};
|
||||
Shape shape_r{2, 24};
|
||||
auto generate_func = [shape_a, shape_r]() {
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_a);
|
||||
|
||||
auto reshape_1 = make_shared<op::Reshape>(A, AxisVector{3, 0, 2, 1}, shape_r);
|
||||
auto abs_1 = make_shared<op::Abs>(reshape_1);
|
||||
auto reshape_2 = make_shared<op::Reshape>(abs_1, AxisVector{0, 1}, shape_a);
|
||||
auto abs_2 = make_shared<op::Abs>(reshape_2);
|
||||
auto reshape_3 = make_shared<op::Reshape>(abs_2, AxisVector{0, 1, 2, 3}, shape_a);
|
||||
auto f_ = make_shared<Function>(NodeVector{reshape_3}, ParameterVector{A});
|
||||
return f_;
|
||||
};
|
||||
|
||||
auto baseline_f = generate_func();
|
||||
auto optimized_f = generate_func();
|
||||
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
|
||||
|
||||
pass::Manager pass_manager;
|
||||
// pass_manager.register_pass<pass::VisualizeTree>("before_nonrecurrent_reshapes.png");
|
||||
pass_manager.register_pass<pass::RecurrentReshapeElimination>();
|
||||
// pass_manager.register_pass<pass::VisualizeTree>("after_nonrecurrent_reshapes.png");
|
||||
pass_manager.run_passes(optimized_f);
|
||||
|
||||
test::Uniform<float> rng(0.0f, 100.0f);
|
||||
vector<vector<float>> args;
|
||||
vector<float> tensor_val(shape_size(baseline_input_shape));
|
||||
rng.initialize(tensor_val);
|
||||
args.push_back(tensor_val);
|
||||
|
||||
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
|
||||
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
|
||||
|
||||
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
|
||||
|
||||
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
|
||||
ASSERT_EQ(num_reshapes_optimized, 3);
|
||||
}
|
||||
|
||||
TEST(reshape_elimination, recurrent_reshapes_multiple_branches)
|
||||
{
|
||||
Shape shape_a{2, 2, 3, 3, 2, 4};
|
||||
auto generate_func = [shape_a]() {
|
||||
auto A = make_shared<op::Parameter>(element::f32, shape_a);
|
||||
Shape shape_r_1{3, 2, 2, 4, 6};
|
||||
Shape shape_r_2{6, 8, 3, 2};
|
||||
Shape shape_r_3{6, 8, 6};
|
||||
Shape shape_r_4{6, 2, 2, 2, 6};
|
||||
Shape shape_r_5{2, 3, 2, 2, 2, 3, 2};
|
||||
Shape shape_r_6{48, 6};
|
||||
|
||||
auto r_1 = make_shared<op::Reshape>(A, AxisVector{2, 4, 0, 5, 3, 1}, shape_r_1);
|
||||
auto r_2 = make_shared<op::Reshape>(r_1, AxisVector{0, 1, 2, 3, 4}, shape_r_2);
|
||||
auto r_3 = make_shared<op::Reshape>(r_2, AxisVector{0, 1, 2, 3}, shape_r_3);
|
||||
auto r_4 = make_shared<op::Reshape>(r_3, AxisVector{0, 1, 2}, shape_r_4);
|
||||
auto r_5 = make_shared<op::Reshape>(r_4, AxisVector{0, 1, 2, 3, 4}, shape_r_5);
|
||||
auto r_6 = make_shared<op::Reshape>(r_5, AxisVector{0, 1, 2, 3, 4, 5, 6}, shape_r_6);
|
||||
|
||||
auto r_7 = make_shared<op::Reshape>(A, AxisVector{2, 4, 0, 5, 3, 1}, shape_r_2);
|
||||
auto r_8 = make_shared<op::Reshape>(r_7, AxisVector{0, 1, 2, 3}, shape_r_3);
|
||||
|
||||
auto f = make_shared<Function>(NodeVector{r_6, r_8}, ParameterVector{A});
|
||||
return f;
|
||||
};
|
||||
|
||||
auto baseline_f = generate_func();
|
||||
auto optimized_f = generate_func();
|
||||
auto baseline_input_shape = baseline_f->get_parameters().at(0)->get_shape();
|
||||
|
||||
pass::Manager pass_manager;
|
||||
// pass_manager.register_pass<pass::VisualizeTree>(
|
||||
// "before_recurrent_reshapes_multiple_branches.png");
|
||||
pass_manager.register_pass<pass::RecurrentReshapeElimination>();
|
||||
// pass_manager.register_pass<pass::VisualizeTree>(
|
||||
// "after_recurrent_reshapes_multiple_branches.png");
|
||||
pass_manager.run_passes(optimized_f);
|
||||
|
||||
test::Uniform<float> rng(0.0f, 100.0f);
|
||||
vector<vector<float>> args;
|
||||
vector<float> tensor_val(shape_size(baseline_input_shape));
|
||||
rng.initialize(tensor_val);
|
||||
args.push_back(tensor_val);
|
||||
|
||||
auto baseline_results = execute(baseline_f, args, "INTERPRETER");
|
||||
auto optimized_results = execute(optimized_f, args, "INTERPRETER");
|
||||
|
||||
EXPECT_TRUE(test::all_close(baseline_results.at(0), optimized_results.at(0)));
|
||||
|
||||
size_t num_reshapes_optimized = count_ops_of_type<op::Reshape>(optimized_f);
|
||||
ASSERT_EQ(num_reshapes_optimized, 2);
|
||||
}
|
||||
#endif
|
||||
|
||||
TEST(reshape_elimination, pass_property)
|
||||
{
|
||||
{
|
||||
auto pass = std::make_shared<ngraph::pass::ReshapeElimination>();
|
||||
ASSERT_FALSE(pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
|
||||
ASSERT_FALSE(pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
|
||||
}
|
||||
{
|
||||
auto pass = std::make_shared<ngraph::pass::RecurrentReshapeElimination>();
|
||||
ASSERT_FALSE(pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
|
||||
ASSERT_FALSE(pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
|
||||
}
|
||||
}
|
@ -1,184 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/file_util.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/batch_norm.hpp"
|
||||
#include "ngraph/op/get_output_element.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/pass/core_fusion.hpp"
|
||||
#include "ngraph/pass/cse.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/reshape_elimination.hpp"
|
||||
#include "ngraph/pass/reshape_sinking.hpp"
|
||||
#include "ngraph/pass/visualize_tree.hpp"
|
||||
#include "ngraph/serializer.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "util/all_close.hpp"
|
||||
#include "util/ndarray.hpp"
|
||||
#include "util/random.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
TEST(reshape_sinking, edge_splitting)
|
||||
{
|
||||
// checks if Reshapes are pushed through op::Abs, but stopped by Sum
|
||||
Shape shape_nhwc{16, 28, 28, 1};
|
||||
Shape shape_nchw{16, 1, 28, 28};
|
||||
auto a = make_shared<op::Parameter>(element::i32, shape_nhwc);
|
||||
auto reshape = make_shared<op::Reshape>(a, AxisVector{0, 3, 1, 2}, shape_nchw);
|
||||
auto absn = make_shared<op::Abs>(reshape);
|
||||
auto absn2 = make_shared<op::Abs>(absn);
|
||||
auto sum = make_shared<op::Sum>(reshape, AxisSet{0, 1, 2, 3});
|
||||
auto func = make_shared<Function>(NodeVector{absn2, sum}, ParameterVector{a});
|
||||
pass::Manager pass_manager;
|
||||
// size_t before_count = count_ops_of_type<op::Reshape>(func);
|
||||
pass_manager.register_pass<pass::ReshapeSinking>();
|
||||
pass_manager.register_pass<pass::ReshapeElimination>();
|
||||
pass_manager.register_pass<pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(func);
|
||||
ASSERT_EQ(func->get_results().at(1)->input_value(0).get_node_shared_ptr(), sum);
|
||||
auto new_reshape =
|
||||
as_type_ptr<op::Reshape>(func->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_TRUE(new_reshape);
|
||||
ASSERT_EQ(new_reshape->get_shape(), shape_nchw);
|
||||
}
|
||||
|
||||
TEST(reshape_sinking, broadcast_swimming)
|
||||
{
|
||||
Shape shape_nchw{1, 32, 536, 536};
|
||||
Shape shape_nhwc{1, 536, 536, 32};
|
||||
Shape shape_weights{16, 32, 3, 3};
|
||||
Shape conv_nhwc{1, 534, 534, 16};
|
||||
Shape conv_nchw{1, 16, 534, 534};
|
||||
AxisVector to_nhwc{0, 2, 3, 1};
|
||||
AxisVector to_nchw{0, 3, 1, 2};
|
||||
|
||||
size_t channel = 16;
|
||||
auto bias = make_shared<op::Parameter>(element::i32, Shape{channel});
|
||||
auto bias_reshape = make_shared<op::Reshape>(bias, AxisVector{0}, Shape{1, channel});
|
||||
auto bias_broadcast = make_shared<op::Broadcast>(bias_reshape, conv_nhwc, AxisSet{1, 2});
|
||||
|
||||
auto input = make_shared<op::Parameter>(element::i32, shape_nhwc);
|
||||
auto reshape_input = make_shared<op::Reshape>(input, to_nchw, shape_nchw);
|
||||
|
||||
auto weights = make_shared<op::Parameter>(element::i32, shape_weights);
|
||||
auto conv = make_shared<op::Convolution>(reshape_input, weights);
|
||||
auto conv_reshape = make_shared<op::Reshape>(conv, to_nhwc, conv_nhwc);
|
||||
auto add = bias_broadcast + conv_reshape;
|
||||
auto relu = make_shared<op::Relu>(add);
|
||||
|
||||
auto func = make_shared<Function>(NodeVector{relu}, ParameterVector{bias, input, weights});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<pass::ReshapeSinking>();
|
||||
pass_manager.register_pass<pass::ReshapeElimination>();
|
||||
pass_manager.register_pass<pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(func);
|
||||
|
||||
ASSERT_EQ(add->get_shape(), conv_nchw);
|
||||
ASSERT_EQ(add->get_input_shape(0), conv_nchw);
|
||||
ASSERT_EQ(add->input_value(1).get_node_shared_ptr(), conv);
|
||||
}
|
||||
|
||||
TEST(reshape_sinking, concat)
|
||||
{
|
||||
Shape shape{};
|
||||
Shape shape_w{1, 1, 1, 1};
|
||||
Shape shape_x{1, 3, 3, 1};
|
||||
Shape shape_b{1, 3, 3, 1};
|
||||
Shape r_shape{1, 3, 3, 2};
|
||||
|
||||
auto B_ = op::Constant::create(element::f32, shape_w, {3});
|
||||
auto B = make_shared<op::Reshape>(B_, AxisVector{3, 2, 0, 1}, Shape{1, 1, 1, 1}); /* nchw */
|
||||
auto A_ = make_shared<op::Parameter>(element::f32, shape_x);
|
||||
auto A = make_shared<op::Reshape>(A_, AxisVector{0, 3, 1, 2}, Shape{1, 1, 3, 3}); /* nchw */
|
||||
auto C = op::Constant::create(element::f32, Shape{1}, {2});
|
||||
auto R = make_shared<op::Parameter>(element::f32, r_shape);
|
||||
|
||||
auto conv = make_shared<op::Convolution>(A,
|
||||
B,
|
||||
Strides{1, 1},
|
||||
Strides{1, 1},
|
||||
CoordinateDiff{0, 0},
|
||||
CoordinateDiff{0, 0},
|
||||
Strides{1, 1});
|
||||
auto reshape_conv =
|
||||
make_shared<op::Reshape>(conv, AxisVector{0, 2, 3, 1}, Shape{1, 3, 3, 1}); /* nhwc */
|
||||
auto broadcast = make_shared<op::Broadcast>(C, reshape_conv->get_shape(), AxisSet{0, 1, 2});
|
||||
auto add = broadcast + reshape_conv;
|
||||
|
||||
auto B1_ = op::Constant::create(element::f32, shape_w, {3});
|
||||
auto B1 = make_shared<op::Reshape>(B1_, AxisVector{3, 2, 0, 1}, Shape{1, 1, 1, 1});
|
||||
auto A1_ = make_shared<op::Parameter>(element::f32, shape_x);
|
||||
auto A1 = make_shared<op::Reshape>(A1_, AxisVector{0, 3, 1, 2}, Shape{1, 1, 3, 3});
|
||||
auto C1 = op::Constant::create(element::f32, Shape{1}, {2});
|
||||
auto R1 = make_shared<op::Parameter>(element::f32, r_shape);
|
||||
|
||||
auto conv1 = make_shared<op::Convolution>(A1,
|
||||
B1,
|
||||
Strides{1, 1},
|
||||
Strides{1, 1},
|
||||
CoordinateDiff{0, 0},
|
||||
CoordinateDiff{0, 0},
|
||||
Strides{1, 1});
|
||||
auto reshape_conv1 = make_shared<op::Reshape>(conv1, AxisVector{0, 2, 3, 1}, Shape{1, 3, 3, 1});
|
||||
auto broadcast1 = make_shared<op::Broadcast>(C1, reshape_conv->get_shape(), AxisSet{0, 1, 2});
|
||||
auto add1 = broadcast1 + reshape_conv1;
|
||||
|
||||
auto concat = make_shared<op::Concat>(NodeVector{add, add1}, 3);
|
||||
auto relu = make_shared<op::Relu>(concat);
|
||||
auto reshape_relu =
|
||||
make_shared<op::Reshape>(relu, AxisVector{0, 3, 1, 2}, Shape{1, 2, 3, 3}); /* nchw */
|
||||
auto B2_ = op::Constant::create(element::f32, Shape{1, 1, 2, 1}, {2});
|
||||
auto B2 = make_shared<op::Reshape>(B2_, AxisVector{3, 2, 0, 1}, Shape{1, 2, 1, 1});
|
||||
auto conv2 = make_shared<op::Convolution>(reshape_relu,
|
||||
B2,
|
||||
Strides{1, 1},
|
||||
Strides{1, 1},
|
||||
CoordinateDiff{0, 0},
|
||||
CoordinateDiff{0, 0},
|
||||
Strides{1, 1});
|
||||
auto reshape_conv2 =
|
||||
make_shared<op::Reshape>(conv2, AxisVector{0, 2, 3, 1}, Shape{1, 3, 3, 1}); /* nhwc */
|
||||
auto f = make_shared<Function>(reshape_conv2, ParameterVector{A_, A1_});
|
||||
pass::Manager pass_manager;
|
||||
size_t before_count = count_ops_of_type<op::Reshape>(f);
|
||||
pass_manager.register_pass<pass::ReshapeSinking>();
|
||||
pass_manager.register_pass<pass::ReshapeElimination>();
|
||||
pass_manager.register_pass<pass::CommonSubexpressionElimination>();
|
||||
pass_manager.run_passes(f);
|
||||
size_t before_after = count_ops_of_type<op::Reshape>(f);
|
||||
ASSERT_LE(before_after, before_count);
|
||||
}
|
||||
|
||||
TEST(reshape_sinking, pass_property)
|
||||
{
|
||||
auto pass = std::make_shared<ngraph::pass::ReshapeSinking>();
|
||||
ASSERT_TRUE(pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
|
||||
ASSERT_FALSE(pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
|
||||
}
|
@ -33,6 +33,18 @@ set (SRC
|
||||
dynamic/dynamic_backend.hpp
|
||||
op/avg_pool.cpp
|
||||
op/avg_pool.hpp
|
||||
pass/dyn_elimination.cpp
|
||||
pass/dyn_elimination.hpp
|
||||
pass/fused_op_decomposition.cpp
|
||||
pass/fused_op_decomposition.hpp
|
||||
pass/implicit_broadcast_elimination.cpp
|
||||
pass/implicit_broadcast_elimination.hpp
|
||||
pass/like_replacement.cpp
|
||||
pass/like_replacement.hpp
|
||||
pass/liveness.cpp
|
||||
pass/liveness.hpp
|
||||
pass/shape_relevance.cpp
|
||||
pass/shape_relevance.hpp
|
||||
)
|
||||
|
||||
add_library(ngraph_backend SHARED ${SRC})
|
||||
|
@ -23,13 +23,13 @@
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/op/transpose.hpp"
|
||||
#include "ngraph/pass/constant_folding.hpp"
|
||||
#include "ngraph/pass/dyn_elimination.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/shape_relevance.hpp"
|
||||
#include "ngraph/specialize_function.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "opset0_downgrade.hpp"
|
||||
#include "opset1_downgrade.hpp"
|
||||
#include "pass/dyn_elimination.hpp"
|
||||
#include "pass/shape_relevance.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
@ -22,16 +22,14 @@
|
||||
#include "ngraph/except.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/ops.hpp"
|
||||
#include "ngraph/pass/assign_layout.hpp"
|
||||
#include "ngraph/pass/core_fusion.hpp"
|
||||
#include "ngraph/pass/fused_op_decomposition.hpp"
|
||||
#include "ngraph/pass/like_replacement.hpp"
|
||||
#include "ngraph/pass/liveness.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/serializer.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "opset0_downgrade.hpp"
|
||||
#include "opset1_downgrade.hpp"
|
||||
#include "pass/fused_op_decomposition.hpp"
|
||||
#include "pass/like_replacement.hpp"
|
||||
#include "pass/liveness.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
@ -26,13 +26,13 @@
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/ops.hpp"
|
||||
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
|
||||
#include "ngraph/provenance.hpp"
|
||||
#include "ngraph/slice_plan.hpp"
|
||||
#include "ngraph/type.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
#include "op/avg_pool.hpp"
|
||||
#include "opset0_downgrade.hpp"
|
||||
#include "pass/implicit_broadcast_elimination.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "backend_visibility.hpp"
|
||||
#include "ngraph/pass/graph_rewrite.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
@ -50,7 +51,7 @@ namespace ngraph
|
||||
/// <td> \image html dyn_broadcast_post_dyneliminate.svg </td>
|
||||
/// </tr>
|
||||
/// </table>
|
||||
class NGRAPH_API DynElimination : public GraphRewrite
|
||||
class BACKEND_API DynElimination : public GraphRewrite
|
||||
{
|
||||
public:
|
||||
DynElimination();
|
@ -13,7 +13,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
#include "ngraph/pass/fused_op_decomposition.hpp"
|
||||
#include "fused_op_decomposition.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/op/get_output_element.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
@ -18,6 +18,7 @@
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "backend_visibility.hpp"
|
||||
#include "ngraph/op/util/fused_op.hpp"
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
@ -50,7 +51,7 @@ namespace ngraph
|
||||
/// <td> \image html decompose_gelu_post.svg </td>
|
||||
/// </tr>
|
||||
/// </table>
|
||||
class NGRAPH_API FusedOpDecomposition : public NodePass
|
||||
class BACKEND_API FusedOpDecomposition : public NodePass
|
||||
{
|
||||
public:
|
||||
/// \brief Function signature type for callback used to check whether provided node
|
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
|
||||
#include "implicit_broadcast_elimination.hpp"
|
||||
|
||||
#include "ngraph/builder/autobroadcast.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
@ -16,6 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "backend_visibility.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
@ -28,7 +29,7 @@ namespace ngraph
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::ImplicitBroadcastElimination : public ngraph::pass::NodePass
|
||||
class BACKEND_API ngraph::pass::ImplicitBroadcastElimination : public ngraph::pass::NodePass
|
||||
{
|
||||
public:
|
||||
bool run_on_node(std::shared_ptr<ngraph::Node> node) override;
|
@ -16,13 +16,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "backend_visibility.hpp"
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class NGRAPH_API LikeReplacement : public FunctionPass
|
||||
class BACKEND_API LikeReplacement : public FunctionPass
|
||||
{
|
||||
public:
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function> function) override;
|
@ -18,6 +18,7 @@
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "liveness.hpp"
|
||||
#include "ngraph/descriptor/input.hpp"
|
||||
#include "ngraph/descriptor/output.hpp"
|
||||
#include "ngraph/function.hpp"
|
||||
@ -27,7 +28,6 @@
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/result.hpp"
|
||||
#include "ngraph/pass/liveness.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
using namespace std;
|
@ -16,6 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "backend_visibility.hpp"
|
||||
#include "ngraph/descriptor/tensor.hpp"
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
@ -27,7 +28,7 @@ namespace ngraph
|
||||
}
|
||||
}
|
||||
|
||||
class NGRAPH_API ngraph::pass::Liveness : public FunctionPass
|
||||
class BACKEND_API ngraph::pass::Liveness : public FunctionPass
|
||||
{
|
||||
public:
|
||||
bool run_on_function(std::shared_ptr<ngraph::Function>) override;
|
@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/pass/shape_relevance.hpp"
|
||||
#include "pass/shape_relevance.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
@ -16,13 +16,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "backend_visibility.hpp"
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace pass
|
||||
{
|
||||
class NGRAPH_API ShapeRelevance : public FunctionPass
|
||||
class BACKEND_API ShapeRelevance : public FunctionPass
|
||||
{
|
||||
public:
|
||||
ShapeRelevance()
|
@ -23,8 +23,8 @@
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/function.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/pass/liveness.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "pass/liveness.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
|
||||
using namespace std;
|
||||
|
@ -1,166 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/file_util.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/divide.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/op/product.hpp"
|
||||
#include "ngraph/op/sqrt.hpp"
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/visualize_tree.hpp"
|
||||
#include "ngraph/pass/zero_dim_tensor_elimination.hpp"
|
||||
#include "util/test_tools.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
TEST(zero_dim_tensor_elimination, zero_sum)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto abs_node = std::make_shared<op::Abs>(A);
|
||||
auto sum_node = std::make_shared<op::Sum>(abs_node, AxisSet{0});
|
||||
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
|
||||
auto f = std::make_shared<Function>(NodeVector{sum_node, constant}, ParameterVector{A});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
|
||||
EXPECT_EQ(count_ops_of_type<op::Sum>(f), 1);
|
||||
pass_manager.run_passes(f);
|
||||
EXPECT_EQ(count_ops_of_type<op::Sum>(f), 0);
|
||||
}
|
||||
|
||||
TEST(zero_dim_tensor_elimination, zero_product)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto abs_node = std::make_shared<op::Abs>(A);
|
||||
auto product_node = std::make_shared<op::Product>(abs_node, AxisSet{0});
|
||||
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
|
||||
auto f = std::make_shared<Function>(NodeVector{product_node, constant}, ParameterVector{A});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
|
||||
EXPECT_EQ(count_ops_of_type<op::Product>(f), 1);
|
||||
pass_manager.run_passes(f);
|
||||
EXPECT_EQ(count_ops_of_type<op::Product>(f), 0);
|
||||
}
|
||||
|
||||
TEST(zero_dim_tensor_elimination, zero_min)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto abs_node = std::make_shared<op::Abs>(A);
|
||||
auto min_node = std::make_shared<op::Min>(abs_node, AxisSet{0});
|
||||
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
|
||||
auto f = std::make_shared<Function>(NodeVector{min_node, constant}, ParameterVector{A});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
|
||||
EXPECT_EQ(count_ops_of_type<op::Min>(f), 1);
|
||||
pass_manager.run_passes(f);
|
||||
EXPECT_EQ(count_ops_of_type<op::Min>(f), 0);
|
||||
}
|
||||
|
||||
TEST(zero_dim_tensor_elimination, zero_max)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::i32, zero_shape);
|
||||
auto abs_node = std::make_shared<op::Abs>(A);
|
||||
auto max_node = std::make_shared<op::Max>(abs_node, AxisSet{0});
|
||||
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
|
||||
auto f = std::make_shared<Function>(NodeVector{max_node, constant}, ParameterVector{A});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
|
||||
EXPECT_EQ(count_ops_of_type<op::Max>(f), 1);
|
||||
pass_manager.run_passes(f);
|
||||
EXPECT_EQ(count_ops_of_type<op::Max>(f), 0);
|
||||
}
|
||||
|
||||
TEST(zero_dim_tensor_elimination, zero_const_conv)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::f32, Shape{1, 1, 0});
|
||||
auto weights = std::make_shared<op::Parameter>(element::f32, Shape{1, 1, 4});
|
||||
auto convolution = std::make_shared<op::Convolution>(
|
||||
A, weights, Strides{1}, Strides{1}, CoordinateDiff{2}, CoordinateDiff{2});
|
||||
auto abs_node = std::make_shared<op::Abs>(convolution);
|
||||
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
|
||||
auto f =
|
||||
std::make_shared<Function>(NodeVector{abs_node, constant}, ParameterVector{A, weights});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
|
||||
EXPECT_EQ(count_ops_of_type<op::Convolution>(f), 1);
|
||||
pass_manager.run_passes(f);
|
||||
EXPECT_EQ(count_ops_of_type<op::Convolution>(f), 0);
|
||||
}
|
||||
|
||||
TEST(zero_dim_tensor_elimination, zero_const_pad)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::f32, zero_shape);
|
||||
auto B = std::make_shared<op::Parameter>(element::f32, Shape{});
|
||||
|
||||
auto pad = std::make_shared<op::Pad>(A, B, CoordinateDiff{2}, CoordinateDiff{2});
|
||||
auto abs_node = std::make_shared<op::Abs>(pad);
|
||||
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
|
||||
auto f = std::make_shared<Function>(NodeVector{abs_node, constant}, ParameterVector{A, B});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
|
||||
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 0);
|
||||
pass_manager.run_passes(f);
|
||||
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 1);
|
||||
}
|
||||
|
||||
TEST(zero_dim_tensor_elimination, zero_const_slice)
|
||||
{
|
||||
Shape zero_shape{0};
|
||||
auto A = std::make_shared<op::Parameter>(element::f32, zero_shape);
|
||||
auto B = std::make_shared<op::Parameter>(element::f32, Shape{});
|
||||
auto slice = make_shared<op::Slice>(A, Coordinate{0}, Coordinate{0});
|
||||
auto pad = std::make_shared<op::Pad>(A, B, CoordinateDiff{2}, CoordinateDiff{2});
|
||||
auto abs_node = std::make_shared<op::Abs>(pad);
|
||||
auto constant = std::make_shared<op::Constant>(element::i32, zero_shape, std::vector<string>{});
|
||||
auto f = std::make_shared<Function>(NodeVector{abs_node, constant}, ParameterVector{A, B});
|
||||
pass::Manager pass_manager;
|
||||
|
||||
pass_manager.register_pass<ngraph::pass::ZeroDimTensorElimination>();
|
||||
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 0);
|
||||
EXPECT_EQ(count_ops_of_type<op::Slice>(f), 0);
|
||||
pass_manager.run_passes(f);
|
||||
EXPECT_EQ(count_ops_of_type<op::Broadcast>(f), 1);
|
||||
EXPECT_EQ(count_ops_of_type<op::Slice>(f), 0);
|
||||
}
|
||||
|
||||
TEST(zero_dim_tensor_elimination, pass_property)
|
||||
{
|
||||
auto pass = std::make_shared<ngraph::pass::ZeroDimTensorElimination>();
|
||||
ASSERT_TRUE(pass->get_property(pass::PassProperty::REQUIRE_STATIC_SHAPE));
|
||||
ASSERT_FALSE(pass->get_property(pass::PassProperty::CHANGE_DYNAMIC_STATE));
|
||||
}
|
Loading…
Reference in New Issue
Block a user