Remove redundant node methods (#1324)

* Remove placement

* Removed validate and infer eltwise

* Remove is eltwise

* Remove support broadcast and decompose

* Removed is_op, is_parameter, is_pattern

* Fixed code style

* Added is_constant and is_output

* Removed is_communicative and is_null

* Fixed code style

* Fixed typo

* Fixed comments

* Fixed typo

* Revert is_parameter, is_output, is_result for OpenCV build
This commit is contained in:
Ilya Churaev 2020-07-21 06:02:00 +03:00 committed by GitHub
parent 898f0626ad
commit 54ae67414e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
100 changed files with 1052 additions and 804 deletions

View File

@ -32,6 +32,7 @@
#include <ngraph/graph_util.hpp>
#include <ngraph/op/result.hpp>
#include <ngraph/op/parameter.hpp>
#include <ngraph/op/util/op_types.hpp>
#include <ngraph/rt_info.hpp>
using namespace InferenceEngine;
@ -364,7 +365,7 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
auto orderedOps = function->get_ordered_ops();
orderedOps.erase(
std::remove_if(std::begin(orderedOps), std::end(orderedOps), [] (const std::shared_ptr<ngraph::Node>& node) {
return node->is_constant();
return ngraph::op::is_constant(node);
}),
std::end(orderedOps));
bool allEmpty = true;
@ -401,7 +402,7 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
auto NoConstants = [] (std::vector<ngraph::Input<ngraph::Node>>&& inputs) {
std::vector<ngraph::Input<ngraph::Node>> result;
for (auto&& input : inputs) {
if (!(input.get_source_output().get_node()->is_constant())) {
if (!(ngraph::op::is_constant(input.get_source_output().get_node()))) {
result.emplace_back(std::move(input));
}
}
@ -478,7 +479,7 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
InputSet subgraphInputs;
// Get all subgraph inputs using just node affinities. Also collect transitive closure
for (auto&& node : orderedOps) {
if (node->is_parameter()) {
if (ngraph::op::is_parameter(node)) {
graphInputNodes.insert(node.get());
subgraphInputs.insert(Input{node.get(), 0});
nodeInputDependencies[node.get()].insert(Input{node.get(), 0});
@ -550,7 +551,8 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
}
auto& nodeSubgraphCyclicInputDependency = nodeSubgraphCyclicInputDependencies[node.get()];
for (auto&& subgraphInput : allNodeSubgraphInputs) {
if (!subgraphInput.get_node()->is_parameter() && subgraphIds[node.get()] == subgraphIds[InputNode(subgraphInput)]) {
if (!ngraph::op::is_parameter(subgraphInput.get_node()) &&
subgraphIds[node.get()] == subgraphIds[InputNode(subgraphInput)]) {
nodeSubgraphCyclicInputDependency.emplace(subgraphInput);
}
}
@ -585,7 +587,7 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
NodeMap<ngraph::Node*> subgraphParameterToPrevResult;
std::vector<std::shared_ptr<ngraph::op::Result>> results;
for (auto&& input : subgraphInputs) {
if (!(input.get_node()->is_parameter())) {
if (!ngraph::op::is_parameter(input.get_node())) {
auto output = input.get_source_output();
output.remove_target_input(input);
auto result = std::make_shared<ngraph::op::Result>(output);
@ -614,10 +616,10 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
for (auto&& subgraphIdPtrValue : subgraphIds) {
auto node = subgraphIdPtrValue.first;
auto& subgraph = subgraphs[subgraphIdPtrValue.second];
if (node->is_output()) {
if (ngraph::op::is_output(node)) {
subgraph._results.emplace_back(
std::dynamic_pointer_cast<ngraph::op::v0::Result>(node->shared_from_this()));
} else if (node->is_parameter()) {
} else if (ngraph::op::is_parameter(node)) {
subgraph._parameters.emplace_back(
std::dynamic_pointer_cast<ngraph::op::v0::Parameter>(node->shared_from_this()));
}

View File

@ -26,6 +26,7 @@
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/op/fused/gelu.hpp>
#include <ngraph/op/util/op_types.hpp>
#include "ngraph_ops/fully_connected.hpp"
#if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64)
@ -227,7 +228,7 @@ void Engine::QueryNetwork(const ICNNNetwork& network, const std::map<std::string
if (function != nullptr) {
std::unordered_set<std::string> originalOps;
for (auto&& node : function->get_ops()) {
if (!node->is_constant() && !node->is_parameter() && !node->is_output()) {
if (!ngraph::op::is_constant(node) && !ngraph::op::is_parameter(node) && !ngraph::op::is_output(node)) {
originalOps.emplace(node->get_friendly_name());
}
}

View File

@ -10,6 +10,7 @@
#include <transformations_visibility.hpp>
#include <ngraph/op/util/op_types.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/validation_util.hpp>
@ -58,7 +59,7 @@ void ngraph::pass::ConvertReduceToPooling::convert_reduce_to_pooling() {
auto input = reduce->input_value(0);
auto axes_node = reduce->input_value(1).get_node_shared_ptr();
if (!axes_node->is_constant()) {
if (!ngraph::op::is_constant(axes_node)) {
return false;
}

View File

@ -15,6 +15,7 @@
#include <graph_tools.hpp>
#include <functional_test_utils/plugin_cache.hpp>
#include <multi-device/multi_device_config.hpp>
#include <ngraph/op/util/op_types.hpp>
#include "common_test_utils/file_utils.hpp"
#include "common_test_utils/unicode_utils.hpp"
@ -1382,7 +1383,7 @@ TEST_P(IEClassLoadNetworkTest, QueryNetworkHETEROWithMULTINoThrow_V10) {
ASSERT_NE(nullptr, function);
std::unordered_set<std::string> expectedLayers;
for (auto &&node : function->get_ops()) {
if (!node->is_constant() && !node->is_parameter() && !node->is_output()) {
if (!ngraph::op::is_constant(node) && !ngraph::op::is_parameter(node) && !ngraph::op::is_output(node)) {
expectedLayers.emplace(node->get_friendly_name());
}
}
@ -1419,7 +1420,7 @@ TEST_P(IEClassLoadNetworkTest, QueryNetworkMULTIWithHETERONoThrow_V10) {
ASSERT_NE(nullptr, function);
std::unordered_set<std::string> expectedLayers;
for (auto &&node : function->get_ops()) {
if (!node->is_constant() && !node->is_parameter() && !node->is_output()) {
if (!ngraph::op::is_constant(node) && !ngraph::op::is_parameter(node) && !ngraph::op::is_output(node)) {
expectedLayers.emplace(node->get_friendly_name());
}
}

View File

@ -4,6 +4,7 @@
//
#include "hetero/query_network.hpp"
#include <ngraph/op/util/op_types.hpp>
#include <ngraph/variant.hpp>
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/subgraph_builders.hpp"
@ -27,7 +28,9 @@ TEST_P(QueryNetworkTest, queryNetworkResultContainAllAndOnlyInputLayers) {
ASSERT_NE(nullptr, cnnNetwork.getFunction());
std::set<std::string> expectedLayers;
for (auto&& node : function->get_ops()) {
if (!node->is_parameter() && !node->is_constant() && !node->is_output()) {
if (!ngraph::op::is_parameter(node) &&
!ngraph::op::is_constant(node) &&
!ngraph::op::is_output(node)) {
expectedLayers.insert(node->get_friendly_name());
}
}
@ -37,4 +40,4 @@ TEST_P(QueryNetworkTest, queryNetworkResultContainAllAndOnlyInputLayers) {
}
ASSERT_EQ(expectedLayers, actualLayers);
}
} // namespace HeteroTests
} // namespace HeteroTests

View File

@ -4,6 +4,7 @@
//
#include "hetero/synthetic.hpp"
#include <ngraph/op/util/op_types.hpp>
#include <ngraph/variant.hpp>
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/subgraph_builders.hpp"
@ -22,7 +23,9 @@ std::vector<FunctionParameter> HeteroSyntheticTest::_singleMajorNodeFunctions{[]
for (auto&& builder : builders) {
auto function = builder();
for (auto&& node : function->get_ordered_ops()) {
if (!(node->is_constant()) && !(node->is_parameter()) && !(node->is_output())) {
if (!ngraph::op::is_constant(node) &&
!(ngraph::op::is_parameter(node)) &&
!(ngraph::op::is_output(node))) {
result.push_back(FunctionParameter{{node->get_friendly_name()}, function});
}
}
@ -41,7 +44,9 @@ std::vector<FunctionParameter> HeteroSyntheticTest::_randomMajorNodeFunctions{[]
for (std::size_t i = 0; i < ordered_ops.size(); ++i) {
std::unordered_set<std::string> majorPluginNodeIds;
for (auto&& node : ordered_ops) {
if (!(node->is_constant()) && !(node->is_parameter()) && !(node->is_output()) && d(e)) {
if (!(ngraph::op::is_constant(node)) &&
!(ngraph::op::is_parameter(node)) &&
!(ngraph::op::is_output(node)) && d(e)) {
majorPluginNodeIds.emplace(node->get_friendly_name());
}
}
@ -117,7 +122,9 @@ std::string HeteroSyntheticTest::SetUpAffinity() {
auto& pluginParameters = std::get<Plugin>(param);
affinities += "\n{\n";
for (auto&& node : std::get<Function>(param)._function->get_ordered_ops()) {
if (!(node->is_constant()) && !(node->is_parameter()) && !(node->is_output())) {
if (!ngraph::op::is_constant(node) &&
!(ngraph::op::is_parameter(node)) &&
!(ngraph::op::is_output(node))) {
std::string affinity;
if (std::get<Function>(param)._majorPluginNodeIds.end() !=
std::get<Function>(param)._majorPluginNodeIds.find(node->get_friendly_name())) {
@ -140,4 +147,4 @@ TEST_P(HeteroSyntheticTest, someLayersToMajorPluginOthersToFallback) {
ASSERT_NE(nullptr, cnnNetwork.getFunction());
}
} // namespace HeteroTests
} // namespace HeteroTests

View File

@ -8,6 +8,7 @@
#include <assert.h>
#include <ngraph/function.hpp>
#include <ngraph/op/util/op_types.hpp>
#include <ngraph/pass/visualize_tree.hpp>
std::pair<bool, std::string> compare_functions(const std::shared_ptr<ngraph::Function> & f1, const std::shared_ptr<ngraph::Function> & f2) {
@ -75,7 +76,7 @@ void check_rt_info(const std::shared_ptr<ngraph::Function> & f) {
std::ostringstream err_log;
for (auto & op : f->get_ops()) {
if (op->is_constant()) continue;
if (ngraph::op::is_constant(op)) continue;
const auto & rt_info = op->get_rt_info();
for (const auto & attr_name : attrs_to_check) {
@ -94,4 +95,4 @@ void check_rt_info(const std::shared_ptr<ngraph::Function> & f) {
void visualize_function(std::shared_ptr<ngraph::Function> f, const std::string & file_name) {
std::vector<std::shared_ptr<ngraph::Function> > g{f};
ngraph::pass::VisualizeTree(file_name).run_on_module(g);
}
}

View File

@ -8,6 +8,7 @@
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/op/util/op_types.hpp>
#include <ngraph/specialize_function.hpp>
#include <ngraph_functions/utils/ngraph_helpers.hpp>
@ -126,7 +127,7 @@ std::shared_ptr<Function> foldFunction(const std::shared_ptr<Function> &function
const auto &foldedFunc = specialize_function(function, paramElementTypes, paramShapes, inBuffers, true, true);
for (const auto &op : foldedFunc->get_ops()) {
NGRAPH_CHECK(op->is_constant() || op->is_output() || op->is_parameter(),
NGRAPH_CHECK(op::is_constant(op) || op::is_output(op) || op::is_parameter(op),
"Function was not fully folded to constant state!\n",
"At least one non constant node with type ", op->get_type_name(),
" present in function.");
@ -141,7 +142,7 @@ std::vector<std::vector<std::uint8_t>> getConstData(const std::shared_ptr<Functi
const auto &output = function->output(i).get_node_shared_ptr();
NGRAPH_CHECK(output->inputs().size() == 1);
auto parrentNode = output->input_value(0).get_node_shared_ptr();
NGRAPH_CHECK(parrentNode->is_constant(), "Function was not fully folded to constant state!\n",
NGRAPH_CHECK(op::is_constant(parrentNode), "Function was not fully folded to constant state!\n",
"Parent node of one of results is not constant and has type ", parrentNode->get_type_name());
const auto data = std::dynamic_pointer_cast<opset1::Constant>(parrentNode)->get_data_ptr<std::uint8_t>();

View File

@ -30,6 +30,7 @@
#include "ngraph/check.hpp"
#include "ngraph/except.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/opsets/opset.hpp"
#include "node_factory.hpp"
#include "tensor_iterator_builder.hpp"
@ -55,7 +56,7 @@ namespace
std::shared_ptr<ngraph::Node>(m_opset.create(op_type_name));
NGRAPH_CHECK(op_node != nullptr, "Couldn't create operator: ", op_type_name);
NGRAPH_CHECK(!op_node->is_constant(),
NGRAPH_CHECK(!ngraph::op::is_constant(op_node),
"Currently NodeFactory doesn't support Constant node: ",
op_type_name);

View File

@ -426,6 +426,8 @@ set (SRC
op/util/binary_elementwise_logical.hpp
op/util/broadcast_base.cpp
op/util/broadcast_base.hpp
op/util/elementwise_args.cpp
op/util/elementwise_args.hpp
op/util/embeddingbag_packed_base.cpp
op/util/embeddingbag_packed_base.hpp
op/util/embeddingbag_offsets_base.cpp
@ -447,6 +449,8 @@ set (SRC
op/util/unary_elementwise_arithmetic.cpp
op/util/unary_elementwise_arithmetic.hpp
op/util/variable.hpp
op/util/op_types.cpp
op/util/op_types.hpp
ops.hpp
opsets/opset.cpp
partial_shape.cpp
@ -559,8 +563,6 @@ set (SRC
pattern/op/skip.hpp
pattern/op/true.cpp
pattern/op/true.hpp
placement.cpp
placement.hpp
provenance.cpp
provenance.hpp
rank.hpp

View File

@ -31,3 +31,13 @@ namespace ngraph
}
} // namespace onnx_import
} // namespace ngraph
bool ngraph::op::is_null(const ngraph::Node* node)
{
return dynamic_cast<const ngraph::onnx_import::NullNode*>(node) != nullptr;
}
bool ngraph::op::is_null(const std::shared_ptr<ngraph::Node>& node)
{
return is_null(node.get());
}

View File

@ -19,9 +19,17 @@
#include <memory>
#include "ngraph/node.hpp"
#include "utils/onnx_importer_visibility.hpp"
namespace ngraph
{
namespace op
{
ONNX_IMPORTER_API
bool is_null(const ngraph::Node* node);
ONNX_IMPORTER_API
bool is_null(const std::shared_ptr<ngraph::Node>& node);
}
namespace onnx_import
{
/// \brief Represents a missing optional input or output of an ONNX node
@ -40,7 +48,6 @@ namespace ngraph
const NodeTypeInfo& get_type_info() const override { return type_info; }
NullNode() = default;
bool is_null() const final override { return true; }
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
};

View File

@ -20,6 +20,7 @@
#include "clip.hpp"
#include "default_opset.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
namespace ngraph
{
@ -57,7 +58,7 @@ namespace ngraph
// If second input is provided, assign to min input, otherwise set lowest
// numeric limit of double as min input.
if (inputs.size() > 1 && !inputs.at(1)->is_null())
if (inputs.size() > 1 && !ngraph::op::is_null(inputs.at(1)))
{
min = inputs.at(1);
}
@ -69,7 +70,7 @@ namespace ngraph
// If third input is provided, assign to max input, otherwise set maximum
// numeric limit of double as max input.
if (inputs.size() == 3 && !inputs.at(2)->is_null())
if (inputs.size() == 3 && !ngraph::op::is_null(inputs.at(2)))
{
max = inputs.at(2);
}

View File

@ -21,6 +21,7 @@
#include "dequantize_linear.hpp"
#include "ngraph/axis_set.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/dequantize.hpp"
#include "ngraph/shape.hpp"
@ -37,7 +38,7 @@ namespace ngraph
{
std::shared_ptr<ngraph::Node> get_zero_point(const NodeVector& inputs)
{
if (inputs.size() == 3 && !inputs[2]->is_null())
if (inputs.size() == 3 && !ngraph::op::is_null(inputs[2]))
{
auto zero_point = inputs[2];

View File

@ -20,6 +20,7 @@
#include "default_opset.hpp"
#include "gru.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
#include "ngraph/shape.hpp"
#include "utils/recurrent.hpp"
@ -47,7 +48,7 @@ namespace ngraph
const auto& ng_inputs = node.get_ng_inputs();
const auto el_type = ng_inputs.at(0)->get_output_element_type(0);
if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null())
if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
{
auto bias = ng_inputs.at(3);
// gates_count * 2 since B is: [Wb, Rb]

View File

@ -21,7 +21,9 @@
#include "core/graph.hpp"
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
#include "ngraph/function.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "utils/reshape.hpp"
namespace ngraph
@ -52,7 +54,7 @@ namespace ngraph
const std::shared_ptr<ngraph::Node>& body_cond)
{
bool loop_cond_value = false;
if (loop_cond->is_constant() &&
if (ngraph::op::is_constant(loop_cond) &&
loop_cond->get_element_type() == element::boolean)
{
loop_cond_value = as_type_ptr<default_opset::Constant>(loop_cond)
@ -61,7 +63,8 @@ namespace ngraph
}
// According to ONNX skipped cond input (is_null) means
// that is has true value
bool is_loop_cond_true = loop_cond->is_null() || loop_cond_value == true;
bool is_loop_cond_true =
ngraph::op::is_null(loop_cond) || loop_cond_value == true;
if (!is_loop_cond_true)
{
@ -76,7 +79,7 @@ namespace ngraph
{
const auto second_input =
body_cond->input_value(1).get_node_shared_ptr();
if (second_input->is_constant() &&
if (ngraph::op::is_constant(second_input) &&
second_input->get_element_type() == element::boolean &&
as_type_ptr<default_opset::Constant>(second_input)
->cast_vector<bool>()
@ -99,7 +102,7 @@ namespace ngraph
// At this moment nGraph TensorIterator doesn't have support for conditional
// termination of iterations.
CHECK_VALID_NODE(node,
!trip_count->is_null(),
!ngraph::op::is_null(trip_count),
"Currently nGraph requires trip count input to be provided.");
const OutputVector loop_carried_dependencies{std::next(ng_inputs.begin(), 2),

View File

@ -27,6 +27,7 @@
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/enum_names.hpp"
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
#include "ngraph/frontend/onnx_import/op/lstm.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
@ -91,7 +92,7 @@ namespace ngraph
// ------ Optional inputs ------
// The bias tensor for input gate. Shape [num_directions, 4*hidden_size]
if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null())
if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
{
auto bias = ng_inputs.at(3);
auto split_bias = builder::opset1::split(bias, 2, 1);
@ -106,7 +107,7 @@ namespace ngraph
0.f));
}
// The lengths of the sequences in a batch. Shape [batch_size]
if (ng_inputs.size() > 4 && !ng_inputs.at(4)->is_null())
if (ng_inputs.size() > 4 && !ngraph::op::is_null(ng_inputs.at(4)))
{
m_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] = ng_inputs.at(4);
}
@ -122,7 +123,7 @@ namespace ngraph
}
// The initial value of the hidden.
// Shape [num_directions, batch_size, hidden_size]
if (ng_inputs.size() > 5 && !ng_inputs.at(5)->is_null())
if (ng_inputs.size() > 5 && !ngraph::op::is_null(ng_inputs.at(5)))
{
m_map[LSTMInput::LSTM_INPUT_INIT_H] =
builder::opset1::reorder_axes(ng_inputs.at(5), {1, 0, 2});
@ -136,7 +137,7 @@ namespace ngraph
}
// The initial value of the cell.
// Shape [num_directions, batch_size, hidden_size]
if (ng_inputs.size() > 6 && !ng_inputs.at(6)->is_null())
if (ng_inputs.size() > 6 && !ngraph::op::is_null(ng_inputs.at(6)))
{
m_map[LSTMInput::LSTM_INPUT_INIT_C] =
builder::opset1::reorder_axes(ng_inputs.at(6), {1, 0, 2});
@ -149,7 +150,7 @@ namespace ngraph
std::vector<float>(batch_size * num_directions * hidden_size, 0.f));
}
// The weight tensor for peepholes. Shape [num_directions, 3*hidde_size]
if (ng_inputs.size() > 7 && !ng_inputs.at(7)->is_null())
if (ng_inputs.size() > 7 && !ngraph::op::is_null(ng_inputs.at(7)))
{
m_map[LSTMInput::LSTM_INPUT_P] = ng_inputs.at(7);
}

View File

@ -23,6 +23,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/convert.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/shape.hpp"
#include "pad.hpp"
#include "utils/convpool.hpp"
@ -112,7 +113,7 @@ namespace ngraph
data->get_element_type(), ngraph::Shape{}, {0});
}
if (pads->is_constant())
if (ngraph::op::is_constant(pads))
{
std::vector<std::int64_t> pads_vector =
ngraph::as_type_ptr<default_opset::Constant>(pads)

View File

@ -17,6 +17,7 @@
#include "resize.hpp"
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/op/util/op_types.hpp"
namespace ngraph
{
@ -75,7 +76,7 @@ namespace ngraph
attrs.mode = mode;
attrs.align_corners = false;
if (scales->is_constant() && data_shape.is_static())
if (ngraph::op::is_constant(scales) && data_shape.is_static())
{
const auto scales_const =
as_type_ptr<default_opset::Constant>(scales->shared_from_this());

View File

@ -23,6 +23,7 @@
#include "gather.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "utils/common.hpp"
namespace
@ -190,7 +191,8 @@ namespace ngraph
if (inputs.size() >= 4) // axes input provided
{
axes = inputs.at(3);
CHECK_VALID_NODE(node, axes->is_constant(), "Axes input must be constant");
CHECK_VALID_NODE(
node, ngraph::op::is_constant(axes), "Axes input must be constant");
}
else
{

View File

@ -18,6 +18,7 @@
#include "default_opset.hpp"
#include "exceptions.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "upsample.hpp"
namespace ngraph
@ -137,7 +138,7 @@ namespace ngraph
attrs.axes.insert(ax);
}
if (scales->is_constant() && data_shape.is_static())
if (ngraph::op::is_constant(scales) && data_shape.is_static())
{
const auto scales_const =
as_type_ptr<default_opset::Constant>(scales->shared_from_this());

View File

@ -24,6 +24,7 @@
#include "ngraph/builder/split.hpp"
#include "ngraph/check.hpp"
#include "ngraph/enum_names.hpp"
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
#include "recurrent.hpp"
namespace ngraph
@ -60,7 +61,7 @@ namespace ngraph
const std::size_t batch_size = m_map[OpInput::X]->get_shape().at(1);
const std::size_t num_directions = m_map[OpInput::W]->get_shape().front();
if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null())
if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
{
auto bias = ng_inputs.at(3);
auto split_bias = builder::opset1::split(bias, 2, 1);
@ -71,7 +72,7 @@ namespace ngraph
m_map[OpInput::B] = std::make_shared<default_opset::Constant>(
el_type, Shape{num_directions, gates_count * hidden_size}, 0.f);
}
if (ng_inputs.size() > 4 && !ng_inputs.at(4)->is_null())
if (ng_inputs.size() > 4 && !ngraph::op::is_null(ng_inputs.at(4)))
{
m_map[OpInput::SEQ_LENGTHS] = ng_inputs.at(4);
}
@ -81,7 +82,7 @@ namespace ngraph
element::i32, Shape{batch_size}, m_map[OpInput::X]->get_shape().at(0));
}
// The initial value of the hidden.
if (ng_inputs.size() > 5 && !ng_inputs.at(5)->is_null())
if (ng_inputs.size() > 5 && !ngraph::op::is_null(ng_inputs.at(5)))
{
m_map[OpInput::INIT_H] = ng_inputs.at(5);
}

View File

@ -22,6 +22,7 @@
#include "default_opset.hpp"
#include "ngraph/builder/make_constant.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/shape.hpp"
#include "reshape.hpp"
@ -102,7 +103,7 @@ namespace ngraph
node_shape);
// If node is a Constant, recreate as Constant with Shape{}
if (node->is_constant())
if (ngraph::op::is_constant(node))
{
const auto value =
ngraph::as_type_ptr<default_opset::Constant>(node)->get_data_ptr();

View File

@ -21,6 +21,7 @@
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/util.hpp"
using namespace std;
@ -83,7 +84,7 @@ void Function::validate_nodes_and_infer_types()
node->revalidate_and_infer_types();
// If we find a parameter make sure it is in the list of parameters of the function
if (node->is_parameter())
if (op::is_parameter(node))
{
auto it = std::find(m_parameters.begin(), m_parameters.end(), node);
if (it == m_parameters.end())

View File

@ -29,6 +29,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/provenance.hpp"
@ -130,7 +131,7 @@ void ngraph::replace_node(std::shared_ptr<Node> target,
std::shared_ptr<Node> replacement,
const std::vector<int64_t>& output_order)
{
if (target->is_output())
if (ngraph::op::is_output(target))
{
throw ngraph_error("Result nodes cannot be replaced.");
}
@ -185,7 +186,7 @@ void ngraph::replace_node(std::shared_ptr<Node> target,
void ngraph::replace_node(const std::shared_ptr<Node>& target,
const OutputVector& replacement_values)
{
if (target->is_output())
if (ngraph::op::is_output(target))
{
throw ngraph_error("Result nodes cannot be replaced.");
}
@ -258,7 +259,7 @@ bool ngraph::is_post_dominated(Node* X, Node* Y)
{
ngraph::Node* curr = stack.top();
visited.insert(curr);
if (curr->is_output())
if (ngraph::op::is_output(curr))
{
return false;
}
@ -465,7 +466,6 @@ pair<shared_ptr<op::Result>, shared_ptr<op::Parameter>>
// Make parameter node
shared_ptr<op::Parameter> par_node = make_shared<op::Parameter>(
src_node->get_output_element_type(0), src_node->get_output_shape(0));
par_node->set_placement(dst_node->get_placement());
// Fix input / output among src, dst and par
std::vector<Input<Node>> dst_inputs = get_inputs_from(*src_node, *dst_node);
@ -489,7 +489,6 @@ pair<shared_ptr<op::Result>, shared_ptr<op::Parameter>>
// Add res node
// Add [4], [5], [6], [7]
shared_ptr<op::Result> res_node = make_shared<op::Result>(src_node);
res_node->set_placement(src_node->get_placement());
return make_pair(res_node, par_node);
}
@ -641,7 +640,7 @@ bool ngraph::is_used(Node* node)
ngraph::Node* n = stack.top();
if (instances_seen.count(n) == 0)
{
if (n->is_output())
if (ngraph::op::is_output(n))
{
return true;
}
@ -675,7 +674,7 @@ bool ngraph::possibly_overwritten(Node* node)
{
for (auto& input : output.get_target_inputs())
{
if (input.get_node()->is_op())
if (op::is_op(input.get_node()))
{
auto op = static_cast<ngraph::op::Op*>(input.get_node());
if (auto op_annotations = op->get_op_annotations())
@ -714,7 +713,7 @@ bool ngraph::is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t
bool ngraph::compare_constants(const std::shared_ptr<Node>& n1, const std::shared_ptr<Node>& n2)
{
if (!(n1->is_constant() && n2->is_constant()))
if (!(op::is_constant(n1) && op::is_constant(n2)))
{
return false;
}

View File

@ -29,7 +29,6 @@
#include "ngraph/check.hpp"
#include "ngraph/function.hpp"
#include "ngraph/node.hpp"
#include "ngraph/placement.hpp"
namespace ngraph
{
@ -406,10 +405,6 @@ namespace ngraph
NGRAPH_API
std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func);
// Assert that nodes in the function is colocated and return that placement
NGRAPH_API
Placement get_colocated_function_placement(std::shared_ptr<Function> func);
NGRAPH_API
std::pair<std::shared_ptr<op::Result>, std::shared_ptr<op::v0::Parameter>>
insert_result_parameter_split(const std::shared_ptr<Node>& src_node,

View File

@ -28,7 +28,6 @@
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/pattern/matcher.hpp"
#include "ngraph/placement.hpp"
using namespace std;
using namespace ngraph;
@ -298,16 +297,6 @@ const std::deque<descriptor::Output>& Node::get_outputs() const
return m_outputs;
}
bool Node::is_output() const
{
return false;
}
bool Node::is_constant() const
{
return false;
}
const std::string& Node::description() const
{
// Terrible transitional kludge to keep description working while we change
@ -339,16 +328,6 @@ void Node::set_friendly_name(const string& name)
m_friendly_name = name;
}
Placement Node::get_placement() const
{
return m_placement;
}
void Node::set_placement(Placement placement)
{
m_placement = placement;
}
void Node::add_provenance_group_member(const shared_ptr<Node>& node)
{
m_provenance_group.insert(node);
@ -865,76 +844,6 @@ ResultVector ngraph::as_result_vector(const OutputVector& values)
return result;
}
std::tuple<element::Type, PartialShape>
Node::validate_and_infer_elementwise_args(const op::AutoBroadcastSpec& autob)
{
element::Type element_type = get_input_element_type(0);
PartialShape pshape = get_input_partial_shape(0);
if (get_input_size() > 1)
{
for (size_t i = 1; i < get_input_size(); ++i)
{
NODE_VALIDATION_CHECK(
this,
element::Type::merge(element_type, element_type, get_input_element_type(i)),
"Argument element types are inconsistent.");
if (autob.m_type == op::AutoBroadcastType::NONE)
{
NODE_VALIDATION_CHECK(this,
PartialShape::merge_into(pshape, get_input_partial_shape(i)),
"Argument shapes are inconsistent.");
}
else if (autob.m_type == op::AutoBroadcastType::NUMPY ||
autob.m_type == op::AutoBroadcastType::PDPD)
{
NODE_VALIDATION_CHECK(
this,
PartialShape::broadcast_merge_into(pshape, get_input_partial_shape(i), autob),
"Argument shapes are inconsistent.");
}
else
{
NODE_VALIDATION_CHECK(this, false, "Unsupported auto broadcast specification");
}
}
}
return std::make_tuple(element_type, pshape);
}
void Node::validate_and_infer_elementwise_arithmetic(const op::AutoBroadcastSpec& autob)
{
auto args_et_pshape = validate_and_infer_elementwise_args(autob);
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_CHECK(this,
args_et.is_dynamic() || args_et != element::boolean,
"Arguments cannot have boolean element type (argument element type: ",
args_et,
").");
set_output_type(0, args_et, args_pshape);
}
void Node::validate_and_infer_elementwise_logical(const op::AutoBroadcastSpec& autob)
{
auto args_et_pshape = validate_and_infer_elementwise_args(autob);
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_CHECK(
this,
args_et.is_dynamic() || args_et == element::boolean,
"Operands for logical operators must have boolean element type but have element type ",
args_et,
".");
set_output_type(0, element::boolean, args_pshape);
}
bool Node::match_value(pattern::Matcher* matcher,
const Output<Node>& pattern_value,
const Output<Node>& graph_value)

View File

@ -42,7 +42,6 @@
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/op_annotations.hpp"
#include "ngraph/output_vector.hpp"
#include "ngraph/placement.hpp"
#include "ngraph/strides.hpp"
#include "ngraph/type.hpp"
@ -154,13 +153,6 @@ namespace ngraph
using type_info_t = DiscreteTypeInfo;
protected:
std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args(
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
void validate_and_infer_elementwise_arithmetic(
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
void validate_and_infer_elementwise_logical(
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
/// \brief Construct an unitialized Node
Node() {}
/// \brief Construct an unitialized Node
@ -176,19 +168,14 @@ namespace ngraph
void safe_delete(NodeVector& nodes, bool recurse);
public:
virtual bool is_parameter() const { return false; }
virtual bool is_output() const { return false; }
virtual bool is_constant() const { return false; }
virtual ~Node();
virtual bool visit_attributes(AttributeVisitor& visitor) { return false; }
virtual bool is_unary_elementwise_arithmetic() const { return false; }
virtual bool is_binary_elementwise_arithmetic() const { return false; }
virtual bool is_binary_elementwise_comparison() const { return false; }
virtual bool is_binary_elementwise_logical() const { return false; }
/// \returns true if node supports autobroadcast operations
virtual bool supports_auto_broadcast() const { return false; }
/// \returns the autobroadcasr spec
virtual const op::AutoBroadcastSpec& get_autob() const;
/// \returns true if the node can decompose
virtual bool supports_decompose() const { return false; }
/// \brief Evaluates the op on input_values putting results in output_values
/// \returns true if successful
virtual bool evaluate(const HostTensorVector& output_values,
@ -276,15 +263,7 @@ namespace ngraph
const element::Type& element_type,
const PartialShape& pshape);
virtual bool is_parameter() const { return false; }
virtual bool is_output() const;
virtual bool is_constant() const;
virtual bool is_null() const { return false; }
virtual bool is_op() const { return false; }
virtual bool is_pattern() const { return false; }
virtual bool is_commutative() const { return false; }
virtual bool is_dynamic() const;
virtual bool has_state() const { return false; }
size_t get_instance_id() const { return m_instance_id; }
/// \brief Writes a description of a node to a stream
/// \param os The stream; should be returned
@ -440,12 +419,6 @@ namespace ngraph
/// True if this and node have one output with same element type and shape
bool has_same_type(std::shared_ptr<const Node> node) const;
/// Get device placement
Placement get_placement() const;
/// Set device placement
void set_placement(Placement placement);
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
RTMap& get_rt_info() { return m_rt_info; }
@ -557,7 +530,6 @@ namespace ngraph
std::set<std::shared_ptr<Node>> m_provenance_group;
std::deque<descriptor::Input> m_inputs;
std::deque<descriptor::Output> m_outputs;
Placement m_placement = Placement::DEFAULT;
std::shared_ptr<ngraph::op::util::OpAnnotations> m_op_annotations;
std::map<std::string, std::shared_ptr<Variant>> m_rt_info;
};

View File

@ -57,7 +57,6 @@ namespace ngraph
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
@ -97,7 +96,6 @@ namespace ngraph
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; }
size_t get_version() const override { return 1; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;

View File

@ -54,7 +54,6 @@ namespace ngraph
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};

View File

@ -40,6 +40,7 @@ namespace ngraph
public:
static constexpr NodeTypeInfo type_info{"Constant", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
bool is_constant() const override { return true; }
Constant() = default;
/// \brief Initialize a constant from tensor
@ -442,7 +443,6 @@ namespace ngraph
get_data_ptr());
}
bool is_constant() const override { return true; }
bool get_all_data_elements_bitwise_identical() const
{
return m_all_elements_bitwise_identical;

View File

@ -18,6 +18,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/crop_and_resize.hpp"
#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
@ -94,7 +95,8 @@ void op::CropAndResize::validate_and_infer_types()
auto& crop_size_et = crop_size.get_element_type();
NODE_VALIDATION_CHECK(this, crop_size_et.is_integral(), "crops_size must be integral");
auto crop_size_node = crop_size.get_node_shared_ptr();
NODE_VALIDATION_CHECK(this, crop_size_node->is_constant(), "crop_size must be a constant");
NODE_VALIDATION_CHECK(
this, ngraph::op::is_constant(crop_size_node), "crop_size must be a constant");
auto crop_size_const = static_pointer_cast<op::Constant>(crop_size_node);
if (crop_size_et == element::i8)
{

View File

@ -63,7 +63,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
@ -111,7 +110,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};

View File

@ -23,6 +23,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/fused/batch_to_space.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/shape.hpp"
using namespace std;
@ -152,13 +153,13 @@ void ngraph::op::v1::BatchToSpace::pre_validate_and_infer_types()
auto block = input_value(1);
auto crops_begin = input_value(2);
auto crops_end = input_value(3);
NGRAPH_CHECK(block.get_node_shared_ptr()->is_constant(),
NGRAPH_CHECK(ngraph::op::is_constant(block.get_node()),
"block_shape input node is expected to be a static constant");
NGRAPH_CHECK(crops_begin.get_node_shared_ptr()->is_constant(),
NGRAPH_CHECK(ngraph::op::is_constant(crops_begin.get_node()),
"crops_begin input node is expected to be a static constant");
NGRAPH_CHECK(crops_end.get_node_shared_ptr()->is_constant(),
NGRAPH_CHECK(ngraph::op::is_constant(crops_end.get_node()),
"crops_end input node is expected to be a static constant");
const auto& data_type = get_input_element_type(0);

View File

@ -23,6 +23,7 @@
#include "ngraph/op/divide.hpp"
#include "ngraph/op/fused/normalize_l2.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
@ -55,7 +56,7 @@ void op::NormalizeL2::pre_validate_and_infer_types()
const auto& input_rank = input_pshape.rank();
const auto& axes_rank = axes_pshape.rank();
NODE_VALIDATION_CHECK(this, axes_node->is_constant(), "Input axes must be Constant type");
NODE_VALIDATION_CHECK(this, op::is_constant(axes_node), "Input axes must be Constant type");
if (axes_rank.is_static())
{

View File

@ -22,6 +22,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/fused/space_to_batch.hpp"
#include "ngraph/op/pad.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/shape.hpp"
using namespace std;
@ -135,13 +136,13 @@ void ngraph::op::v1::SpaceToBatch::pre_validate_and_infer_types()
auto block = input_value(1);
auto crops_begin = input_value(2);
auto crops_end = input_value(3);
NGRAPH_CHECK(block.get_node_shared_ptr()->is_constant(),
NGRAPH_CHECK(ngraph::op::is_constant(block.get_node()),
"block_shape input node is expected to be a static constant");
NGRAPH_CHECK(crops_begin.get_node_shared_ptr()->is_constant(),
NGRAPH_CHECK(ngraph::op::is_constant(crops_begin.get_node()),
"crops_begin input node is expected to be a static constant");
NGRAPH_CHECK(crops_end.get_node_shared_ptr()->is_constant(),
NGRAPH_CHECK(ngraph::op::is_constant(crops_end.get_node()),
"crops_end input node is expected to be a static constant");
const auto& data_type = get_input_element_type(0);

View File

@ -20,6 +20,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/unsqueeze.hpp"
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/runtime/reference/copy.hpp"
#include "ngraph/validation_util.hpp"
@ -42,7 +43,7 @@ void op::Unsqueeze::pre_validate_and_infer_types()
const auto axes_node = input_value(1).get_node_shared_ptr();
if (data_rank.is_dynamic() || !axes_node->is_constant())
if (data_rank.is_dynamic() || !op::is_constant(axes_node))
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
return;

View File

@ -47,7 +47,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
@ -80,7 +79,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};

View File

@ -47,7 +47,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
@ -80,7 +79,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};

View File

@ -47,7 +47,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
@ -80,7 +79,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};

View File

@ -17,6 +17,7 @@
#include "ngraph/op/non_max_suppression.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
@ -170,7 +171,7 @@ void op::v1::NonMaxSuppression::validate_and_infer_types()
const auto max_output_boxes_per_class = input_value(2).get_node_shared_ptr();
if (num_boxes_boxes.is_static() && scores_ps[1].is_static() &&
max_output_boxes_per_class->is_constant())
op::is_constant(max_output_boxes_per_class))
{
const auto num_boxes = num_boxes_boxes.get_length();
const auto max_output_boxes_per_class = max_boxes_output_from_input();
@ -384,7 +385,7 @@ void op::v3::NonMaxSuppression::validate_and_infer_types()
const auto num_boxes_boxes = boxes_ps[1];
const auto max_output_boxes_per_class_node = input_value(2).get_node_shared_ptr();
if (num_boxes_boxes.is_static() && scores_ps[1].is_static() &&
max_output_boxes_per_class_node->is_constant())
op::is_constant(max_output_boxes_per_class_node))
{
const auto num_boxes = num_boxes_boxes.get_length();
const auto num_classes = scores_ps[1].get_length();
@ -517,7 +518,7 @@ void op::v4::NonMaxSuppression::validate_and_infer_types()
const auto num_boxes_boxes = boxes_ps[1];
const auto max_output_boxes_per_class_node = input_value(2).get_node_shared_ptr();
if (num_boxes_boxes.is_static() && scores_ps[0].is_static() && scores_ps[1].is_static() &&
max_output_boxes_per_class_node->is_constant())
op::is_constant(max_output_boxes_per_class_node))
{
const auto num_boxes = num_boxes_boxes.get_length();
const auto num_classes = scores_ps[1].get_length();

View File

@ -16,6 +16,7 @@
#include "ngraph/op/not.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/util/elementwise_args.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/not.hpp"
@ -39,7 +40,7 @@ bool ngraph::op::v1::LogicalNot::visit_attributes(AttributeVisitor& visitor)
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
void op::v1::LogicalNot::validate_and_infer_types()
{
auto args_et_pshape = validate_and_infer_elementwise_args();
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
@ -106,7 +107,7 @@ op::v0::Not::Not(const Output<Node>& arg)
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
void op::v0::Not::validate_and_infer_types()
{
auto args_et_pshape = validate_and_infer_elementwise_args();
auto args_et_pshape = ngraph::op::util::validate_and_infer_elementwise_args(this);
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);

View File

@ -47,7 +47,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
@ -79,7 +78,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};

View File

@ -16,6 +16,7 @@
#include "ngraph/op/one_hot.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
@ -152,7 +153,7 @@ void op::v1::OneHot::validate_and_infer_types()
const auto& depth = input_value(1).get_node_shared_ptr();
PartialShape result_shape{PartialShape::dynamic()};
if (indices_shape.is_static() && indices_shape.rank().is_static() && depth->is_constant())
if (indices_shape.is_static() && indices_shape.rank().is_static() && op::is_constant(depth))
{
const auto indices_rank = indices_shape.rank().get_length();

View File

@ -27,8 +27,6 @@ namespace ngraph
/// Root of all actual ops
class NGRAPH_API Op : public Node
{
public:
virtual bool is_op() const override { return true; }
protected:
Op()
: Node()

View File

@ -52,7 +52,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
@ -84,7 +83,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};

View File

@ -19,6 +19,7 @@
#include "ngraph/except.hpp"
#include "ngraph/op/broadcast.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
@ -271,8 +272,8 @@ void op::v1::Pad::validate_and_infer_types()
auto pads_begin_node = input_value(1).get_node_shared_ptr();
auto pads_end_node = input_value(2).get_node_shared_ptr();
if (arg_shape_rank.is_static() && pads_begin_node->is_constant() &&
pads_end_node->is_constant())
if (arg_shape_rank.is_static() && op::is_constant(pads_begin_node) &&
op::is_constant(pads_end_node))
{
const auto implied_rank = pads_begin_coord.size();
std::vector<Dimension> result_dims(implied_rank, Dimension::dynamic());

View File

@ -48,7 +48,6 @@ namespace ngraph
bool visit_attributes(AttributeVisitor& visitor) override;
bool is_parameter() const override { return true; }
void validate_and_infer_types() override;
bool get_cacheable() const { return m_cacheable; }
@ -69,7 +68,7 @@ namespace ngraph
{
m_element_type = element_type;
}
bool is_parameter() const override { return true; }
protected:
bool m_cacheable;
PartialShape m_partial_shape;

View File

@ -44,14 +44,13 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual bool is_output() const override { return true; }
void set_needs_default_layout(bool val) { m_needs_default_layout = val; }
bool needs_default_layout() const { return m_needs_default_layout; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
bool constant_fold(OutputVector& output_values,
const OutputVector& inputs_values) override;
bool is_output() const override { return true; }
private:
bool m_needs_default_layout{false};
};

View File

@ -21,6 +21,7 @@
#include "ngraph/function.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/reverse.hpp"
#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
@ -133,7 +134,7 @@ void op::v1::Reverse::validate_and_infer_types()
const auto rank = input_rank.get_length();
const auto rev_axes_node = input_value(1).get_node_shared_ptr();
if (rev_axes_node->is_constant())
if (op::is_constant(rev_axes_node))
{
const auto rev_axes_constant = as_type_ptr<op::Constant>(rev_axes_node);

View File

@ -16,6 +16,7 @@
#include "ngraph/op/scatter_elements_update.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/runtime/reference/scatter_elements_update.hpp"
#include "ngraph/validation_util.hpp"
@ -91,7 +92,7 @@ void op::v3::ScatterElementsUpdate::validate_and_infer_types()
" and: ",
updates_shape);
if (input_value(3).get_node_shared_ptr()->is_constant() && data_shape.rank().is_static())
if (ngraph::op::is_constant(input_value(3).get_node()) && data_shape.rank().is_static())
{
const auto axis_input = as_type_ptr<op::v0::Constant>(input_value(3).get_node_shared_ptr());
auto axis = axis_input->cast_vector<int64_t>().at(0);

View File

@ -117,7 +117,6 @@ namespace ngraph
{
m_auto_broadcast = auto_broadcast;
}
bool supports_auto_broadcast() const override { return true; }
// TODO: Move all uses of get_autob to get_auto_broadcast() and remove this.
const AutoBroadcastSpec& get_autob() const override { return m_auto_broadcast; }
private:

View File

@ -25,6 +25,7 @@
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/runtime/reference/softmax.hpp"
#include "ngraph/util.hpp"
@ -53,7 +54,7 @@ op::v0::Softmax::Softmax(const Output<Node>& arg, const Output<Node>& axes)
bool op::v0::Softmax::are_axes_constant() const
{
return input_value(1).get_node_shared_ptr()->is_constant();
return op::is_constant(input_value(1).get_node());
}
const AxisSet op::v0::Softmax::get_axes() const

View File

@ -19,6 +19,7 @@
#include "ngraph/builder/split.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/split.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/validation_util.hpp"
using namespace std;
@ -51,7 +52,8 @@ void op::v0::Split::pre_validate_and_infer_types()
NODE_VALIDATION_CHECK(this, is_scalar(axis_shape), "The 'axis' input node must be scalar");
const auto axis_node = input_value(1).get_node_shared_ptr();
NODE_VALIDATION_CHECK(this, axis_node->is_constant(), "The 'axis' input node must be constant");
NODE_VALIDATION_CHECK(
this, op::is_constant(axis_node), "The 'axis' input node must be constant");
const auto axis_node_const = as_type_ptr<op::Constant>(axis_node);
m_axis = axis_node_const->get_data_ptr<int64_t>()[0];
@ -142,7 +144,7 @@ void op::v1::Split::validate_and_infer_types()
NODE_VALIDATION_CHECK(
this, axis_et.is_integral(), "The 'axis' input only accepts integral types");
if (input_value(1).get_node_shared_ptr()->is_constant() && data_ps.is_static())
if (op::is_constant(input_value(1).get_node()) && data_ps.is_static())
{
const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
auto axis = axis_input->cast_vector<int64_t>()[0];

View File

@ -20,6 +20,7 @@
#include "ngraph/axis_vector.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/topk.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/validation_util.hpp"
@ -482,7 +483,7 @@ void op::v1::TopK::validate_and_infer_types()
this, k_partial_shape.rank().compatible(0), "The 'K' input must be a scalar.");
size_t k = 0;
if (input_value(1).get_node_shared_ptr()->is_constant())
if (op::is_constant(input_value(1).get_node()))
{
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(),
get_input_element_type(1));
@ -638,7 +639,7 @@ shared_ptr<Node> op::v1::TopK::clone_with_new_inputs(const OutputVector& new_arg
size_t op::v1::TopK::get_k() const
{
size_t k = 0;
if (input_value(1).get_node_shared_ptr()->is_constant())
if (op::is_constant(input_value(1).get_node()))
{
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(),
get_input_element_type(1));
@ -668,7 +669,7 @@ bool op::v1::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVec
// 2. get value of k - from constant node or from HT
size_t k = 0;
if (input_value(1).get_node_shared_ptr()->is_constant())
if (op::is_constant(input_value(1).get_node()))
{
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(),
get_input_element_type(1));

View File

@ -16,6 +16,7 @@
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/util/elementwise_args.hpp"
using namespace std;
using namespace ngraph;
@ -33,6 +34,22 @@ op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const Output<
{
}
void op::util::BinaryElementwiseArithmetic::validate_and_infer_elementwise_arithmetic(
const op::AutoBroadcastSpec& autob)
{
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this, autob);
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_CHECK(this,
args_et.is_dynamic() || args_et != element::boolean,
"Arguments cannot have boolean element type (argument element type: ",
args_et,
").");
set_output_type(0, args_et, args_pshape);
}
void op::util::BinaryElementwiseArithmetic::validate_and_infer_types()
{
validate_and_infer_elementwise_arithmetic(m_autob);

View File

@ -69,12 +69,11 @@ namespace ngraph
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
bool is_binary_elementwise_arithmetic() const override { return true; }
bool supports_auto_broadcast() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
private:
AutoBroadcastSpec m_autob;
void validate_and_infer_elementwise_arithmetic(const op::AutoBroadcastSpec& autob);
};
}
}

View File

@ -16,6 +16,7 @@
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/util/elementwise_args.hpp"
using namespace std;
using namespace ngraph;
@ -35,7 +36,7 @@ op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const Output<
void op::util::BinaryElementwiseComparison::validate_and_infer_types()
{
auto args_et_pshape = validate_and_infer_elementwise_args(m_autob);
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this, m_autob);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
set_output_type(0, element::boolean, args_pshape);

View File

@ -71,8 +71,6 @@ namespace ngraph
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
bool supports_auto_broadcast() const override { return true; }
bool is_binary_elementwise_comparison() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
private:

View File

@ -16,6 +16,7 @@
#include "ngraph/op/util/binary_elementwise_logical.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/util/elementwise_args.hpp"
using namespace std;
using namespace ngraph;
@ -32,6 +33,23 @@ op::util::BinaryElementwiseLogical::BinaryElementwiseLogical(const Output<Node>&
{
}
void op::util::BinaryElementwiseLogical::validate_and_infer_elementwise_logical(
const op::AutoBroadcastSpec& autob)
{
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this, autob);
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_CHECK(
this,
args_et.is_dynamic() || args_et == element::boolean,
"Operands for logical operators must have boolean element type but have element type ",
args_et,
".");
set_output_type(0, element::boolean, args_pshape);
}
void op::util::BinaryElementwiseLogical::validate_and_infer_types()
{
validate_and_infer_elementwise_logical(m_autob);

View File

@ -68,11 +68,10 @@ namespace ngraph
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
bool supports_auto_broadcast() const override { return true; }
bool is_binary_elementwise_logical() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
private:
void validate_and_infer_elementwise_logical(const op::AutoBroadcastSpec& autob);
AutoBroadcastSpec m_autob;
};
}

View File

@ -19,6 +19,7 @@
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/partial_shape.hpp"
#include "ngraph/runtime/reference/broadcast.hpp"
@ -192,7 +193,7 @@ void op::util::BroadcastBase::validate_and_infer_types()
" doesn't match rank of input tensor ",
arg_shape.size());
if (shape_constant && input_value(2).get_node_shared_ptr()->is_constant())
if (shape_constant && op::is_constant(input_value(2).get_node()))
{
auto target_shape = shape_constant->get_shape_val();
auto axes_mapping_val =

View File

@ -0,0 +1,61 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "elementwise_args.hpp"
using namespace ngraph;
std::tuple<element::Type, PartialShape>
ngraph::op::util::validate_and_infer_elementwise_args(Node* node,
const op::AutoBroadcastSpec& autob)
{
NGRAPH_CHECK(node != nullptr, "nGraph node is empty! Cannot validate eltwise arguments.");
element::Type element_type = node->get_input_element_type(0);
PartialShape pshape = node->get_input_partial_shape(0);
if (node->get_input_size() > 1)
{
for (size_t i = 1; i < node->get_input_size(); ++i)
{
NODE_VALIDATION_CHECK(
node,
element::Type::merge(element_type, element_type, node->get_input_element_type(i)),
"Argument element types are inconsistent.");
if (autob.m_type == op::AutoBroadcastType::NONE)
{
NODE_VALIDATION_CHECK(
node,
PartialShape::merge_into(pshape, node->get_input_partial_shape(i)),
"Argument shapes are inconsistent.");
}
else if (autob.m_type == op::AutoBroadcastType::NUMPY ||
autob.m_type == op::AutoBroadcastType::PDPD)
{
NODE_VALIDATION_CHECK(node,
PartialShape::broadcast_merge_into(
pshape, node->get_input_partial_shape(i), autob),
"Argument shapes are inconsistent.");
}
else
{
NODE_VALIDATION_CHECK(node, false, "Unsupported auto broadcast specification");
}
}
}
return std::make_tuple(element_type, pshape);
}

View File

@ -16,22 +16,16 @@
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "ngraph/node.hpp"
namespace ngraph
{
enum class Placement
namespace op
{
DEFAULT,
INTERPRETER,
CPU,
GPU,
NNP,
};
std::string placement_to_string(Placement placement);
namespace util
{
std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args(
Node* node, const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
}
}
}

View File

@ -30,7 +30,6 @@ namespace ngraph
class NGRAPH_API FusedOp : public Op
{
public:
bool supports_decompose() const final { return true; }
// Fused op decomposition can be performed in the presence of
// partial shapes
virtual bool can_decompose_with_partial_shapes() { return false; }

View File

@ -0,0 +1,158 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/and.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/equal.hpp"
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/not_equal.hpp"
#include "ngraph/op/op.hpp"
#include "ngraph/op/or.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/result.hpp"
#include "ngraph/op/select.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/op/util/binary_elementwise_logical.hpp"
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/op/xor.hpp"
#include "ngraph/type.hpp"
bool ngraph::op::is_unary_elementwise_arithmetic(const ngraph::Node* node)
{
return dynamic_cast<const ngraph::op::util::UnaryElementwiseArithmetic*>(node) != nullptr;
}
bool ngraph::op::is_binary_elementwise_arithmetic(const ngraph::Node* node)
{
return dynamic_cast<const ngraph::op::util::BinaryElementwiseArithmetic*>(node) != nullptr;
}
bool ngraph::op::is_binary_elementwise_comparison(const ngraph::Node* node)
{
return dynamic_cast<const ngraph::op::util::BinaryElementwiseComparison*>(node) != nullptr;
}
bool ngraph::op::is_binary_elementwise_logical(const ngraph::Node* node)
{
return dynamic_cast<const ngraph::op::util::BinaryElementwiseLogical*>(node) != nullptr;
}
bool ngraph::op::supports_auto_broadcast(const ngraph::Node* node)
{
return dynamic_cast<const ngraph::op::v1::Select*>(node) != nullptr ||
dynamic_cast<const ngraph::op::util::BinaryElementwiseComparison*>(node) != nullptr ||
dynamic_cast<const ngraph::op::util::BinaryElementwiseLogical*>(node) != nullptr ||
dynamic_cast<const ngraph::op::util::BinaryElementwiseArithmetic*>(node) != nullptr;
}
bool ngraph::op::supports_decompose(const ngraph::Node* node)
{
return dynamic_cast<const ngraph::op::util::FusedOp*>(node) != nullptr;
}
bool ngraph::op::is_op(const ngraph::Node* node)
{
return dynamic_cast<const ngraph::op::Op*>(node) != nullptr;
}
bool ngraph::op::is_parameter(const ngraph::Node* node)
{
return dynamic_cast<const ngraph::op::Parameter*>(node) != nullptr;
}
bool ngraph::op::is_output(const ngraph::Node* node)
{
return dynamic_cast<const ngraph::op::Result*>(node) != nullptr;
}
bool ngraph::op::is_constant(const ngraph::Node* node)
{
return dynamic_cast<const ngraph::op::Constant*>(node) != nullptr;
}
bool ngraph::op::is_commutative(const ngraph::Node* node)
{
return dynamic_cast<const ngraph::op::v0::Add*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v1::Add*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v0::Maximum*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v1::Maximum*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v0::Equal*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v1::Equal*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v0::NotEqual*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v1::NotEqual*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v1::LogicalAnd*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v0::Xor*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v1::LogicalXor*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v0::Minimum*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v1::Minimum*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v0::Multiply*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v1::Multiply*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v0::Or*>(node) != nullptr ||
dynamic_cast<const ngraph::op::v1::LogicalOr*>(node) != nullptr;
}
bool ngraph::op::is_unary_elementwise_arithmetic(const std::shared_ptr<ngraph::Node>& node)
{
return is_unary_elementwise_arithmetic(node.get());
}
bool ngraph::op::is_binary_elementwise_arithmetic(const std::shared_ptr<ngraph::Node>& node)
{
return is_binary_elementwise_arithmetic(node.get());
}
bool ngraph::op::is_binary_elementwise_comparison(const std::shared_ptr<ngraph::Node>& node)
{
return is_binary_elementwise_comparison(node.get());
}
bool ngraph::op::is_binary_elementwise_logical(const std::shared_ptr<ngraph::Node>& node)
{
return is_binary_elementwise_logical(node.get());
}
bool ngraph::op::supports_auto_broadcast(const std::shared_ptr<ngraph::Node>& node)
{
return supports_auto_broadcast(node.get());
}
bool ngraph::op::supports_decompose(const std::shared_ptr<ngraph::Node>& node)
{
return supports_decompose(node.get());
}
bool ngraph::op::is_op(const std::shared_ptr<ngraph::Node>& node)
{
return is_op(node.get());
}
bool ngraph::op::is_parameter(const std::shared_ptr<ngraph::Node>& node)
{
return is_parameter(node.get());
}
bool ngraph::op::is_output(const std::shared_ptr<ngraph::Node>& node)
{
return is_output(node.get());
}
bool ngraph::op::is_constant(const std::shared_ptr<ngraph::Node>& node)
{
return is_constant(node.get());
}
bool ngraph::op::is_commutative(const std::shared_ptr<ngraph::Node>& node)
{
return is_commutative(node.get());
}

View File

@ -0,0 +1,79 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <memory>
#include "ngraph/ngraph_visibility.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace op
{
NGRAPH_API
bool is_unary_elementwise_arithmetic(const ngraph::Node* node);
NGRAPH_API
bool is_binary_elementwise_arithmetic(const ngraph::Node* node);
NGRAPH_API
bool is_binary_elementwise_comparison(const ngraph::Node* node);
NGRAPH_API
bool is_binary_elementwise_logical(const ngraph::Node* node);
NGRAPH_API
bool supports_auto_broadcast(const ngraph::Node* node);
NGRAPH_API
bool supports_decompose(const ngraph::Node* node);
NGRAPH_API
bool is_op(const ngraph::Node* node);
NGRAPH_API
bool is_parameter(const ngraph::Node* node);
NGRAPH_API
bool is_output(const ngraph::Node* node);
NGRAPH_API
bool is_constant(const ngraph::Node* node);
NGRAPH_API
bool is_commutative(const ngraph::Node* node);
NGRAPH_API
bool is_unary_elementwise_arithmetic(const std::shared_ptr<ngraph::Node>& node);
NGRAPH_API
bool is_binary_elementwise_arithmetic(const std::shared_ptr<ngraph::Node>& node);
NGRAPH_API
bool is_binary_elementwise_comparison(const std::shared_ptr<ngraph::Node>& node);
NGRAPH_API
bool is_binary_elementwise_logical(const std::shared_ptr<ngraph::Node>& node);
NGRAPH_API
bool supports_auto_broadcast(const std::shared_ptr<ngraph::Node>& node);
NGRAPH_API
bool supports_decompose(const std::shared_ptr<ngraph::Node>& node);
NGRAPH_API
bool is_op(const std::shared_ptr<ngraph::Node>& node);
NGRAPH_API
bool is_parameter(const std::shared_ptr<ngraph::Node>& node);
NGRAPH_API
bool is_output(const std::shared_ptr<ngraph::Node>& node);
NGRAPH_API
bool is_constant(const std::shared_ptr<ngraph::Node>& node);
NGRAPH_API
bool is_commutative(const std::shared_ptr<ngraph::Node>& node);
}
}

View File

@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/util/scatter_base.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/validation_util.hpp"
@ -80,7 +81,7 @@ void op::util::ScatterBase::validate_and_infer_types()
bool compatible = true;
int64_t axis;
bool is_axis_constant = input_value(AXIS).get_node_shared_ptr()->is_constant();
bool is_axis_constant = op::is_constant(input_value(AXIS).get_node());
// Get axis value if possible.
if (is_axis_constant && data_shape.rank().is_static())

View File

@ -15,6 +15,7 @@
//*****************************************************************************
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/elementwise_args.hpp"
using namespace ngraph;
@ -28,6 +29,21 @@ op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const Output<No
{
}
void op::util::UnaryElementwiseArithmetic::validate_and_infer_elementwise_arithmetic()
{
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
element::Type& args_et = std::get<0>(args_et_pshape);
PartialShape& args_pshape = std::get<1>(args_et_pshape);
NODE_VALIDATION_CHECK(this,
args_et.is_dynamic() || args_et != element::boolean,
"Arguments cannot have boolean element type (argument element type: ",
args_et,
").");
set_output_type(0, args_et, args_pshape);
}
void op::util::UnaryElementwiseArithmetic::validate_and_infer_types()
{
validate_and_infer_elementwise_arithmetic();

View File

@ -57,8 +57,10 @@ namespace ngraph
public:
void validate_and_infer_types() override;
bool is_unary_elementwise_arithmetic() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
private:
void validate_and_infer_elementwise_arithmetic();
};
}
}

View File

@ -17,6 +17,7 @@
#include <numeric>
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/op/variadic_split.hpp"
#include "ngraph/validation_util.hpp"
@ -62,8 +63,8 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types()
const auto& data_type = data.get_element_type();
set_output_size(num_outputs);
if (data_shape.rank().is_static() && axis_input->is_constant() &&
split_lengths_input->is_constant())
if (data_shape.rank().is_static() && op::is_constant(axis_input) &&
op::is_constant(split_lengths_input))
{
const auto axis_input_constant = as_type_ptr<op::Constant>(axis_input);
auto axis_val = axis_input_constant->cast_vector<int64_t>()[0];

View File

@ -52,7 +52,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
bool visit_attributes(AttributeVisitor& visitor) override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
@ -85,7 +84,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual bool is_commutative() const override { return true; }
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};

View File

@ -168,5 +168,6 @@
#include "ngraph/op/topk.hpp"
#include "ngraph/op/transpose.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/op/variadic_split.hpp"
#include "ngraph/op/xor.hpp"

View File

@ -38,6 +38,7 @@
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/transpose.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/opsets/opset2.hpp"
#include "ngraph/opsets/opset3.hpp"
#include "ngraph/pattern/matcher.hpp"
@ -818,7 +819,7 @@ bool pass::AlgebraicSimplification::run_on_function(shared_ptr<Function> f)
bool replaced = false;
for (auto n : f->get_ordered_ops())
{
if (n->is_output() || n->is_parameter())
if (op::is_output(n) || op::is_parameter(n))
{
continue;
}

View File

@ -17,6 +17,7 @@
#include <sstream>
#include "common_function_collection.hpp"
#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
@ -48,11 +49,11 @@ bool pass::CommonFunctionCollection::run_on_module(vector<shared_ptr<Function>>&
{
for (const shared_ptr<Node>& n : current_function->get_ordered_ops())
{
if (n->is_constant() || n->is_parameter())
if (op::is_constant(n) || op::is_parameter(n))
{
continue;
}
if (n->is_op())
if (op::is_op(n))
{
auto op = std::static_pointer_cast<op::Op>(n);
auto annotations = op->get_op_annotations();

View File

@ -58,6 +58,7 @@
#include "ngraph/op/sum.hpp"
#include "ngraph/op/tan.hpp"
#include "ngraph/op/tanh.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/pattern/matcher.hpp"
using namespace std;
@ -277,7 +278,7 @@ namespace std
// TODO: Do we need another map, so we could
// specify how to compute hash for each op?
if (p_this.is_commutative())
if (ngraph::op::is_commutative(&p_this))
{
sort(begin(cargs), end(cargs));
}
@ -301,7 +302,7 @@ bool ngraph::pass::CommonSubexpressionElimination::run_on_function(shared_ptr<ng
for (auto n : f->get_ordered_ops())
{
if (n->is_output() || n->is_parameter())
if (op::is_output(n) || op::is_parameter(n))
{
continue;
}

View File

@ -16,6 +16,7 @@
#include "ngraph/pass/fused_op_decomposition.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/provenance.hpp"
using namespace std;
@ -30,7 +31,7 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
{
bool modified = false;
if (node->supports_decompose())
if (op::supports_decompose(node))
{
if (m_has_direct_support && m_has_direct_support(*node))
{

View File

@ -21,13 +21,14 @@
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
#include "ngraph/op/util/binary_elementwise_logical.hpp"
#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
bool ngraph::pass::ImplicitBroadcastElimination::run_on_node(std::shared_ptr<Node> node)
{
if (node->supports_auto_broadcast())
if (ngraph::op::supports_auto_broadcast(node))
{
if (node->get_autob().m_type != op::AutoBroadcastType::NONE)
{
@ -45,7 +46,7 @@ bool ngraph::pass::ImplicitBroadcastElimination::run_on_node(std::shared_ptr<Nod
NodeVector ngraph::pass::explicit_broadcast(std::shared_ptr<Node>& node)
{
NodeVector rc;
if (node->supports_auto_broadcast())
if (ngraph::op::supports_auto_broadcast(node))
{
auto autob = node->get_autob();
if (autob.m_type == op::AutoBroadcastType::NONE)

View File

@ -22,6 +22,7 @@
#include "ngraph/op/concat.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/pass/liveness.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pass/memory_layout.hpp"
@ -48,7 +49,7 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function)
std::map<descriptor::Tensor*, descriptor::Tensor*> in_place_outputs;
std::set<const descriptor::Tensor*> reused_inputs;
if (node->is_op())
if (op::is_op(node))
{
auto op = std::static_pointer_cast<op::Op>(node);
// concat and slice in_place_oi should be treated differently
@ -67,7 +68,7 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function)
if ((node->liveness_free_list.count(input) != 0 ||
is_type<op::GetOutputElement>(node) ||
(m_disable_memory_sharing && !oi_pair.destructive &&
!input_node->is_parameter() && !input_node->is_constant())) &&
!op::is_parameter(input_node) && !op::is_constant(input_node))) &&
node->liveness_new_list.count(output) != 0)
{

View File

@ -35,6 +35,7 @@
#include "ngraph/op/slice.hpp"
#include "ngraph/op/stop_gradient.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/opsets/opset3.hpp"
#include "ngraph/util.hpp"
#include "nop_elimination.hpp"
@ -344,7 +345,7 @@ static bool eliminate_squeeze(const std::shared_ptr<Node>& node)
if (auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(input))
{
PartialShape data_shape;
if (input->is_parameter())
if (op::is_parameter(input))
{
data_shape = unsqueeze->input(0).get_partial_shape();
}
@ -393,7 +394,7 @@ static bool eliminate_squeeze(const std::shared_ptr<Node>& node)
if (auto squeeze_i = as_type_ptr<opset3::Squeeze>(input))
{
PartialShape data_shape;
if (input->is_parameter())
if (op::is_parameter(input))
{
data_shape = squeeze_i->input(0).get_partial_shape();
}

View File

@ -21,6 +21,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/util/op_annotations.hpp"
#include "ngraph/op/util/op_types.hpp"
using namespace std;
using namespace ngraph;
@ -29,7 +30,7 @@ bool pass::PropagateCacheability::run_on_function(shared_ptr<Function> function)
{
for (auto& node : function->get_ordered_ops())
{
if (node->is_op())
if (op::is_op(node))
{
auto op = static_pointer_cast<op::Op>(node);
NGRAPH_DEBUG << "propagate cacheability: node is " << node->get_name();
@ -40,7 +41,7 @@ bool pass::PropagateCacheability::run_on_function(shared_ptr<Function> function)
op_annotations = op_annotations_factory();
op->set_op_annotations(op_annotations);
}
if (node->is_parameter())
if (op::is_parameter(node))
{
auto parameter = static_pointer_cast<op::Parameter>(node);
op_annotations->set_cacheable(parameter->get_cacheable());
@ -54,7 +55,7 @@ bool pass::PropagateCacheability::run_on_function(shared_ptr<Function> function)
{
auto input_value_node = input.get_source_output().get_node_shared_ptr();
NGRAPH_DEBUG << "propagate cacheability: arg is " << *input_value_node;
if (input_value_node->is_op())
if (op::is_op(input_value_node))
{
auto arg_op = static_pointer_cast<op::Op>(input_value_node);
auto arg_op_annotations = arg_op->get_op_annotations();

View File

@ -35,6 +35,7 @@
#include "ngraph/op/reshape.hpp"
#include "ngraph/op/slice.hpp"
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
#include "ngraph/pattern/op/label.hpp"
#include "ngraph/util.hpp"
@ -181,7 +182,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
continue;
}
NGRAPH_DEBUG << "Processing (swimming) " << n->get_name();
if (n->is_unary_elementwise_arithmetic())
if (op::is_unary_elementwise_arithmetic(n))
{
Swimmer nsw{n->input(0), csw.reshape};
work_queue.push_back(nsw);
@ -549,7 +550,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
{
NGRAPH_DEBUG << "Start: Processing node " << n->get_name();
// collect all Result nodes for a sanity check
if (n->is_output())
if (ngraph::op::is_output(n))
{
results.push_back(n);
}
@ -558,11 +559,11 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
{
sink_reshape(reshape, reorders, reshapes_to_delete);
}
else if (n->is_unary_elementwise_arithmetic())
else if (op::is_unary_elementwise_arithmetic(n))
{
sink_unary(n, reorders, reshapes_to_delete);
}
else if (n->is_binary_elementwise_arithmetic())
else if (op::is_binary_elementwise_arithmetic(n))
{
sink_binary(n, reorders, reshapes_to_delete);
}

View File

@ -17,6 +17,7 @@
#include "ngraph/pass/shape_relevance.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/util/op_types.hpp"
using namespace ngraph;
@ -91,7 +92,7 @@ bool pass::ShapeRelevance::run_on_function(std::shared_ptr<Function> f)
shape_determinants.insert(node);
already_visited.insert(node);
if (node->is_parameter())
if (op::is_parameter(node))
{
auto node_as_param = static_cast<op::Parameter*>(node);
if (!node_as_param->is_relevant_to_shapes())

View File

@ -24,6 +24,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/pass/pass.hpp"
#include "ngraph/pass/visualize_tree.hpp"
#include "ngraph/util.hpp"
@ -348,7 +349,7 @@ static std::string pretty_value(const vector<T>& value)
std::string pass::VisualizeTree::get_constant_value(std::shared_ptr<Node> node, size_t max_elements)
{
if (!node->is_constant())
if (!op::is_constant(node))
return {};
std::stringstream ss;
ss << "{" << node->get_element_type().get_type_name() << "}";
@ -392,7 +393,7 @@ string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
vector<string> attributes;
attributes.push_back("shape=box");
if (node->is_output())
if (ngraph::op::is_output(node))
{
attributes.push_back("color=crimson");
attributes.push_back("penwidth=1.5");

View File

@ -29,6 +29,7 @@
#include "ngraph/op/product.hpp"
#include "ngraph/op/replace_slice.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/type.hpp"
#include "zero_dim_tensor_elimination.hpp"
@ -46,7 +47,8 @@ static bool verify_no_internal_zero_length_ops(shared_ptr<Function> f)
set<Output<Node>> zero_length_source_outputs;
for (auto n : f->get_ordered_ops())
{
if (n->is_output() || n->is_parameter() || n->is_constant() || n->get_output_size() > 1)
if (op::is_output(n) || op::is_parameter(n) || op::is_constant(n) ||
n->get_output_size() > 1)
{
continue;
}
@ -92,7 +94,7 @@ bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f)
// if any `GetOutputElement` is zero-length
// we replace it w/ a signalling constant
// so we don't have to deal w/ multi-output nodes directly
if (n->is_output() || n->is_parameter() || n->get_output_size() > 1)
if (op::is_output(n) || op::is_parameter(n) || n->get_output_size() > 1)
{
continue;
}

View File

@ -23,6 +23,7 @@
#include "ngraph/log.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/parameter.hpp"
#include "ngraph/op/util/op_types.hpp"
namespace ngraph
{
@ -187,7 +188,7 @@ namespace ngraph
return false;
}
if (graph_node->is_commutative())
if (ngraph::op::is_commutative(graph_node))
{
// TODO: [nikolayk] we don't really have to use lexicographically-based perms,
// heap's algo should be faster

View File

@ -85,7 +85,6 @@ namespace ngraph
ValuePredicate get_predicate() const;
bool is_pattern() const override { return true; }
protected:
ValuePredicate m_predicate;
};

View File

@ -1,40 +0,0 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <deque>
#include <sstream>
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/placement.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
std::string ngraph::placement_to_string(Placement placement)
{
switch (placement)
{
case Placement::DEFAULT: return "DEFAULT";
case Placement::INTERPRETER: return "INTERPRETER";
case Placement::CPU: return "CPU";
case Placement::GPU: return "GPU";
case Placement::NNP: return "NNP";
}
throw runtime_error("unhandled placement type");
}

View File

@ -19,6 +19,7 @@
#include "ngraph/op/assign.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/tensor_iterator.hpp"
#include "ngraph/op/util/op_types.hpp"
using namespace ngraph;
@ -67,7 +68,7 @@ std::shared_ptr<Function>
for (auto old_node : f->get_ordered_ops())
{
if (old_node->is_parameter())
if (op::is_parameter(old_node))
{
continue;
}

View File

@ -481,7 +481,7 @@ TEST(autobroadcast, axes_mapping_from_bcast_axes)
const AxisSet broadcast_axes{0, 2};
auto axes_mapping = builder::opset1::get_axes_mapping_output(output_shape, broadcast_axes);
EXPECT_TRUE(axes_mapping.get_node()->is_constant());
EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
EXPECT_EQ(axes_mapping_shape.size(), 2);
EXPECT_EQ(axes_mapping_shape, (Shape{1, 3}));
@ -494,7 +494,7 @@ TEST(autobroadcast, axes_mapping_from_bcast_axes_scalar)
const AxisSet broadcast_axes{0, 1, 2, 3};
auto axes_mapping = builder::opset1::get_axes_mapping_output(output_shape, broadcast_axes);
EXPECT_TRUE(axes_mapping.get_node()->is_constant());
EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
EXPECT_EQ(axes_mapping_shape.size(), 0);
EXPECT_EQ(axes_mapping_shape, (Shape{}));
@ -507,7 +507,7 @@ TEST(autobroadcast, axes_mapping_from_bcast_axes_identical)
const AxisSet broadcast_axes{};
auto axes_mapping = builder::opset1::get_axes_mapping_output(output_shape, broadcast_axes);
EXPECT_TRUE(axes_mapping.get_node()->is_constant());
EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
EXPECT_EQ(axes_mapping_shape.size(), output_shape.size());
EXPECT_EQ(axes_mapping_shape, (Shape{0, 1, 2, 3}));
@ -521,7 +521,7 @@ TEST(autobroadcast, axes_mapping_start_match_axis)
auto axes_mapping =
builder::opset1::get_axes_mapping_output(output_shape, input_shape, start_match_axis);
EXPECT_TRUE(axes_mapping.get_node()->is_constant());
EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
EXPECT_EQ(axes_mapping_shape.size(), 2);
EXPECT_EQ(axes_mapping_shape, (Shape{1, 2}));
@ -535,7 +535,7 @@ TEST(autobroadcast, axes_mapping_start_match_axis_scalar)
auto axes_mapping =
builder::opset1::get_axes_mapping_output(output_shape, input_shape, start_match_axis);
EXPECT_TRUE(axes_mapping.get_node()->is_constant());
EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
EXPECT_EQ(axes_mapping_shape.size(), 0);
EXPECT_EQ(axes_mapping_shape, (Shape{}));
@ -549,7 +549,7 @@ TEST(autobroadcast, axes_mapping_start_match_axis_identical)
auto axes_mapping =
builder::opset1::get_axes_mapping_output(output_shape, input_shape, start_match_axis);
EXPECT_TRUE(axes_mapping.get_node()->is_constant());
EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
EXPECT_EQ(axes_mapping_shape.size(), output_shape.size());
EXPECT_EQ(axes_mapping_shape, (Shape{0, 1, 2, 3}));

View File

@ -35,6 +35,7 @@
// clang-format on
#include "gtest/gtest.h"
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
#include "ngraph/frontend/onnx_import/onnx.hpp"
#include "ngraph/frontend/onnx_import/onnx_utils.hpp"
#include "ngraph/frontend/onnx_import/default_opset.hpp"
@ -376,7 +377,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_missing_input)
std::shared_ptr<ngraph::Node> C = ng_inputs.at(2);
A = A * C;
if (!B->is_null())
if (!ngraph::op::is_null(B))
{
B = B / C;
}
@ -393,7 +394,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_missing_input)
for (const auto& ng_input : ng_inputs)
{
if (!ng_input->is_null())
if (!ngraph::op::is_null(ng_input))
{
result = ng_input * result;
}

View File

@ -20,6 +20,7 @@
#include "ngraph/file_util.hpp"
#include "ngraph/frontend/onnx_import/default_opset.hpp"
#include "ngraph/frontend/onnx_import/onnx.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/pass/constant_folding.hpp"
#include "ngraph/pass/manager.hpp"
#include "util/all_close.hpp"
@ -44,7 +45,7 @@ namespace
for (auto ng_node : ng_function->get_ordered_ops())
{
if (ng_node->is_constant())
if (op::is_constant(ng_node))
{
const auto folded_node = as_type_ptr<default_opset::Constant>(ng_node);
const auto output_values = folded_node->cast_vector<T>();

View File

@ -33,7 +33,7 @@ TEST(op, is_op)
{
auto arg0 = make_shared<op::Parameter>(element::f32, Shape{1});
ASSERT_NE(nullptr, arg0);
EXPECT_TRUE(arg0->is_parameter());
EXPECT_TRUE(op::is_parameter(arg0));
}
TEST(op, is_parameter)
@ -42,7 +42,7 @@ TEST(op, is_parameter)
ASSERT_NE(nullptr, arg0);
auto t0 = make_shared<op::Add>(arg0, arg0);
ASSERT_NE(nullptr, t0);
EXPECT_FALSE(t0->is_parameter());
EXPECT_FALSE(op::is_parameter(t0));
}
TEST(op, provenance_tag)

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@
#include "ngraph/ngraph.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/pass/manager.hpp"
#include "opset0_downgrade.hpp"
#include "opset1_upgrade.hpp"
@ -28,8 +29,8 @@ TEST(opset_transform, opset1_broadcast_upgrade_pass)
ASSERT_TRUE(bcast_v1);
EXPECT_EQ(bcast_v1->get_broadcast_spec(), op::AutoBroadcastSpec());
EXPECT_EQ(bcast_v1->get_broadcast_axes(), (std::make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
ASSERT_TRUE(bcast_v1->input_value(1).get_node()->is_constant());
ASSERT_TRUE(bcast_v1->input_value(2).get_node()->is_constant());
ASSERT_TRUE(op::is_constant(bcast_v1->input_value(1).get_node()));
ASSERT_TRUE(op::is_constant(bcast_v1->input_value(2).get_node()));
EXPECT_EQ(
as_type_ptr<op::Constant>(bcast_v1->input_value(1).get_node_shared_ptr())->get_shape_val(),
(Shape{3, 5, 4, 6}));

View File

@ -34,6 +34,7 @@
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/sum.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
#include "ngraph/pass/manager.hpp"
#include "ngraph/pattern/matcher.hpp"
@ -322,7 +323,7 @@ TEST(pattern, matcher)
auto b = make_shared<op::Parameter>(element::i32, shape);
auto is_bea = [](std::shared_ptr<Node> node) -> bool {
return node->is_binary_elementwise_arithmetic();
return op::is_binary_elementwise_arithmetic(node);
};
auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
auto add_ab = a + b;

View File

@ -20,6 +20,7 @@
#include "ngraph/cpio.hpp"
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
#include "ngraph/except.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/pass/assign_layout.hpp"
#include "ngraph/pass/core_fusion.hpp"
@ -164,7 +165,7 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
for (auto op : m_nodes)
{
event::Duration d2(op->description(), "Interpreter");
if (op->is_parameter())
if (op::is_parameter(op))
{
continue;
}

View File

@ -24,6 +24,7 @@
#include "ngraph/graph_util.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
#include "ngraph/provenance.hpp"
@ -135,7 +136,7 @@ namespace
*node);
const auto& arg_shape = arg_pshape.to_shape();
NGRAPH_CHECK(target_shape_input.get_node_shared_ptr()->is_constant());
NGRAPH_CHECK(op::is_constant(target_shape_input.get_node()));
auto target_shape = node->get_output_shape(0);
NGRAPH_CHECK(node->get_broadcast_axes().first);
@ -250,7 +251,7 @@ namespace
const auto target_shape_input = node->input_value(1).get_node_shared_ptr();
const auto input_rank = node->get_input_partial_shape(0).rank();
if (target_shape_input->is_constant() && node->get_output_partial_shape(0).is_static() &&
if (op::is_constant(target_shape_input) && node->get_output_partial_shape(0).is_static() &&
input_rank.is_static())
{
const auto output_shape = node->get_output_shape(0);
@ -416,12 +417,12 @@ namespace
shared_ptr<Node> op_cast(shared_ptr<op::v1::OneHot> node)
{
const auto indices = node->input_value(0);
const auto depth = node->input_value(1).get_node_shared_ptr();
const auto depth = node->input_value(1).get_node();
auto on_value = node->input_value(2);
auto off_value = node->input_value(3);
const auto axis = node->get_axis();
NGRAPH_CHECK(depth->is_constant(), "depth input must be constant", *node);
NGRAPH_CHECK(op::is_constant(depth), "depth input must be constant", *node);
const auto output_pshape = node->get_output_partial_shape(0);
NGRAPH_CHECK(output_pshape.is_static(), "output shape must be static", *node);
const auto output_shape = output_pshape.to_shape();
@ -529,7 +530,7 @@ namespace
shared_ptr<Node> op_cast(shared_ptr<op::v1::Reverse> node)
{
auto axes_node = node->input_value(1).get_node_shared_ptr();
NGRAPH_CHECK(axes_node->is_constant(),
NGRAPH_CHECK(op::is_constant(axes_node),
"Unable to convert Reverse:v1 to Reverse:v0 "
"if reduction axes are not constant. Node: ",
*node);
@ -685,7 +686,7 @@ namespace
const auto data_shape = data_pshape.to_shape();
const auto order_node = node->input_value(1).get_node_shared_ptr();
NGRAPH_CHECK(order_node->is_constant(),
NGRAPH_CHECK(op::is_constant(order_node),
"Unable to convert Transpose:v1 to Reshape:v0 "
"if order node is not constant. Node: ",
*node);
@ -715,7 +716,7 @@ namespace
{
const auto split_lengths = node->input_value(2).get_node_shared_ptr();
NGRAPH_CHECK(split_lengths->is_constant(),
NGRAPH_CHECK(op::is_constant(split_lengths),
"Unable to convert VariadicSplit:v1 to Split:v0 "
"if 'split_lengths' input is not constant. Node: ",
*node);

View File

@ -23,6 +23,7 @@
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/op/util/op_types.hpp"
#include "ngraph/ops.hpp"
#include "ngraph/provenance.hpp"
#include "op/avg_pool.hpp"
@ -404,7 +405,7 @@ namespace
shared_ptr<Node> op_cast(shared_ptr<op::Softmax> node)
{
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(),
NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
"axes parameter is expected to be a static constant");
AxisSet axes = node->get_axes();
@ -487,9 +488,9 @@ namespace
shared_ptr<Node> op_cast(shared_ptr<op::TopK> node)
{
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(),
NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
"parameter k is expected to be a static constant");
NGRAPH_CHECK(node->input_value(2).get_node_shared_ptr()->is_constant(),
NGRAPH_CHECK(op::is_constant(node->input_value(2).get_node()),
"parameter top_k_axis is expected to be a static constant");
const auto k = node->get_k();

View File

@ -86,6 +86,6 @@ TEST(tensor, output_flag)
for (size_t i = 0; i < f0->get_output_size(); ++i)
{
EXPECT_TRUE(f0->get_output_op(i)->is_output());
EXPECT_TRUE(op::is_output(f0->get_output_op(i)));
}
}