Adding sinks to ngraph::Function (#2741)
* initial draft of adding sinks to ngraph::Function * style fixes * code style fixes * code style fixes * code style fix * review fix+build fix * code style fix * fix build * API changed according to latest discussion * review fixes * review fixes + tests * added 1 more ctor * style fixes * used new api in ir parser * fixed build * review fixes * remove validate_nodes_and_infer_types from remove_sink/remove_result * removed validate_.. after discussion * style fix
This commit is contained in:
parent
28de789993
commit
32732a1f29
@ -28,6 +28,8 @@ void ngraph::op::GenericIE::addExtension(std::shared_ptr<const ngraph::Function>
|
||||
|
||||
for (auto r : func->get_results())
|
||||
nodes.emplace_back(r);
|
||||
for (auto s : func->get_sinks())
|
||||
nodes.emplace_back(s);
|
||||
for (auto param : func->get_parameters())
|
||||
nodes.emplace_back(param);
|
||||
|
||||
|
@ -133,7 +133,7 @@ std::shared_ptr<ICNNNetwork> V10Parser::parse(const pugi::xml_node& root, std::i
|
||||
ngraph::ParameterVector parameter_nodes;
|
||||
ngraph::ResultVector result_nodes;
|
||||
ngraph::NodeVector allNodes;
|
||||
std::vector<std::shared_ptr<ngraph::op::Assign>> assign_nodes;
|
||||
ngraph::SinkVector assign_nodes;
|
||||
std::map<std::string, std::shared_ptr<ngraph::Node>> variable_id_to_read_value;
|
||||
|
||||
// Following topological order create nGraph operations
|
||||
@ -187,15 +187,12 @@ std::shared_ptr<ICNNNetwork> V10Parser::parse(const pugi::xml_node& root, std::i
|
||||
}
|
||||
|
||||
::ngraph::op::GenericIE::DisableReshape noReshape(allNodes);
|
||||
auto function = std::make_shared<ngraph::Function>(result_nodes, parameter_nodes, GetStrAttr(root, "name", ""));
|
||||
if (!result_nodes.empty()) {
|
||||
for (const auto& assign : assign_nodes) {
|
||||
assign->add_control_dependency(variable_id_to_read_value.at(assign->get_variable_id()));
|
||||
// often Assign node is a leaf of the graph, we add control_dependency for one of the results
|
||||
// to make Assign node visible for traversals get_ops(), get_ordered_ops()
|
||||
result_nodes[0]->add_control_dependency(assign);
|
||||
}
|
||||
auto function = std::make_shared<ngraph::Function>(result_nodes, assign_nodes, parameter_nodes, GetStrAttr(root, "name", ""));
|
||||
for (const auto& assign : assign_nodes) {
|
||||
assign->add_control_dependency(
|
||||
variable_id_to_read_value.at(std::dynamic_pointer_cast<ngraph::op::Assign>(assign)->get_variable_id()));
|
||||
}
|
||||
|
||||
CNNNetwork net(function, _exts);
|
||||
parsePreProcess(net, root, binStream);
|
||||
return net;
|
||||
|
@ -27,6 +27,7 @@
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/result.hpp"
|
||||
#include "ngraph/op/sink.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -52,6 +53,16 @@ namespace ngraph
|
||||
const ParameterVector& parameters,
|
||||
const std::string& name = "");
|
||||
|
||||
Function(const ResultVector& results,
|
||||
const SinkVector& sinks,
|
||||
const ParameterVector& parameters,
|
||||
const std::string& name = "");
|
||||
|
||||
Function(const OutputVector& results,
|
||||
const SinkVector& sinks,
|
||||
const ParameterVector& parameters,
|
||||
const std::string& name = "");
|
||||
|
||||
virtual ~Function() {}
|
||||
/// Return the number of outputs for this function.
|
||||
size_t get_output_size() const;
|
||||
@ -138,6 +149,27 @@ namespace ngraph
|
||||
bool evaluate(const HostTensorVector& output_tensors,
|
||||
const HostTensorVector& input_tensors) const;
|
||||
|
||||
/// \brief Return a list of function's sinks.
|
||||
const SinkVector& get_sinks() const { return m_sinks; }
|
||||
/// \brief Add new sink nodes to the list. Method doesn't validate graph, it should be done
|
||||
/// manually after all changes.
|
||||
/// \param sinks new sink nodes
|
||||
void add_sinks(const SinkVector& sinks);
|
||||
|
||||
/// \brief Delete sink node from the list of sinks. Method doesn't delete node from graph.
|
||||
/// \param sink Sink to delete
|
||||
void remove_sink(const std::shared_ptr<op::Sink>& sink);
|
||||
|
||||
/// \brief Add new Result nodes to the list. Method doesn't validate graph, it should be
|
||||
/// done manually after all changes.
|
||||
/// \param results new Result nodes
|
||||
void add_results(const ResultVector& results);
|
||||
|
||||
/// \brief Delete Result node from the list of results. Method will not delete node from
|
||||
/// graph.
|
||||
/// \param result Result node to delete
|
||||
void remove_result(const std::shared_ptr<op::Result>& result);
|
||||
|
||||
private:
|
||||
Function(const Function&) = delete;
|
||||
Function(const Function&&) = delete;
|
||||
@ -150,6 +182,9 @@ namespace ngraph
|
||||
topological_sort_t m_topological_sorter;
|
||||
|
||||
ResultVector m_results;
|
||||
// List of the nodes with side effect in graph.
|
||||
// These nodes are not outputs of graph but should not be removed even if have no children.
|
||||
SinkVector m_sinks;
|
||||
ParameterVector m_parameters;
|
||||
};
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/sink.hpp"
|
||||
#include "ngraph/op/util/variable.hpp"
|
||||
|
||||
namespace ngraph
|
||||
@ -26,11 +26,10 @@ namespace ngraph
|
||||
namespace v3
|
||||
{
|
||||
/// \brief Assign operation sets an input value to the variable with `variable_id`
|
||||
class NGRAPH_API Assign : public Op
|
||||
class NGRAPH_API Assign : public Sink
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"Assign", 3};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
Assign() = default;
|
||||
|
||||
/// \brief Constructs an Assign operation.
|
||||
|
46
ngraph/core/include/ngraph/op/sink.hpp
Normal file
46
ngraph/core/include/ngraph/op/sink.hpp
Normal file
@ -0,0 +1,46 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
/// Root of nodes that can be sink nodes
|
||||
class NGRAPH_API Sink : public Op
|
||||
{
|
||||
public:
|
||||
virtual ~Sink() = 0;
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
protected:
|
||||
Sink()
|
||||
: Op()
|
||||
{
|
||||
}
|
||||
Sink(const OutputVector& arguments)
|
||||
: Op(arguments)
|
||||
{
|
||||
}
|
||||
};
|
||||
}
|
||||
using SinkVector = std::vector<std::shared_ptr<op::Sink>>;
|
||||
}
|
@ -77,6 +77,28 @@ Function::Function(const std::shared_ptr<Node>& result,
|
||||
{
|
||||
}
|
||||
|
||||
Function::Function(const ResultVector& results,
|
||||
const SinkVector& sinks,
|
||||
const ParameterVector& parameters,
|
||||
const std::string& name)
|
||||
: m_results(results)
|
||||
, m_sinks(sinks)
|
||||
, m_parameters(parameters)
|
||||
, m_name(name)
|
||||
, m_unique_name("Function_" + to_string(m_next_instance_id.fetch_add(1)))
|
||||
, m_topological_sorter(topological_sort<std::vector<std::shared_ptr<Node>>>)
|
||||
{
|
||||
validate_nodes_and_infer_types();
|
||||
}
|
||||
|
||||
Function::Function(const OutputVector& results,
|
||||
const SinkVector& sinks,
|
||||
const ParameterVector& parameters,
|
||||
const std::string& name)
|
||||
: Function(as_result_vector(results), sinks, parameters, name)
|
||||
{
|
||||
}
|
||||
|
||||
void Function::validate_nodes_and_infer_types()
|
||||
{
|
||||
OV_ITT_SCOPED_TASK(itt::domains::nGraph, "Function::validate_nodes_and_infer_types");
|
||||
@ -106,6 +128,10 @@ std::vector<shared_ptr<Node>> Function::get_ordered_ops() const
|
||||
{
|
||||
nodes.push_back(r);
|
||||
}
|
||||
for (auto& r : get_sinks())
|
||||
{
|
||||
nodes.emplace_back(r);
|
||||
}
|
||||
for (auto& param : get_parameters())
|
||||
{
|
||||
nodes.push_back(param);
|
||||
@ -122,6 +148,11 @@ void Function::map_unordered_ops(std::function<void(Node*)> f) const
|
||||
{
|
||||
remaining_ops.push(r.get());
|
||||
}
|
||||
for (auto& r : get_sinks())
|
||||
{
|
||||
remaining_ops.push(r.get());
|
||||
}
|
||||
|
||||
for (auto& param : get_parameters())
|
||||
{
|
||||
remaining_ops.push(param.get());
|
||||
@ -347,6 +378,33 @@ bool Function::visit_attributes(AttributeVisitor& visitor)
|
||||
return true;
|
||||
}
|
||||
|
||||
void Function::add_sinks(const SinkVector& sinks)
|
||||
{
|
||||
m_sinks.insert(m_sinks.end(), sinks.begin(), sinks.end());
|
||||
}
|
||||
|
||||
void Function::remove_sink(const std::shared_ptr<op::Sink>& sink)
|
||||
{
|
||||
m_sinks.erase(std::remove_if(m_sinks.begin(),
|
||||
m_sinks.end(),
|
||||
[&sink](std::shared_ptr<op::Sink>& s) { return s == sink; }),
|
||||
m_sinks.end());
|
||||
}
|
||||
|
||||
void Function::add_results(const ResultVector& results)
|
||||
{
|
||||
m_results.insert(m_results.end(), results.begin(), results.end());
|
||||
}
|
||||
|
||||
void Function::remove_result(const std::shared_ptr<op::Result>& result)
|
||||
{
|
||||
m_results.erase(
|
||||
std::remove_if(m_results.begin(),
|
||||
m_results.end(),
|
||||
[&result](std::shared_ptr<op::v0::Result>& r) { return r == result; }),
|
||||
m_results.end());
|
||||
}
|
||||
|
||||
constexpr DiscreteTypeInfo AttributeAdapter<shared_ptr<Function>>::type_info;
|
||||
|
||||
AttributeAdapter<shared_ptr<Function>>::AttributeAdapter(shared_ptr<Function>& ref)
|
||||
@ -401,6 +459,10 @@ bool AttributeAdapter<shared_ptr<Function>>::visit_attributes(AttributeVisitor&
|
||||
{
|
||||
results.push_back(result);
|
||||
}
|
||||
for (auto sink : m_ref->get_sinks())
|
||||
{
|
||||
results.push_back(sink);
|
||||
}
|
||||
|
||||
int64_t i = 0;
|
||||
ostringstream index;
|
||||
|
@ -57,6 +57,10 @@ void ngraph::traverse_nodes(const Function* p, std::function<void(std::shared_pt
|
||||
{
|
||||
nodes.push_back(r);
|
||||
}
|
||||
for (auto s : p->get_sinks())
|
||||
{
|
||||
nodes.emplace_back(s);
|
||||
}
|
||||
|
||||
for (auto param : p->get_parameters())
|
||||
{
|
||||
@ -419,7 +423,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
|
||||
// clone function operations
|
||||
clone_nodes(func.get_ops(), node_map);
|
||||
|
||||
// get cloned function results and parameters
|
||||
// get cloned function results and sinks and parameters
|
||||
ResultVector cloned_results;
|
||||
for (shared_ptr<Node> node : func.get_results())
|
||||
{
|
||||
@ -430,6 +434,12 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
|
||||
}
|
||||
cloned_results.push_back(result);
|
||||
}
|
||||
SinkVector cloned_sinks;
|
||||
for (auto node : func.get_sinks())
|
||||
{
|
||||
cloned_sinks.push_back(static_pointer_cast<op::Sink>(node_map.at(node.get())));
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<op::Parameter>> cloned_params;
|
||||
for (auto param : func.get_parameters())
|
||||
{
|
||||
@ -439,6 +449,7 @@ std::shared_ptr<ngraph::Function> ngraph::clone_function(const ngraph::Function&
|
||||
// create and return cloned function
|
||||
auto result = std::make_shared<ngraph::Function>(cloned_results, cloned_params);
|
||||
result->set_friendly_name(func.get_friendly_name());
|
||||
result->add_sinks(cloned_sinks);
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -863,6 +874,18 @@ bool ngraph::check_for_cycles(const ngraph::Function* func,
|
||||
}
|
||||
}
|
||||
|
||||
for (auto res : func->get_sinks())
|
||||
{
|
||||
std::deque<std::shared_ptr<Node>> path;
|
||||
// mirror of path stack for faster cycle check
|
||||
std::unordered_set<std::shared_ptr<Node>> path_set;
|
||||
if (check_for_cycles_bkwd(res, path, path_set, cycle_nodes))
|
||||
{
|
||||
is_bkwd_cycle = true;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto param : func->get_parameters())
|
||||
{
|
||||
std::deque<std::shared_ptr<Node>> path;
|
||||
|
@ -21,10 +21,10 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::v3::Assign::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v3::Assign, "Assign", 3, op::Sink);
|
||||
|
||||
op::v3::Assign::Assign(const Output<Node>& new_value, const std::string& variable_id)
|
||||
: Op({new_value})
|
||||
: Sink({new_value})
|
||||
, m_variable_id(variable_id)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
|
25
ngraph/core/src/op/sink.cpp
Normal file
25
ngraph/core/src/op/sink.cpp
Normal file
@ -0,0 +1,25 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/sink.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::Sink, "Sink", 0);
|
||||
|
||||
op::Sink::~Sink()
|
||||
{
|
||||
}
|
@ -120,6 +120,11 @@ std::shared_ptr<Function>
|
||||
new_results[i] = std::static_pointer_cast<op::Result>(m[new_results[i].get()]);
|
||||
new_results[i]->set_friendly_name(name);
|
||||
}
|
||||
SinkVector new_sinks = f->get_sinks();
|
||||
for (size_t i = 0; i < new_sinks.size(); i++)
|
||||
{
|
||||
new_sinks[i] = std::static_pointer_cast<op::Sink>(m[new_sinks[i].get()]);
|
||||
}
|
||||
|
||||
return std::make_shared<Function>(new_results, new_parameters);
|
||||
return std::make_shared<Function>(new_results, new_sinks, new_parameters);
|
||||
}
|
||||
|
@ -204,3 +204,176 @@ TEST(build_graph, default_output_checks)
|
||||
FAIL() << "nullptr initialization of Output failed";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(build_graph, build_graph_with_sink)
|
||||
{
|
||||
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
|
||||
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
|
||||
auto read = make_shared<op::ReadValue>(init_const, "v0");
|
||||
std::vector<shared_ptr<Node>> args = {arg, read};
|
||||
auto pattern = make_shared<op::Concat>(args, 1);
|
||||
auto res = make_shared<op::Result>(pattern);
|
||||
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
|
||||
auto crop = make_shared<op::Split>(pattern, axis, 3);
|
||||
auto assign = make_shared<op::Assign>(crop, "v0");
|
||||
|
||||
auto f = make_shared<Function>(ResultVector({res}), SinkVector({assign}), ParameterVector{arg});
|
||||
|
||||
SinkVector sinks = f->get_sinks();
|
||||
EXPECT_EQ(sinks.size(), 1);
|
||||
EXPECT_EQ(sinks[0], assign);
|
||||
NodeVector nodes = f->get_ops();
|
||||
EXPECT_EQ(nodes.size(), 8);
|
||||
}
|
||||
|
||||
TEST(build_graph, build_graph_with_sink_output_ctor)
|
||||
{
|
||||
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
|
||||
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
|
||||
auto read = make_shared<op::ReadValue>(init_const, "v0");
|
||||
std::vector<shared_ptr<Node>> args = {arg, read};
|
||||
auto pattern = make_shared<op::Concat>(args, 1);
|
||||
auto res = make_shared<op::Result>(pattern);
|
||||
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
|
||||
auto crop = make_shared<op::Split>(pattern, axis, 3);
|
||||
auto assign = make_shared<op::Assign>(crop, "v0");
|
||||
|
||||
auto f = make_shared<Function>(
|
||||
OutputVector({pattern->output(0)}), SinkVector({assign}), ParameterVector{arg});
|
||||
|
||||
SinkVector sinks = f->get_sinks();
|
||||
EXPECT_EQ(sinks.size(), 1);
|
||||
EXPECT_EQ(sinks[0], assign);
|
||||
NodeVector nodes = f->get_ops();
|
||||
EXPECT_EQ(nodes.size(), 8);
|
||||
}
|
||||
|
||||
TEST(build_graph, build_graph_with_add_sink)
|
||||
{
|
||||
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
|
||||
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
|
||||
auto read = make_shared<op::ReadValue>(init_const, "v0");
|
||||
std::vector<shared_ptr<Node>> args = {arg, read};
|
||||
auto pattern = make_shared<op::Concat>(args, 1);
|
||||
auto res = make_shared<op::Result>(pattern);
|
||||
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
|
||||
auto crop = make_shared<op::Split>(pattern, axis, 3);
|
||||
auto assign = make_shared<op::Assign>(crop, "v0");
|
||||
|
||||
auto f = make_shared<Function>(ResultVector({res}), ParameterVector{arg});
|
||||
|
||||
NodeVector nodes = f->get_ops();
|
||||
EXPECT_EQ(nodes.size(), 5);
|
||||
SinkVector sinks = f->get_sinks();
|
||||
EXPECT_EQ(sinks.size(), 0);
|
||||
|
||||
f->add_sinks(SinkVector({assign}));
|
||||
sinks = f->get_sinks();
|
||||
EXPECT_EQ(sinks.size(), 1);
|
||||
EXPECT_EQ(sinks[0], assign);
|
||||
nodes = f->get_ops();
|
||||
EXPECT_EQ(nodes.size(), 8);
|
||||
}
|
||||
|
||||
TEST(build_graph, build_graph_with_wrong_remove_sink)
|
||||
{
|
||||
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
|
||||
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
|
||||
auto read = make_shared<op::ReadValue>(init_const, "v0");
|
||||
std::vector<shared_ptr<Node>> args = {arg, read};
|
||||
auto pattern = make_shared<op::Concat>(args, 1);
|
||||
auto res = make_shared<op::Result>(pattern);
|
||||
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
|
||||
auto crop = make_shared<op::Split>(pattern, axis, 3);
|
||||
auto assign = make_shared<op::Assign>(crop, "v0");
|
||||
|
||||
auto f = make_shared<Function>(ResultVector({res}), SinkVector({assign}), ParameterVector{arg});
|
||||
|
||||
SinkVector sinks = f->get_sinks();
|
||||
EXPECT_EQ(sinks.size(), 1);
|
||||
EXPECT_EQ(sinks[0], assign);
|
||||
f->remove_sink(assign);
|
||||
sinks = f->get_sinks();
|
||||
EXPECT_EQ(sinks.size(), 0);
|
||||
auto nodes = f->get_ops();
|
||||
EXPECT_EQ(nodes.size(), 5);
|
||||
}
|
||||
|
||||
TEST(build_graph, build_graph_with_remove_sink)
|
||||
{
|
||||
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
|
||||
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
|
||||
auto read = make_shared<op::ReadValue>(init_const, "v0");
|
||||
std::vector<shared_ptr<Node>> args = {arg, read};
|
||||
auto pattern = make_shared<op::Concat>(args, 1);
|
||||
auto res = make_shared<op::Result>(pattern);
|
||||
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
|
||||
auto crop = make_shared<op::Split>(pattern, axis, 3);
|
||||
auto assign = make_shared<op::Assign>(crop, "v0");
|
||||
|
||||
auto f = make_shared<Function>(ResultVector({res}), SinkVector({assign}), ParameterVector{arg});
|
||||
|
||||
pattern->input(1).replace_source_output(arg);
|
||||
|
||||
SinkVector sinks = f->get_sinks();
|
||||
EXPECT_EQ(sinks.size(), 1);
|
||||
EXPECT_EQ(sinks[0], assign);
|
||||
f->remove_sink(assign);
|
||||
sinks = f->get_sinks();
|
||||
EXPECT_EQ(sinks.size(), 0);
|
||||
auto nodes = f->get_ops();
|
||||
EXPECT_EQ(nodes.size(), 3);
|
||||
}
|
||||
|
||||
TEST(build_graph, build_graph_with_add_result)
|
||||
{
|
||||
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
|
||||
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
|
||||
auto read = make_shared<op::ReadValue>(init_const, "v0");
|
||||
std::vector<shared_ptr<Node>> args = {arg, read};
|
||||
auto pattern = make_shared<op::Concat>(args, 1);
|
||||
auto res = make_shared<op::Result>(pattern);
|
||||
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
|
||||
auto crop = make_shared<op::Split>(pattern, axis, 3);
|
||||
auto res2 = make_shared<op::Result>(crop, "v0");
|
||||
|
||||
auto f = make_shared<Function>(ResultVector({res}), ParameterVector{arg});
|
||||
|
||||
NodeVector nodes = f->get_ops();
|
||||
EXPECT_EQ(nodes.size(), 5);
|
||||
ResultVector results = f->get_results();
|
||||
EXPECT_EQ(results.size(), 1);
|
||||
|
||||
f->add_results(ResultVector({res2}));
|
||||
results = f->get_results();
|
||||
EXPECT_EQ(results.size(), 2);
|
||||
EXPECT_EQ(results[1], res2);
|
||||
nodes = f->get_ops();
|
||||
EXPECT_EQ(nodes.size(), 8);
|
||||
}
|
||||
|
||||
TEST(build_graph, build_graph_with_remove_result)
|
||||
{
|
||||
auto arg = make_shared<op::Parameter>(element::f32, Shape{2, 4});
|
||||
auto init_const = op::Constant::create(element::f32, Shape{2, 2}, {0, 0, 0, 0});
|
||||
auto read = make_shared<op::ReadValue>(init_const, "v0");
|
||||
std::vector<shared_ptr<Node>> args = {arg, read};
|
||||
auto pattern = make_shared<op::Concat>(args, 1);
|
||||
auto res = make_shared<op::Result>(pattern);
|
||||
const auto axis = op::Constant::create(element::i64, Shape{}, {1});
|
||||
auto crop = make_shared<op::Split>(pattern, axis, 3);
|
||||
auto res2 = make_shared<op::Result>(crop, "v0");
|
||||
|
||||
auto f = make_shared<Function>(ResultVector({res, res2}), ParameterVector{arg});
|
||||
|
||||
NodeVector nodes = f->get_ops();
|
||||
EXPECT_EQ(nodes.size(), 8);
|
||||
ResultVector results = f->get_results();
|
||||
EXPECT_EQ(results.size(), 2);
|
||||
|
||||
f->remove_result(res2);
|
||||
results = f->get_results();
|
||||
EXPECT_EQ(results.size(), 1);
|
||||
nodes = f->get_ops();
|
||||
EXPECT_EQ(nodes.size(), 5);
|
||||
}
|
Loading…
Reference in New Issue
Block a user