Remove redundant node methods (#1324)
* Remove placement * Removed validate and infer eltwise * Remove is eltwise * Remove support broadcast and decompose * Removed is_op, is_parameter, is_pattern * Fixed code style * Added is_constant and is_output * Removed is_communicative and is_null * Fixed code style * Fixed typo * Fixed comments * Fixed typo * Revert is_parameter, is_output, is_result for OpenCV build
This commit is contained in:
parent
898f0626ad
commit
54ae67414e
@ -32,6 +32,7 @@
|
||||
#include <ngraph/graph_util.hpp>
|
||||
#include <ngraph/op/result.hpp>
|
||||
#include <ngraph/op/parameter.hpp>
|
||||
#include <ngraph/op/util/op_types.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
using namespace InferenceEngine;
|
||||
@ -364,7 +365,7 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
|
||||
auto orderedOps = function->get_ordered_ops();
|
||||
orderedOps.erase(
|
||||
std::remove_if(std::begin(orderedOps), std::end(orderedOps), [] (const std::shared_ptr<ngraph::Node>& node) {
|
||||
return node->is_constant();
|
||||
return ngraph::op::is_constant(node);
|
||||
}),
|
||||
std::end(orderedOps));
|
||||
bool allEmpty = true;
|
||||
@ -401,7 +402,7 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
|
||||
auto NoConstants = [] (std::vector<ngraph::Input<ngraph::Node>>&& inputs) {
|
||||
std::vector<ngraph::Input<ngraph::Node>> result;
|
||||
for (auto&& input : inputs) {
|
||||
if (!(input.get_source_output().get_node()->is_constant())) {
|
||||
if (!(ngraph::op::is_constant(input.get_source_output().get_node()))) {
|
||||
result.emplace_back(std::move(input));
|
||||
}
|
||||
}
|
||||
@ -478,7 +479,7 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
|
||||
InputSet subgraphInputs;
|
||||
// Get all subgraph inputs using just node affinities. Also collect transitive closure
|
||||
for (auto&& node : orderedOps) {
|
||||
if (node->is_parameter()) {
|
||||
if (ngraph::op::is_parameter(node)) {
|
||||
graphInputNodes.insert(node.get());
|
||||
subgraphInputs.insert(Input{node.get(), 0});
|
||||
nodeInputDependencies[node.get()].insert(Input{node.get(), 0});
|
||||
@ -550,7 +551,8 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
|
||||
}
|
||||
auto& nodeSubgraphCyclicInputDependency = nodeSubgraphCyclicInputDependencies[node.get()];
|
||||
for (auto&& subgraphInput : allNodeSubgraphInputs) {
|
||||
if (!subgraphInput.get_node()->is_parameter() && subgraphIds[node.get()] == subgraphIds[InputNode(subgraphInput)]) {
|
||||
if (!ngraph::op::is_parameter(subgraphInput.get_node()) &&
|
||||
subgraphIds[node.get()] == subgraphIds[InputNode(subgraphInput)]) {
|
||||
nodeSubgraphCyclicInputDependency.emplace(subgraphInput);
|
||||
}
|
||||
}
|
||||
@ -585,7 +587,7 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
|
||||
NodeMap<ngraph::Node*> subgraphParameterToPrevResult;
|
||||
std::vector<std::shared_ptr<ngraph::op::Result>> results;
|
||||
for (auto&& input : subgraphInputs) {
|
||||
if (!(input.get_node()->is_parameter())) {
|
||||
if (!ngraph::op::is_parameter(input.get_node())) {
|
||||
auto output = input.get_source_output();
|
||||
output.remove_target_input(input);
|
||||
auto result = std::make_shared<ngraph::op::Result>(output);
|
||||
@ -614,10 +616,10 @@ void HeteroExecutableNetwork::InitNgraph(const InferenceEngine::ICNNNetwork& net
|
||||
for (auto&& subgraphIdPtrValue : subgraphIds) {
|
||||
auto node = subgraphIdPtrValue.first;
|
||||
auto& subgraph = subgraphs[subgraphIdPtrValue.second];
|
||||
if (node->is_output()) {
|
||||
if (ngraph::op::is_output(node)) {
|
||||
subgraph._results.emplace_back(
|
||||
std::dynamic_pointer_cast<ngraph::op::v0::Result>(node->shared_from_this()));
|
||||
} else if (node->is_parameter()) {
|
||||
} else if (ngraph::op::is_parameter(node)) {
|
||||
subgraph._parameters.emplace_back(
|
||||
std::dynamic_pointer_cast<ngraph::op::v0::Parameter>(node->shared_from_this()));
|
||||
}
|
||||
|
@ -26,6 +26,7 @@
|
||||
#include <ngraph/opsets/opset2.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/op/fused/gelu.hpp>
|
||||
#include <ngraph/op/util/op_types.hpp>
|
||||
#include "ngraph_ops/fully_connected.hpp"
|
||||
|
||||
#if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64)
|
||||
@ -227,7 +228,7 @@ void Engine::QueryNetwork(const ICNNNetwork& network, const std::map<std::string
|
||||
if (function != nullptr) {
|
||||
std::unordered_set<std::string> originalOps;
|
||||
for (auto&& node : function->get_ops()) {
|
||||
if (!node->is_constant() && !node->is_parameter() && !node->is_output()) {
|
||||
if (!ngraph::op::is_constant(node) && !ngraph::op::is_parameter(node) && !ngraph::op::is_output(node)) {
|
||||
originalOps.emplace(node->get_friendly_name());
|
||||
}
|
||||
}
|
||||
|
@ -10,6 +10,7 @@
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/op/util/op_types.hpp>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
@ -58,7 +59,7 @@ void ngraph::pass::ConvertReduceToPooling::convert_reduce_to_pooling() {
|
||||
auto input = reduce->input_value(0);
|
||||
|
||||
auto axes_node = reduce->input_value(1).get_node_shared_ptr();
|
||||
if (!axes_node->is_constant()) {
|
||||
if (!ngraph::op::is_constant(axes_node)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -15,6 +15,7 @@
|
||||
#include <graph_tools.hpp>
|
||||
#include <functional_test_utils/plugin_cache.hpp>
|
||||
#include <multi-device/multi_device_config.hpp>
|
||||
#include <ngraph/op/util/op_types.hpp>
|
||||
|
||||
#include "common_test_utils/file_utils.hpp"
|
||||
#include "common_test_utils/unicode_utils.hpp"
|
||||
@ -1382,7 +1383,7 @@ TEST_P(IEClassLoadNetworkTest, QueryNetworkHETEROWithMULTINoThrow_V10) {
|
||||
ASSERT_NE(nullptr, function);
|
||||
std::unordered_set<std::string> expectedLayers;
|
||||
for (auto &&node : function->get_ops()) {
|
||||
if (!node->is_constant() && !node->is_parameter() && !node->is_output()) {
|
||||
if (!ngraph::op::is_constant(node) && !ngraph::op::is_parameter(node) && !ngraph::op::is_output(node)) {
|
||||
expectedLayers.emplace(node->get_friendly_name());
|
||||
}
|
||||
}
|
||||
@ -1419,7 +1420,7 @@ TEST_P(IEClassLoadNetworkTest, QueryNetworkMULTIWithHETERONoThrow_V10) {
|
||||
ASSERT_NE(nullptr, function);
|
||||
std::unordered_set<std::string> expectedLayers;
|
||||
for (auto &&node : function->get_ops()) {
|
||||
if (!node->is_constant() && !node->is_parameter() && !node->is_output()) {
|
||||
if (!ngraph::op::is_constant(node) && !ngraph::op::is_parameter(node) && !ngraph::op::is_output(node)) {
|
||||
expectedLayers.emplace(node->get_friendly_name());
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
//
|
||||
|
||||
#include "hetero/query_network.hpp"
|
||||
#include <ngraph/op/util/op_types.hpp>
|
||||
#include <ngraph/variant.hpp>
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ngraph_functions/subgraph_builders.hpp"
|
||||
@ -27,7 +28,9 @@ TEST_P(QueryNetworkTest, queryNetworkResultContainAllAndOnlyInputLayers) {
|
||||
ASSERT_NE(nullptr, cnnNetwork.getFunction());
|
||||
std::set<std::string> expectedLayers;
|
||||
for (auto&& node : function->get_ops()) {
|
||||
if (!node->is_parameter() && !node->is_constant() && !node->is_output()) {
|
||||
if (!ngraph::op::is_parameter(node) &&
|
||||
!ngraph::op::is_constant(node) &&
|
||||
!ngraph::op::is_output(node)) {
|
||||
expectedLayers.insert(node->get_friendly_name());
|
||||
}
|
||||
}
|
||||
@ -37,4 +40,4 @@ TEST_P(QueryNetworkTest, queryNetworkResultContainAllAndOnlyInputLayers) {
|
||||
}
|
||||
ASSERT_EQ(expectedLayers, actualLayers);
|
||||
}
|
||||
} // namespace HeteroTests
|
||||
} // namespace HeteroTests
|
||||
|
@ -4,6 +4,7 @@
|
||||
//
|
||||
|
||||
#include "hetero/synthetic.hpp"
|
||||
#include <ngraph/op/util/op_types.hpp>
|
||||
#include <ngraph/variant.hpp>
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ngraph_functions/subgraph_builders.hpp"
|
||||
@ -22,7 +23,9 @@ std::vector<FunctionParameter> HeteroSyntheticTest::_singleMajorNodeFunctions{[]
|
||||
for (auto&& builder : builders) {
|
||||
auto function = builder();
|
||||
for (auto&& node : function->get_ordered_ops()) {
|
||||
if (!(node->is_constant()) && !(node->is_parameter()) && !(node->is_output())) {
|
||||
if (!ngraph::op::is_constant(node) &&
|
||||
!(ngraph::op::is_parameter(node)) &&
|
||||
!(ngraph::op::is_output(node))) {
|
||||
result.push_back(FunctionParameter{{node->get_friendly_name()}, function});
|
||||
}
|
||||
}
|
||||
@ -41,7 +44,9 @@ std::vector<FunctionParameter> HeteroSyntheticTest::_randomMajorNodeFunctions{[]
|
||||
for (std::size_t i = 0; i < ordered_ops.size(); ++i) {
|
||||
std::unordered_set<std::string> majorPluginNodeIds;
|
||||
for (auto&& node : ordered_ops) {
|
||||
if (!(node->is_constant()) && !(node->is_parameter()) && !(node->is_output()) && d(e)) {
|
||||
if (!(ngraph::op::is_constant(node)) &&
|
||||
!(ngraph::op::is_parameter(node)) &&
|
||||
!(ngraph::op::is_output(node)) && d(e)) {
|
||||
majorPluginNodeIds.emplace(node->get_friendly_name());
|
||||
}
|
||||
}
|
||||
@ -117,7 +122,9 @@ std::string HeteroSyntheticTest::SetUpAffinity() {
|
||||
auto& pluginParameters = std::get<Plugin>(param);
|
||||
affinities += "\n{\n";
|
||||
for (auto&& node : std::get<Function>(param)._function->get_ordered_ops()) {
|
||||
if (!(node->is_constant()) && !(node->is_parameter()) && !(node->is_output())) {
|
||||
if (!ngraph::op::is_constant(node) &&
|
||||
!(ngraph::op::is_parameter(node)) &&
|
||||
!(ngraph::op::is_output(node))) {
|
||||
std::string affinity;
|
||||
if (std::get<Function>(param)._majorPluginNodeIds.end() !=
|
||||
std::get<Function>(param)._majorPluginNodeIds.find(node->get_friendly_name())) {
|
||||
@ -140,4 +147,4 @@ TEST_P(HeteroSyntheticTest, someLayersToMajorPluginOthersToFallback) {
|
||||
ASSERT_NE(nullptr, cnnNetwork.getFunction());
|
||||
}
|
||||
|
||||
} // namespace HeteroTests
|
||||
} // namespace HeteroTests
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <assert.h>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/op/util/op_types.hpp>
|
||||
#include <ngraph/pass/visualize_tree.hpp>
|
||||
|
||||
std::pair<bool, std::string> compare_functions(const std::shared_ptr<ngraph::Function> & f1, const std::shared_ptr<ngraph::Function> & f2) {
|
||||
@ -75,7 +76,7 @@ void check_rt_info(const std::shared_ptr<ngraph::Function> & f) {
|
||||
|
||||
std::ostringstream err_log;
|
||||
for (auto & op : f->get_ops()) {
|
||||
if (op->is_constant()) continue;
|
||||
if (ngraph::op::is_constant(op)) continue;
|
||||
|
||||
const auto & rt_info = op->get_rt_info();
|
||||
for (const auto & attr_name : attrs_to_check) {
|
||||
@ -94,4 +95,4 @@ void check_rt_info(const std::shared_ptr<ngraph::Function> & f) {
|
||||
void visualize_function(std::shared_ptr<ngraph::Function> f, const std::string & file_name) {
|
||||
std::vector<std::shared_ptr<ngraph::Function> > g{f};
|
||||
ngraph::pass::VisualizeTree(file_name).run_on_module(g);
|
||||
}
|
||||
}
|
||||
|
@ -8,6 +8,7 @@
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/op/util/op_types.hpp>
|
||||
#include <ngraph/specialize_function.hpp>
|
||||
|
||||
#include <ngraph_functions/utils/ngraph_helpers.hpp>
|
||||
@ -126,7 +127,7 @@ std::shared_ptr<Function> foldFunction(const std::shared_ptr<Function> &function
|
||||
|
||||
const auto &foldedFunc = specialize_function(function, paramElementTypes, paramShapes, inBuffers, true, true);
|
||||
for (const auto &op : foldedFunc->get_ops()) {
|
||||
NGRAPH_CHECK(op->is_constant() || op->is_output() || op->is_parameter(),
|
||||
NGRAPH_CHECK(op::is_constant(op) || op::is_output(op) || op::is_parameter(op),
|
||||
"Function was not fully folded to constant state!\n",
|
||||
"At least one non constant node with type ", op->get_type_name(),
|
||||
" present in function.");
|
||||
@ -141,7 +142,7 @@ std::vector<std::vector<std::uint8_t>> getConstData(const std::shared_ptr<Functi
|
||||
const auto &output = function->output(i).get_node_shared_ptr();
|
||||
NGRAPH_CHECK(output->inputs().size() == 1);
|
||||
auto parrentNode = output->input_value(0).get_node_shared_ptr();
|
||||
NGRAPH_CHECK(parrentNode->is_constant(), "Function was not fully folded to constant state!\n",
|
||||
NGRAPH_CHECK(op::is_constant(parrentNode), "Function was not fully folded to constant state!\n",
|
||||
"Parent node of one of results is not constant and has type ", parrentNode->get_type_name());
|
||||
|
||||
const auto data = std::dynamic_pointer_cast<opset1::Constant>(parrentNode)->get_data_ptr<std::uint8_t>();
|
||||
|
@ -30,6 +30,7 @@
|
||||
#include "ngraph/check.hpp"
|
||||
#include "ngraph/except.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/opsets/opset.hpp"
|
||||
#include "node_factory.hpp"
|
||||
#include "tensor_iterator_builder.hpp"
|
||||
@ -55,7 +56,7 @@ namespace
|
||||
std::shared_ptr<ngraph::Node>(m_opset.create(op_type_name));
|
||||
|
||||
NGRAPH_CHECK(op_node != nullptr, "Couldn't create operator: ", op_type_name);
|
||||
NGRAPH_CHECK(!op_node->is_constant(),
|
||||
NGRAPH_CHECK(!ngraph::op::is_constant(op_node),
|
||||
"Currently NodeFactory doesn't support Constant node: ",
|
||||
op_type_name);
|
||||
|
||||
|
@ -426,6 +426,8 @@ set (SRC
|
||||
op/util/binary_elementwise_logical.hpp
|
||||
op/util/broadcast_base.cpp
|
||||
op/util/broadcast_base.hpp
|
||||
op/util/elementwise_args.cpp
|
||||
op/util/elementwise_args.hpp
|
||||
op/util/embeddingbag_packed_base.cpp
|
||||
op/util/embeddingbag_packed_base.hpp
|
||||
op/util/embeddingbag_offsets_base.cpp
|
||||
@ -447,6 +449,8 @@ set (SRC
|
||||
op/util/unary_elementwise_arithmetic.cpp
|
||||
op/util/unary_elementwise_arithmetic.hpp
|
||||
op/util/variable.hpp
|
||||
op/util/op_types.cpp
|
||||
op/util/op_types.hpp
|
||||
ops.hpp
|
||||
opsets/opset.cpp
|
||||
partial_shape.cpp
|
||||
@ -559,8 +563,6 @@ set (SRC
|
||||
pattern/op/skip.hpp
|
||||
pattern/op/true.cpp
|
||||
pattern/op/true.hpp
|
||||
placement.cpp
|
||||
placement.hpp
|
||||
provenance.cpp
|
||||
provenance.hpp
|
||||
rank.hpp
|
||||
|
@ -31,3 +31,13 @@ namespace ngraph
|
||||
}
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
||||
|
||||
bool ngraph::op::is_null(const ngraph::Node* node)
|
||||
{
|
||||
return dynamic_cast<const ngraph::onnx_import::NullNode*>(node) != nullptr;
|
||||
}
|
||||
|
||||
bool ngraph::op::is_null(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
return is_null(node.get());
|
||||
}
|
||||
|
@ -19,9 +19,17 @@
|
||||
#include <memory>
|
||||
|
||||
#include "ngraph/node.hpp"
|
||||
#include "utils/onnx_importer_visibility.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
ONNX_IMPORTER_API
|
||||
bool is_null(const ngraph::Node* node);
|
||||
ONNX_IMPORTER_API
|
||||
bool is_null(const std::shared_ptr<ngraph::Node>& node);
|
||||
}
|
||||
namespace onnx_import
|
||||
{
|
||||
/// \brief Represents a missing optional input or output of an ONNX node
|
||||
@ -40,7 +48,6 @@ namespace ngraph
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NullNode() = default;
|
||||
|
||||
bool is_null() const final override { return true; }
|
||||
virtual std::shared_ptr<Node>
|
||||
copy_with_new_args(const NodeVector& new_args) const override;
|
||||
};
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "clip.hpp"
|
||||
#include "default_opset.hpp"
|
||||
#include "ngraph/builder/make_constant.hpp"
|
||||
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -57,7 +58,7 @@ namespace ngraph
|
||||
|
||||
// If second input is provided, assign to min input, otherwise set lowest
|
||||
// numeric limit of double as min input.
|
||||
if (inputs.size() > 1 && !inputs.at(1)->is_null())
|
||||
if (inputs.size() > 1 && !ngraph::op::is_null(inputs.at(1)))
|
||||
{
|
||||
min = inputs.at(1);
|
||||
}
|
||||
@ -69,7 +70,7 @@ namespace ngraph
|
||||
|
||||
// If third input is provided, assign to max input, otherwise set maximum
|
||||
// numeric limit of double as max input.
|
||||
if (inputs.size() == 3 && !inputs.at(2)->is_null())
|
||||
if (inputs.size() == 3 && !ngraph::op::is_null(inputs.at(2)))
|
||||
{
|
||||
max = inputs.at(2);
|
||||
}
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include "dequantize_linear.hpp"
|
||||
#include "ngraph/axis_set.hpp"
|
||||
#include "ngraph/builder/make_constant.hpp"
|
||||
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
|
||||
#include "ngraph/op/convert.hpp"
|
||||
#include "ngraph/op/dequantize.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
@ -37,7 +38,7 @@ namespace ngraph
|
||||
{
|
||||
std::shared_ptr<ngraph::Node> get_zero_point(const NodeVector& inputs)
|
||||
{
|
||||
if (inputs.size() == 3 && !inputs[2]->is_null())
|
||||
if (inputs.size() == 3 && !ngraph::op::is_null(inputs[2]))
|
||||
{
|
||||
auto zero_point = inputs[2];
|
||||
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "default_opset.hpp"
|
||||
#include "gru.hpp"
|
||||
#include "ngraph/builder/split.hpp"
|
||||
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "utils/recurrent.hpp"
|
||||
|
||||
@ -47,7 +48,7 @@ namespace ngraph
|
||||
const auto& ng_inputs = node.get_ng_inputs();
|
||||
const auto el_type = ng_inputs.at(0)->get_output_element_type(0);
|
||||
|
||||
if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null())
|
||||
if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
|
||||
{
|
||||
auto bias = ng_inputs.at(3);
|
||||
// gates_count * 2 since B is: [Wb, Rb]
|
||||
|
@ -21,7 +21,9 @@
|
||||
#include "core/graph.hpp"
|
||||
#include "default_opset.hpp"
|
||||
#include "exceptions.hpp"
|
||||
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
|
||||
#include "ngraph/function.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "utils/reshape.hpp"
|
||||
|
||||
namespace ngraph
|
||||
@ -52,7 +54,7 @@ namespace ngraph
|
||||
const std::shared_ptr<ngraph::Node>& body_cond)
|
||||
{
|
||||
bool loop_cond_value = false;
|
||||
if (loop_cond->is_constant() &&
|
||||
if (ngraph::op::is_constant(loop_cond) &&
|
||||
loop_cond->get_element_type() == element::boolean)
|
||||
{
|
||||
loop_cond_value = as_type_ptr<default_opset::Constant>(loop_cond)
|
||||
@ -61,7 +63,8 @@ namespace ngraph
|
||||
}
|
||||
// According to ONNX skipped cond input (is_null) means
|
||||
// that is has true value
|
||||
bool is_loop_cond_true = loop_cond->is_null() || loop_cond_value == true;
|
||||
bool is_loop_cond_true =
|
||||
ngraph::op::is_null(loop_cond) || loop_cond_value == true;
|
||||
|
||||
if (!is_loop_cond_true)
|
||||
{
|
||||
@ -76,7 +79,7 @@ namespace ngraph
|
||||
{
|
||||
const auto second_input =
|
||||
body_cond->input_value(1).get_node_shared_ptr();
|
||||
if (second_input->is_constant() &&
|
||||
if (ngraph::op::is_constant(second_input) &&
|
||||
second_input->get_element_type() == element::boolean &&
|
||||
as_type_ptr<default_opset::Constant>(second_input)
|
||||
->cast_vector<bool>()
|
||||
@ -99,7 +102,7 @@ namespace ngraph
|
||||
// At this moment nGraph TensorIterator doesn't have support for conditional
|
||||
// termination of iterations.
|
||||
CHECK_VALID_NODE(node,
|
||||
!trip_count->is_null(),
|
||||
!ngraph::op::is_null(trip_count),
|
||||
"Currently nGraph requires trip count input to be provided.");
|
||||
|
||||
const OutputVector loop_carried_dependencies{std::next(ng_inputs.begin(), 2),
|
||||
|
@ -27,6 +27,7 @@
|
||||
#include "ngraph/builder/reshape.hpp"
|
||||
#include "ngraph/builder/split.hpp"
|
||||
#include "ngraph/enum_names.hpp"
|
||||
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
|
||||
#include "ngraph/frontend/onnx_import/op/lstm.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
@ -91,7 +92,7 @@ namespace ngraph
|
||||
|
||||
// ------ Optional inputs ------
|
||||
// The bias tensor for input gate. Shape [num_directions, 4*hidden_size]
|
||||
if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null())
|
||||
if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
|
||||
{
|
||||
auto bias = ng_inputs.at(3);
|
||||
auto split_bias = builder::opset1::split(bias, 2, 1);
|
||||
@ -106,7 +107,7 @@ namespace ngraph
|
||||
0.f));
|
||||
}
|
||||
// The lengths of the sequences in a batch. Shape [batch_size]
|
||||
if (ng_inputs.size() > 4 && !ng_inputs.at(4)->is_null())
|
||||
if (ng_inputs.size() > 4 && !ngraph::op::is_null(ng_inputs.at(4)))
|
||||
{
|
||||
m_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] = ng_inputs.at(4);
|
||||
}
|
||||
@ -122,7 +123,7 @@ namespace ngraph
|
||||
}
|
||||
// The initial value of the hidden.
|
||||
// Shape [num_directions, batch_size, hidden_size]
|
||||
if (ng_inputs.size() > 5 && !ng_inputs.at(5)->is_null())
|
||||
if (ng_inputs.size() > 5 && !ngraph::op::is_null(ng_inputs.at(5)))
|
||||
{
|
||||
m_map[LSTMInput::LSTM_INPUT_INIT_H] =
|
||||
builder::opset1::reorder_axes(ng_inputs.at(5), {1, 0, 2});
|
||||
@ -136,7 +137,7 @@ namespace ngraph
|
||||
}
|
||||
// The initial value of the cell.
|
||||
// Shape [num_directions, batch_size, hidden_size]
|
||||
if (ng_inputs.size() > 6 && !ng_inputs.at(6)->is_null())
|
||||
if (ng_inputs.size() > 6 && !ngraph::op::is_null(ng_inputs.at(6)))
|
||||
{
|
||||
m_map[LSTMInput::LSTM_INPUT_INIT_C] =
|
||||
builder::opset1::reorder_axes(ng_inputs.at(6), {1, 0, 2});
|
||||
@ -149,7 +150,7 @@ namespace ngraph
|
||||
std::vector<float>(batch_size * num_directions * hidden_size, 0.f));
|
||||
}
|
||||
// The weight tensor for peepholes. Shape [num_directions, 3*hidde_size]
|
||||
if (ng_inputs.size() > 7 && !ng_inputs.at(7)->is_null())
|
||||
if (ng_inputs.size() > 7 && !ngraph::op::is_null(ng_inputs.at(7)))
|
||||
{
|
||||
m_map[LSTMInput::LSTM_INPUT_P] = ng_inputs.at(7);
|
||||
}
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/convert.hpp"
|
||||
#include "ngraph/op/pad.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "pad.hpp"
|
||||
#include "utils/convpool.hpp"
|
||||
@ -112,7 +113,7 @@ namespace ngraph
|
||||
data->get_element_type(), ngraph::Shape{}, {0});
|
||||
}
|
||||
|
||||
if (pads->is_constant())
|
||||
if (ngraph::op::is_constant(pads))
|
||||
{
|
||||
std::vector<std::int64_t> pads_vector =
|
||||
ngraph::as_type_ptr<default_opset::Constant>(pads)
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "resize.hpp"
|
||||
#include "default_opset.hpp"
|
||||
#include "exceptions.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -75,7 +76,7 @@ namespace ngraph
|
||||
attrs.mode = mode;
|
||||
attrs.align_corners = false;
|
||||
|
||||
if (scales->is_constant() && data_shape.is_static())
|
||||
if (ngraph::op::is_constant(scales) && data_shape.is_static())
|
||||
{
|
||||
const auto scales_const =
|
||||
as_type_ptr<default_opset::Constant>(scales->shared_from_this());
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include "gather.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "utils/common.hpp"
|
||||
|
||||
namespace
|
||||
@ -190,7 +191,8 @@ namespace ngraph
|
||||
if (inputs.size() >= 4) // axes input provided
|
||||
{
|
||||
axes = inputs.at(3);
|
||||
CHECK_VALID_NODE(node, axes->is_constant(), "Axes input must be constant");
|
||||
CHECK_VALID_NODE(
|
||||
node, ngraph::op::is_constant(axes), "Axes input must be constant");
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -18,6 +18,7 @@
|
||||
|
||||
#include "default_opset.hpp"
|
||||
#include "exceptions.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "upsample.hpp"
|
||||
|
||||
namespace ngraph
|
||||
@ -137,7 +138,7 @@ namespace ngraph
|
||||
attrs.axes.insert(ax);
|
||||
}
|
||||
|
||||
if (scales->is_constant() && data_shape.is_static())
|
||||
if (ngraph::op::is_constant(scales) && data_shape.is_static())
|
||||
{
|
||||
const auto scales_const =
|
||||
as_type_ptr<default_opset::Constant>(scales->shared_from_this());
|
||||
|
@ -24,6 +24,7 @@
|
||||
#include "ngraph/builder/split.hpp"
|
||||
#include "ngraph/check.hpp"
|
||||
#include "ngraph/enum_names.hpp"
|
||||
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
|
||||
#include "recurrent.hpp"
|
||||
|
||||
namespace ngraph
|
||||
@ -60,7 +61,7 @@ namespace ngraph
|
||||
const std::size_t batch_size = m_map[OpInput::X]->get_shape().at(1);
|
||||
const std::size_t num_directions = m_map[OpInput::W]->get_shape().front();
|
||||
|
||||
if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null())
|
||||
if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
|
||||
{
|
||||
auto bias = ng_inputs.at(3);
|
||||
auto split_bias = builder::opset1::split(bias, 2, 1);
|
||||
@ -71,7 +72,7 @@ namespace ngraph
|
||||
m_map[OpInput::B] = std::make_shared<default_opset::Constant>(
|
||||
el_type, Shape{num_directions, gates_count * hidden_size}, 0.f);
|
||||
}
|
||||
if (ng_inputs.size() > 4 && !ng_inputs.at(4)->is_null())
|
||||
if (ng_inputs.size() > 4 && !ngraph::op::is_null(ng_inputs.at(4)))
|
||||
{
|
||||
m_map[OpInput::SEQ_LENGTHS] = ng_inputs.at(4);
|
||||
}
|
||||
@ -81,7 +82,7 @@ namespace ngraph
|
||||
element::i32, Shape{batch_size}, m_map[OpInput::X]->get_shape().at(0));
|
||||
}
|
||||
// The initial value of the hidden.
|
||||
if (ng_inputs.size() > 5 && !ng_inputs.at(5)->is_null())
|
||||
if (ng_inputs.size() > 5 && !ngraph::op::is_null(ng_inputs.at(5)))
|
||||
{
|
||||
m_map[OpInput::INIT_H] = ng_inputs.at(5);
|
||||
}
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "default_opset.hpp"
|
||||
#include "ngraph/builder/make_constant.hpp"
|
||||
#include "ngraph/builder/reshape.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "reshape.hpp"
|
||||
|
||||
@ -102,7 +103,7 @@ namespace ngraph
|
||||
node_shape);
|
||||
|
||||
// If node is a Constant, recreate as Constant with Shape{}
|
||||
if (node->is_constant())
|
||||
if (ngraph::op::is_constant(node))
|
||||
{
|
||||
const auto value =
|
||||
ngraph::as_type_ptr<default_opset::Constant>(node)->get_data_ptr();
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include "ngraph/function.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -83,7 +84,7 @@ void Function::validate_nodes_and_infer_types()
|
||||
node->revalidate_and_infer_types();
|
||||
|
||||
// If we find a parameter make sure it is in the list of parameters of the function
|
||||
if (node->is_parameter())
|
||||
if (op::is_parameter(node))
|
||||
{
|
||||
auto it = std::find(m_parameters.begin(), m_parameters.end(), node);
|
||||
if (it == m_parameters.end())
|
||||
|
@ -29,6 +29,7 @@
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/result.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/visualize_tree.hpp"
|
||||
#include "ngraph/provenance.hpp"
|
||||
@ -130,7 +131,7 @@ void ngraph::replace_node(std::shared_ptr<Node> target,
|
||||
std::shared_ptr<Node> replacement,
|
||||
const std::vector<int64_t>& output_order)
|
||||
{
|
||||
if (target->is_output())
|
||||
if (ngraph::op::is_output(target))
|
||||
{
|
||||
throw ngraph_error("Result nodes cannot be replaced.");
|
||||
}
|
||||
@ -185,7 +186,7 @@ void ngraph::replace_node(std::shared_ptr<Node> target,
|
||||
void ngraph::replace_node(const std::shared_ptr<Node>& target,
|
||||
const OutputVector& replacement_values)
|
||||
{
|
||||
if (target->is_output())
|
||||
if (ngraph::op::is_output(target))
|
||||
{
|
||||
throw ngraph_error("Result nodes cannot be replaced.");
|
||||
}
|
||||
@ -258,7 +259,7 @@ bool ngraph::is_post_dominated(Node* X, Node* Y)
|
||||
{
|
||||
ngraph::Node* curr = stack.top();
|
||||
visited.insert(curr);
|
||||
if (curr->is_output())
|
||||
if (ngraph::op::is_output(curr))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@ -465,7 +466,6 @@ pair<shared_ptr<op::Result>, shared_ptr<op::Parameter>>
|
||||
// Make parameter node
|
||||
shared_ptr<op::Parameter> par_node = make_shared<op::Parameter>(
|
||||
src_node->get_output_element_type(0), src_node->get_output_shape(0));
|
||||
par_node->set_placement(dst_node->get_placement());
|
||||
|
||||
// Fix input / output among src, dst and par
|
||||
std::vector<Input<Node>> dst_inputs = get_inputs_from(*src_node, *dst_node);
|
||||
@ -489,7 +489,6 @@ pair<shared_ptr<op::Result>, shared_ptr<op::Parameter>>
|
||||
// Add res node
|
||||
// Add [4], [5], [6], [7]
|
||||
shared_ptr<op::Result> res_node = make_shared<op::Result>(src_node);
|
||||
res_node->set_placement(src_node->get_placement());
|
||||
|
||||
return make_pair(res_node, par_node);
|
||||
}
|
||||
@ -641,7 +640,7 @@ bool ngraph::is_used(Node* node)
|
||||
ngraph::Node* n = stack.top();
|
||||
if (instances_seen.count(n) == 0)
|
||||
{
|
||||
if (n->is_output())
|
||||
if (ngraph::op::is_output(n))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
@ -675,7 +674,7 @@ bool ngraph::possibly_overwritten(Node* node)
|
||||
{
|
||||
for (auto& input : output.get_target_inputs())
|
||||
{
|
||||
if (input.get_node()->is_op())
|
||||
if (op::is_op(input.get_node()))
|
||||
{
|
||||
auto op = static_cast<ngraph::op::Op*>(input.get_node());
|
||||
if (auto op_annotations = op->get_op_annotations())
|
||||
@ -714,7 +713,7 @@ bool ngraph::is_valid_rank(const std::shared_ptr<Node>& node, std::vector<size_t
|
||||
|
||||
bool ngraph::compare_constants(const std::shared_ptr<Node>& n1, const std::shared_ptr<Node>& n2)
|
||||
{
|
||||
if (!(n1->is_constant() && n2->is_constant()))
|
||||
if (!(op::is_constant(n1) && op::is_constant(n2)))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
@ -29,7 +29,6 @@
|
||||
#include "ngraph/check.hpp"
|
||||
#include "ngraph/function.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/placement.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -406,10 +405,6 @@ namespace ngraph
|
||||
NGRAPH_API
|
||||
std::shared_ptr<ngraph::Function> clone_function(const ngraph::Function& func);
|
||||
|
||||
// Assert that nodes in the function is colocated and return that placement
|
||||
NGRAPH_API
|
||||
Placement get_colocated_function_placement(std::shared_ptr<Function> func);
|
||||
|
||||
NGRAPH_API
|
||||
std::pair<std::shared_ptr<op::Result>, std::shared_ptr<op::v0::Parameter>>
|
||||
insert_result_parameter_split(const std::shared_ptr<Node>& src_node,
|
||||
|
@ -28,7 +28,6 @@
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/result.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
#include "ngraph/placement.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -298,16 +297,6 @@ const std::deque<descriptor::Output>& Node::get_outputs() const
|
||||
return m_outputs;
|
||||
}
|
||||
|
||||
bool Node::is_output() const
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Node::is_constant() const
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::string& Node::description() const
|
||||
{
|
||||
// Terrible transitional kludge to keep description working while we change
|
||||
@ -339,16 +328,6 @@ void Node::set_friendly_name(const string& name)
|
||||
m_friendly_name = name;
|
||||
}
|
||||
|
||||
Placement Node::get_placement() const
|
||||
{
|
||||
return m_placement;
|
||||
}
|
||||
|
||||
void Node::set_placement(Placement placement)
|
||||
{
|
||||
m_placement = placement;
|
||||
}
|
||||
|
||||
void Node::add_provenance_group_member(const shared_ptr<Node>& node)
|
||||
{
|
||||
m_provenance_group.insert(node);
|
||||
@ -865,76 +844,6 @@ ResultVector ngraph::as_result_vector(const OutputVector& values)
|
||||
return result;
|
||||
}
|
||||
|
||||
std::tuple<element::Type, PartialShape>
|
||||
Node::validate_and_infer_elementwise_args(const op::AutoBroadcastSpec& autob)
|
||||
{
|
||||
element::Type element_type = get_input_element_type(0);
|
||||
PartialShape pshape = get_input_partial_shape(0);
|
||||
|
||||
if (get_input_size() > 1)
|
||||
{
|
||||
for (size_t i = 1; i < get_input_size(); ++i)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
element::Type::merge(element_type, element_type, get_input_element_type(i)),
|
||||
"Argument element types are inconsistent.");
|
||||
|
||||
if (autob.m_type == op::AutoBroadcastType::NONE)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
PartialShape::merge_into(pshape, get_input_partial_shape(i)),
|
||||
"Argument shapes are inconsistent.");
|
||||
}
|
||||
else if (autob.m_type == op::AutoBroadcastType::NUMPY ||
|
||||
autob.m_type == op::AutoBroadcastType::PDPD)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
PartialShape::broadcast_merge_into(pshape, get_input_partial_shape(i), autob),
|
||||
"Argument shapes are inconsistent.");
|
||||
}
|
||||
else
|
||||
{
|
||||
NODE_VALIDATION_CHECK(this, false, "Unsupported auto broadcast specification");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(element_type, pshape);
|
||||
}
|
||||
|
||||
void Node::validate_and_infer_elementwise_arithmetic(const op::AutoBroadcastSpec& autob)
|
||||
{
|
||||
auto args_et_pshape = validate_and_infer_elementwise_args(autob);
|
||||
element::Type& args_et = std::get<0>(args_et_pshape);
|
||||
PartialShape& args_pshape = std::get<1>(args_et_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
args_et.is_dynamic() || args_et != element::boolean,
|
||||
"Arguments cannot have boolean element type (argument element type: ",
|
||||
args_et,
|
||||
").");
|
||||
|
||||
set_output_type(0, args_et, args_pshape);
|
||||
}
|
||||
|
||||
void Node::validate_and_infer_elementwise_logical(const op::AutoBroadcastSpec& autob)
|
||||
{
|
||||
auto args_et_pshape = validate_and_infer_elementwise_args(autob);
|
||||
element::Type& args_et = std::get<0>(args_et_pshape);
|
||||
PartialShape& args_pshape = std::get<1>(args_et_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
args_et.is_dynamic() || args_et == element::boolean,
|
||||
"Operands for logical operators must have boolean element type but have element type ",
|
||||
args_et,
|
||||
".");
|
||||
|
||||
set_output_type(0, element::boolean, args_pshape);
|
||||
}
|
||||
|
||||
bool Node::match_value(pattern::Matcher* matcher,
|
||||
const Output<Node>& pattern_value,
|
||||
const Output<Node>& graph_value)
|
||||
|
@ -42,7 +42,6 @@
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/op/util/op_annotations.hpp"
|
||||
#include "ngraph/output_vector.hpp"
|
||||
#include "ngraph/placement.hpp"
|
||||
#include "ngraph/strides.hpp"
|
||||
#include "ngraph/type.hpp"
|
||||
|
||||
@ -154,13 +153,6 @@ namespace ngraph
|
||||
using type_info_t = DiscreteTypeInfo;
|
||||
|
||||
protected:
|
||||
std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args(
|
||||
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
|
||||
void validate_and_infer_elementwise_arithmetic(
|
||||
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
|
||||
void validate_and_infer_elementwise_logical(
|
||||
const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
|
||||
|
||||
/// \brief Construct an unitialized Node
|
||||
Node() {}
|
||||
/// \brief Construct an unitialized Node
|
||||
@ -176,19 +168,14 @@ namespace ngraph
|
||||
void safe_delete(NodeVector& nodes, bool recurse);
|
||||
|
||||
public:
|
||||
virtual bool is_parameter() const { return false; }
|
||||
virtual bool is_output() const { return false; }
|
||||
virtual bool is_constant() const { return false; }
|
||||
virtual ~Node();
|
||||
|
||||
virtual bool visit_attributes(AttributeVisitor& visitor) { return false; }
|
||||
virtual bool is_unary_elementwise_arithmetic() const { return false; }
|
||||
virtual bool is_binary_elementwise_arithmetic() const { return false; }
|
||||
virtual bool is_binary_elementwise_comparison() const { return false; }
|
||||
virtual bool is_binary_elementwise_logical() const { return false; }
|
||||
/// \returns true if node supports autobroadcast operations
|
||||
virtual bool supports_auto_broadcast() const { return false; }
|
||||
/// \returns the autobroadcasr spec
|
||||
virtual const op::AutoBroadcastSpec& get_autob() const;
|
||||
/// \returns true if the node can decompose
|
||||
virtual bool supports_decompose() const { return false; }
|
||||
/// \brief Evaluates the op on input_values putting results in output_values
|
||||
/// \returns true if successful
|
||||
virtual bool evaluate(const HostTensorVector& output_values,
|
||||
@ -276,15 +263,7 @@ namespace ngraph
|
||||
const element::Type& element_type,
|
||||
const PartialShape& pshape);
|
||||
|
||||
virtual bool is_parameter() const { return false; }
|
||||
virtual bool is_output() const;
|
||||
virtual bool is_constant() const;
|
||||
virtual bool is_null() const { return false; }
|
||||
virtual bool is_op() const { return false; }
|
||||
virtual bool is_pattern() const { return false; }
|
||||
virtual bool is_commutative() const { return false; }
|
||||
virtual bool is_dynamic() const;
|
||||
virtual bool has_state() const { return false; }
|
||||
size_t get_instance_id() const { return m_instance_id; }
|
||||
/// \brief Writes a description of a node to a stream
|
||||
/// \param os The stream; should be returned
|
||||
@ -440,12 +419,6 @@ namespace ngraph
|
||||
/// True if this and node have one output with same element type and shape
|
||||
bool has_same_type(std::shared_ptr<const Node> node) const;
|
||||
|
||||
/// Get device placement
|
||||
Placement get_placement() const;
|
||||
|
||||
/// Set device placement
|
||||
void set_placement(Placement placement);
|
||||
|
||||
using RTMap = std::map<std::string, std::shared_ptr<Variant>>;
|
||||
|
||||
RTMap& get_rt_info() { return m_rt_info; }
|
||||
@ -557,7 +530,6 @@ namespace ngraph
|
||||
std::set<std::shared_ptr<Node>> m_provenance_group;
|
||||
std::deque<descriptor::Input> m_inputs;
|
||||
std::deque<descriptor::Output> m_outputs;
|
||||
Placement m_placement = Placement::DEFAULT;
|
||||
std::shared_ptr<ngraph::op::util::OpAnnotations> m_op_annotations;
|
||||
std::map<std::string, std::shared_ptr<Variant>> m_rt_info;
|
||||
};
|
||||
|
@ -57,7 +57,6 @@ namespace ngraph
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
@ -97,7 +96,6 @@ namespace ngraph
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
size_t get_version() const override { return 1; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
|
@ -54,7 +54,6 @@ namespace ngraph
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
|
@ -40,6 +40,7 @@ namespace ngraph
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"Constant", 0};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
bool is_constant() const override { return true; }
|
||||
Constant() = default;
|
||||
|
||||
/// \brief Initialize a constant from tensor
|
||||
@ -442,7 +443,6 @@ namespace ngraph
|
||||
get_data_ptr());
|
||||
}
|
||||
|
||||
bool is_constant() const override { return true; }
|
||||
bool get_all_data_elements_bitwise_identical() const
|
||||
{
|
||||
return m_all_elements_bitwise_identical;
|
||||
|
@ -18,6 +18,7 @@
|
||||
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/crop_and_resize.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -94,7 +95,8 @@ void op::CropAndResize::validate_and_infer_types()
|
||||
auto& crop_size_et = crop_size.get_element_type();
|
||||
NODE_VALIDATION_CHECK(this, crop_size_et.is_integral(), "crops_size must be integral");
|
||||
auto crop_size_node = crop_size.get_node_shared_ptr();
|
||||
NODE_VALIDATION_CHECK(this, crop_size_node->is_constant(), "crop_size must be a constant");
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, ngraph::op::is_constant(crop_size_node), "crop_size must be a constant");
|
||||
auto crop_size_const = static_pointer_cast<op::Constant>(crop_size_node);
|
||||
if (crop_size_et == element::i8)
|
||||
{
|
||||
|
@ -63,7 +63,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
@ -111,7 +110,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/fused/batch_to_space.hpp"
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -152,13 +153,13 @@ void ngraph::op::v1::BatchToSpace::pre_validate_and_infer_types()
|
||||
auto block = input_value(1);
|
||||
auto crops_begin = input_value(2);
|
||||
auto crops_end = input_value(3);
|
||||
NGRAPH_CHECK(block.get_node_shared_ptr()->is_constant(),
|
||||
NGRAPH_CHECK(ngraph::op::is_constant(block.get_node()),
|
||||
"block_shape input node is expected to be a static constant");
|
||||
|
||||
NGRAPH_CHECK(crops_begin.get_node_shared_ptr()->is_constant(),
|
||||
NGRAPH_CHECK(ngraph::op::is_constant(crops_begin.get_node()),
|
||||
"crops_begin input node is expected to be a static constant");
|
||||
|
||||
NGRAPH_CHECK(crops_end.get_node_shared_ptr()->is_constant(),
|
||||
NGRAPH_CHECK(ngraph::op::is_constant(crops_end.get_node()),
|
||||
"crops_end input node is expected to be a static constant");
|
||||
|
||||
const auto& data_type = get_input_element_type(0);
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include "ngraph/op/divide.hpp"
|
||||
#include "ngraph/op/fused/normalize_l2.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -55,7 +56,7 @@ void op::NormalizeL2::pre_validate_and_infer_types()
|
||||
const auto& input_rank = input_pshape.rank();
|
||||
const auto& axes_rank = axes_pshape.rank();
|
||||
|
||||
NODE_VALIDATION_CHECK(this, axes_node->is_constant(), "Input axes must be Constant type");
|
||||
NODE_VALIDATION_CHECK(this, op::is_constant(axes_node), "Input axes must be Constant type");
|
||||
|
||||
if (axes_rank.is_static())
|
||||
{
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/fused/space_to_batch.hpp"
|
||||
#include "ngraph/op/pad.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -135,13 +136,13 @@ void ngraph::op::v1::SpaceToBatch::pre_validate_and_infer_types()
|
||||
auto block = input_value(1);
|
||||
auto crops_begin = input_value(2);
|
||||
auto crops_end = input_value(3);
|
||||
NGRAPH_CHECK(block.get_node_shared_ptr()->is_constant(),
|
||||
NGRAPH_CHECK(ngraph::op::is_constant(block.get_node()),
|
||||
"block_shape input node is expected to be a static constant");
|
||||
|
||||
NGRAPH_CHECK(crops_begin.get_node_shared_ptr()->is_constant(),
|
||||
NGRAPH_CHECK(ngraph::op::is_constant(crops_begin.get_node()),
|
||||
"crops_begin input node is expected to be a static constant");
|
||||
|
||||
NGRAPH_CHECK(crops_end.get_node_shared_ptr()->is_constant(),
|
||||
NGRAPH_CHECK(ngraph::op::is_constant(crops_end.get_node()),
|
||||
"crops_end input node is expected to be a static constant");
|
||||
|
||||
const auto& data_type = get_input_element_type(0);
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/fused/unsqueeze.hpp"
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/runtime/reference/copy.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
|
||||
@ -42,7 +43,7 @@ void op::Unsqueeze::pre_validate_and_infer_types()
|
||||
|
||||
const auto axes_node = input_value(1).get_node_shared_ptr();
|
||||
|
||||
if (data_rank.is_dynamic() || !axes_node->is_constant())
|
||||
if (data_rank.is_dynamic() || !op::is_constant(axes_node))
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
|
||||
return;
|
||||
|
@ -47,7 +47,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
@ -80,7 +79,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
|
@ -47,7 +47,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
@ -80,7 +79,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
|
@ -47,7 +47,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
@ -80,7 +79,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "ngraph/op/non_max_suppression.hpp"
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -170,7 +171,7 @@ void op::v1::NonMaxSuppression::validate_and_infer_types()
|
||||
|
||||
const auto max_output_boxes_per_class = input_value(2).get_node_shared_ptr();
|
||||
if (num_boxes_boxes.is_static() && scores_ps[1].is_static() &&
|
||||
max_output_boxes_per_class->is_constant())
|
||||
op::is_constant(max_output_boxes_per_class))
|
||||
{
|
||||
const auto num_boxes = num_boxes_boxes.get_length();
|
||||
const auto max_output_boxes_per_class = max_boxes_output_from_input();
|
||||
@ -384,7 +385,7 @@ void op::v3::NonMaxSuppression::validate_and_infer_types()
|
||||
const auto num_boxes_boxes = boxes_ps[1];
|
||||
const auto max_output_boxes_per_class_node = input_value(2).get_node_shared_ptr();
|
||||
if (num_boxes_boxes.is_static() && scores_ps[1].is_static() &&
|
||||
max_output_boxes_per_class_node->is_constant())
|
||||
op::is_constant(max_output_boxes_per_class_node))
|
||||
{
|
||||
const auto num_boxes = num_boxes_boxes.get_length();
|
||||
const auto num_classes = scores_ps[1].get_length();
|
||||
@ -517,7 +518,7 @@ void op::v4::NonMaxSuppression::validate_and_infer_types()
|
||||
const auto num_boxes_boxes = boxes_ps[1];
|
||||
const auto max_output_boxes_per_class_node = input_value(2).get_node_shared_ptr();
|
||||
if (num_boxes_boxes.is_static() && scores_ps[0].is_static() && scores_ps[1].is_static() &&
|
||||
max_output_boxes_per_class_node->is_constant())
|
||||
op::is_constant(max_output_boxes_per_class_node))
|
||||
{
|
||||
const auto num_boxes = num_boxes_boxes.get_length();
|
||||
const auto num_classes = scores_ps[1].get_length();
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include "ngraph/op/not.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/elementwise_args.hpp"
|
||||
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/runtime/reference/not.hpp"
|
||||
@ -39,7 +40,7 @@ bool ngraph::op::v1::LogicalNot::visit_attributes(AttributeVisitor& visitor)
|
||||
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
|
||||
void op::v1::LogicalNot::validate_and_infer_types()
|
||||
{
|
||||
auto args_et_pshape = validate_and_infer_elementwise_args();
|
||||
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
|
||||
element::Type& args_et = std::get<0>(args_et_pshape);
|
||||
PartialShape& args_pshape = std::get<1>(args_et_pshape);
|
||||
|
||||
@ -106,7 +107,7 @@ op::v0::Not::Not(const Output<Node>& arg)
|
||||
// TODO(amprocte): Update this to allow only boolean, for consistency with logical binops.
|
||||
void op::v0::Not::validate_and_infer_types()
|
||||
{
|
||||
auto args_et_pshape = validate_and_infer_elementwise_args();
|
||||
auto args_et_pshape = ngraph::op::util::validate_and_infer_elementwise_args(this);
|
||||
element::Type& args_et = std::get<0>(args_et_pshape);
|
||||
PartialShape& args_pshape = std::get<1>(args_et_pshape);
|
||||
|
||||
|
@ -47,7 +47,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
@ -79,7 +78,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include "ngraph/op/one_hot.hpp"
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -152,7 +153,7 @@ void op::v1::OneHot::validate_and_infer_types()
|
||||
const auto& depth = input_value(1).get_node_shared_ptr();
|
||||
PartialShape result_shape{PartialShape::dynamic()};
|
||||
|
||||
if (indices_shape.is_static() && indices_shape.rank().is_static() && depth->is_constant())
|
||||
if (indices_shape.is_static() && indices_shape.rank().is_static() && op::is_constant(depth))
|
||||
{
|
||||
const auto indices_rank = indices_shape.rank().get_length();
|
||||
|
||||
|
@ -27,8 +27,6 @@ namespace ngraph
|
||||
/// Root of all actual ops
|
||||
class NGRAPH_API Op : public Node
|
||||
{
|
||||
public:
|
||||
virtual bool is_op() const override { return true; }
|
||||
protected:
|
||||
Op()
|
||||
: Node()
|
||||
|
@ -52,7 +52,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
@ -84,7 +83,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include "ngraph/except.hpp"
|
||||
#include "ngraph/op/broadcast.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -271,8 +272,8 @@ void op::v1::Pad::validate_and_infer_types()
|
||||
|
||||
auto pads_begin_node = input_value(1).get_node_shared_ptr();
|
||||
auto pads_end_node = input_value(2).get_node_shared_ptr();
|
||||
if (arg_shape_rank.is_static() && pads_begin_node->is_constant() &&
|
||||
pads_end_node->is_constant())
|
||||
if (arg_shape_rank.is_static() && op::is_constant(pads_begin_node) &&
|
||||
op::is_constant(pads_end_node))
|
||||
{
|
||||
const auto implied_rank = pads_begin_coord.size();
|
||||
std::vector<Dimension> result_dims(implied_rank, Dimension::dynamic());
|
||||
|
@ -48,7 +48,6 @@ namespace ngraph
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
bool is_parameter() const override { return true; }
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
bool get_cacheable() const { return m_cacheable; }
|
||||
@ -69,7 +68,7 @@ namespace ngraph
|
||||
{
|
||||
m_element_type = element_type;
|
||||
}
|
||||
|
||||
bool is_parameter() const override { return true; }
|
||||
protected:
|
||||
bool m_cacheable;
|
||||
PartialShape m_partial_shape;
|
||||
|
@ -44,14 +44,13 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual bool is_output() const override { return true; }
|
||||
void set_needs_default_layout(bool val) { m_needs_default_layout = val; }
|
||||
bool needs_default_layout() const { return m_needs_default_layout; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
bool constant_fold(OutputVector& output_values,
|
||||
const OutputVector& inputs_values) override;
|
||||
|
||||
bool is_output() const override { return true; }
|
||||
private:
|
||||
bool m_needs_default_layout{false};
|
||||
};
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include "ngraph/function.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/reverse.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -133,7 +134,7 @@ void op::v1::Reverse::validate_and_infer_types()
|
||||
const auto rank = input_rank.get_length();
|
||||
const auto rev_axes_node = input_value(1).get_node_shared_ptr();
|
||||
|
||||
if (rev_axes_node->is_constant())
|
||||
if (op::is_constant(rev_axes_node))
|
||||
{
|
||||
const auto rev_axes_constant = as_type_ptr<op::Constant>(rev_axes_node);
|
||||
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include "ngraph/op/scatter_elements_update.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/runtime/reference/scatter_elements_update.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
|
||||
@ -91,7 +92,7 @@ void op::v3::ScatterElementsUpdate::validate_and_infer_types()
|
||||
" and: ",
|
||||
updates_shape);
|
||||
|
||||
if (input_value(3).get_node_shared_ptr()->is_constant() && data_shape.rank().is_static())
|
||||
if (ngraph::op::is_constant(input_value(3).get_node()) && data_shape.rank().is_static())
|
||||
{
|
||||
const auto axis_input = as_type_ptr<op::v0::Constant>(input_value(3).get_node_shared_ptr());
|
||||
auto axis = axis_input->cast_vector<int64_t>().at(0);
|
||||
|
@ -117,7 +117,6 @@ namespace ngraph
|
||||
{
|
||||
m_auto_broadcast = auto_broadcast;
|
||||
}
|
||||
bool supports_auto_broadcast() const override { return true; }
|
||||
// TODO: Move all uses of get_autob to get_auto_broadcast() and remove this.
|
||||
const AutoBroadcastSpec& get_autob() const override { return m_auto_broadcast; }
|
||||
private:
|
||||
|
@ -25,6 +25,7 @@
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/runtime/reference/softmax.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
@ -53,7 +54,7 @@ op::v0::Softmax::Softmax(const Output<Node>& arg, const Output<Node>& axes)
|
||||
|
||||
bool op::v0::Softmax::are_axes_constant() const
|
||||
{
|
||||
return input_value(1).get_node_shared_ptr()->is_constant();
|
||||
return op::is_constant(input_value(1).get_node());
|
||||
}
|
||||
|
||||
const AxisSet op::v0::Softmax::get_axes() const
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include "ngraph/builder/split.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/split.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -51,7 +52,8 @@ void op::v0::Split::pre_validate_and_infer_types()
|
||||
NODE_VALIDATION_CHECK(this, is_scalar(axis_shape), "The 'axis' input node must be scalar");
|
||||
|
||||
const auto axis_node = input_value(1).get_node_shared_ptr();
|
||||
NODE_VALIDATION_CHECK(this, axis_node->is_constant(), "The 'axis' input node must be constant");
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, op::is_constant(axis_node), "The 'axis' input node must be constant");
|
||||
const auto axis_node_const = as_type_ptr<op::Constant>(axis_node);
|
||||
m_axis = axis_node_const->get_data_ptr<int64_t>()[0];
|
||||
|
||||
@ -142,7 +144,7 @@ void op::v1::Split::validate_and_infer_types()
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, axis_et.is_integral(), "The 'axis' input only accepts integral types");
|
||||
|
||||
if (input_value(1).get_node_shared_ptr()->is_constant() && data_ps.is_static())
|
||||
if (op::is_constant(input_value(1).get_node()) && data_ps.is_static())
|
||||
{
|
||||
const auto axis_input = as_type_ptr<op::Constant>(input_value(1).get_node_shared_ptr());
|
||||
auto axis = axis_input->cast_vector<int64_t>()[0];
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "ngraph/axis_vector.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/topk.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
|
||||
@ -482,7 +483,7 @@ void op::v1::TopK::validate_and_infer_types()
|
||||
this, k_partial_shape.rank().compatible(0), "The 'K' input must be a scalar.");
|
||||
|
||||
size_t k = 0;
|
||||
if (input_value(1).get_node_shared_ptr()->is_constant())
|
||||
if (op::is_constant(input_value(1).get_node()))
|
||||
{
|
||||
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(),
|
||||
get_input_element_type(1));
|
||||
@ -638,7 +639,7 @@ shared_ptr<Node> op::v1::TopK::clone_with_new_inputs(const OutputVector& new_arg
|
||||
size_t op::v1::TopK::get_k() const
|
||||
{
|
||||
size_t k = 0;
|
||||
if (input_value(1).get_node_shared_ptr()->is_constant())
|
||||
if (op::is_constant(input_value(1).get_node()))
|
||||
{
|
||||
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(),
|
||||
get_input_element_type(1));
|
||||
@ -668,7 +669,7 @@ bool op::v1::TopK::evaluate(const HostTensorVector& outputs, const HostTensorVec
|
||||
|
||||
// 2. get value of k - from constant node or from HT
|
||||
size_t k = 0;
|
||||
if (input_value(1).get_node_shared_ptr()->is_constant())
|
||||
if (op::is_constant(input_value(1).get_node()))
|
||||
{
|
||||
k = read_k_from_constant_node(input_value(1).get_node_shared_ptr(),
|
||||
get_input_element_type(1));
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/op/util/elementwise_args.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -33,6 +34,22 @@ op::util::BinaryElementwiseArithmetic::BinaryElementwiseArithmetic(const Output<
|
||||
{
|
||||
}
|
||||
|
||||
void op::util::BinaryElementwiseArithmetic::validate_and_infer_elementwise_arithmetic(
|
||||
const op::AutoBroadcastSpec& autob)
|
||||
{
|
||||
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this, autob);
|
||||
element::Type& args_et = std::get<0>(args_et_pshape);
|
||||
PartialShape& args_pshape = std::get<1>(args_et_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
args_et.is_dynamic() || args_et != element::boolean,
|
||||
"Arguments cannot have boolean element type (argument element type: ",
|
||||
args_et,
|
||||
").");
|
||||
|
||||
set_output_type(0, args_et, args_pshape);
|
||||
}
|
||||
|
||||
void op::util::BinaryElementwiseArithmetic::validate_and_infer_types()
|
||||
{
|
||||
validate_and_infer_elementwise_arithmetic(m_autob);
|
||||
|
@ -69,12 +69,11 @@ namespace ngraph
|
||||
|
||||
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
|
||||
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
|
||||
bool is_binary_elementwise_arithmetic() const override { return true; }
|
||||
bool supports_auto_broadcast() const override { return true; }
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
private:
|
||||
AutoBroadcastSpec m_autob;
|
||||
void validate_and_infer_elementwise_arithmetic(const op::AutoBroadcastSpec& autob);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/op/util/elementwise_args.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -35,7 +36,7 @@ op::util::BinaryElementwiseComparison::BinaryElementwiseComparison(const Output<
|
||||
|
||||
void op::util::BinaryElementwiseComparison::validate_and_infer_types()
|
||||
{
|
||||
auto args_et_pshape = validate_and_infer_elementwise_args(m_autob);
|
||||
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this, m_autob);
|
||||
PartialShape& args_pshape = std::get<1>(args_et_pshape);
|
||||
|
||||
set_output_type(0, element::boolean, args_pshape);
|
||||
|
@ -71,8 +71,6 @@ namespace ngraph
|
||||
|
||||
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
|
||||
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
|
||||
bool supports_auto_broadcast() const override { return true; }
|
||||
bool is_binary_elementwise_comparison() const override { return true; }
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
private:
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include "ngraph/op/util/binary_elementwise_logical.hpp"
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/op/util/elementwise_args.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -32,6 +33,23 @@ op::util::BinaryElementwiseLogical::BinaryElementwiseLogical(const Output<Node>&
|
||||
{
|
||||
}
|
||||
|
||||
void op::util::BinaryElementwiseLogical::validate_and_infer_elementwise_logical(
|
||||
const op::AutoBroadcastSpec& autob)
|
||||
{
|
||||
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this, autob);
|
||||
element::Type& args_et = std::get<0>(args_et_pshape);
|
||||
PartialShape& args_pshape = std::get<1>(args_et_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
args_et.is_dynamic() || args_et == element::boolean,
|
||||
"Operands for logical operators must have boolean element type but have element type ",
|
||||
args_et,
|
||||
".");
|
||||
|
||||
set_output_type(0, element::boolean, args_pshape);
|
||||
}
|
||||
|
||||
void op::util::BinaryElementwiseLogical::validate_and_infer_types()
|
||||
{
|
||||
validate_and_infer_elementwise_logical(m_autob);
|
||||
|
@ -68,11 +68,10 @@ namespace ngraph
|
||||
|
||||
const AutoBroadcastSpec& get_autob() const override { return m_autob; }
|
||||
void set_autob(const AutoBroadcastSpec& autob) { m_autob = autob; }
|
||||
bool supports_auto_broadcast() const override { return true; }
|
||||
bool is_binary_elementwise_logical() const override { return true; }
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
private:
|
||||
void validate_and_infer_elementwise_logical(const op::AutoBroadcastSpec& autob);
|
||||
AutoBroadcastSpec m_autob;
|
||||
};
|
||||
}
|
||||
|
@ -19,6 +19,7 @@
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/partial_shape.hpp"
|
||||
|
||||
#include "ngraph/runtime/reference/broadcast.hpp"
|
||||
@ -192,7 +193,7 @@ void op::util::BroadcastBase::validate_and_infer_types()
|
||||
" doesn't match rank of input tensor ",
|
||||
arg_shape.size());
|
||||
|
||||
if (shape_constant && input_value(2).get_node_shared_ptr()->is_constant())
|
||||
if (shape_constant && op::is_constant(input_value(2).get_node()))
|
||||
{
|
||||
auto target_shape = shape_constant->get_shape_val();
|
||||
auto axes_mapping_val =
|
||||
|
61
ngraph/src/ngraph/op/util/elementwise_args.cpp
Normal file
61
ngraph/src/ngraph/op/util/elementwise_args.cpp
Normal file
@ -0,0 +1,61 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "elementwise_args.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
std::tuple<element::Type, PartialShape>
|
||||
ngraph::op::util::validate_and_infer_elementwise_args(Node* node,
|
||||
const op::AutoBroadcastSpec& autob)
|
||||
{
|
||||
NGRAPH_CHECK(node != nullptr, "nGraph node is empty! Cannot validate eltwise arguments.");
|
||||
element::Type element_type = node->get_input_element_type(0);
|
||||
PartialShape pshape = node->get_input_partial_shape(0);
|
||||
|
||||
if (node->get_input_size() > 1)
|
||||
{
|
||||
for (size_t i = 1; i < node->get_input_size(); ++i)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
node,
|
||||
element::Type::merge(element_type, element_type, node->get_input_element_type(i)),
|
||||
"Argument element types are inconsistent.");
|
||||
|
||||
if (autob.m_type == op::AutoBroadcastType::NONE)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
node,
|
||||
PartialShape::merge_into(pshape, node->get_input_partial_shape(i)),
|
||||
"Argument shapes are inconsistent.");
|
||||
}
|
||||
else if (autob.m_type == op::AutoBroadcastType::NUMPY ||
|
||||
autob.m_type == op::AutoBroadcastType::PDPD)
|
||||
{
|
||||
NODE_VALIDATION_CHECK(node,
|
||||
PartialShape::broadcast_merge_into(
|
||||
pshape, node->get_input_partial_shape(i), autob),
|
||||
"Argument shapes are inconsistent.");
|
||||
}
|
||||
else
|
||||
{
|
||||
NODE_VALIDATION_CHECK(node, false, "Unsupported auto broadcast specification");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_tuple(element_type, pshape);
|
||||
}
|
@ -16,22 +16,16 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "ngraph/node.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
enum class Placement
|
||||
namespace op
|
||||
{
|
||||
DEFAULT,
|
||||
INTERPRETER,
|
||||
CPU,
|
||||
GPU,
|
||||
NNP,
|
||||
};
|
||||
|
||||
std::string placement_to_string(Placement placement);
|
||||
namespace util
|
||||
{
|
||||
std::tuple<element::Type, PartialShape> validate_and_infer_elementwise_args(
|
||||
Node* node, const op::AutoBroadcastSpec& autob = op::AutoBroadcastSpec());
|
||||
}
|
||||
}
|
||||
}
|
@ -30,7 +30,6 @@ namespace ngraph
|
||||
class NGRAPH_API FusedOp : public Op
|
||||
{
|
||||
public:
|
||||
bool supports_decompose() const final { return true; }
|
||||
// Fused op decomposition can be performed in the presence of
|
||||
// partial shapes
|
||||
virtual bool can_decompose_with_partial_shapes() { return false; }
|
||||
|
158
ngraph/src/ngraph/op/util/op_types.cpp
Normal file
158
ngraph/src/ngraph/op/util/op_types.cpp
Normal file
@ -0,0 +1,158 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/and.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/equal.hpp"
|
||||
#include "ngraph/op/maximum.hpp"
|
||||
#include "ngraph/op/minimum.hpp"
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/op/not_equal.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/or.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/result.hpp"
|
||||
#include "ngraph/op/select.hpp"
|
||||
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
|
||||
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
|
||||
#include "ngraph/op/util/binary_elementwise_logical.hpp"
|
||||
#include "ngraph/op/util/fused_op.hpp"
|
||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||
#include "ngraph/op/xor.hpp"
|
||||
#include "ngraph/type.hpp"
|
||||
|
||||
bool ngraph::op::is_unary_elementwise_arithmetic(const ngraph::Node* node)
|
||||
{
|
||||
return dynamic_cast<const ngraph::op::util::UnaryElementwiseArithmetic*>(node) != nullptr;
|
||||
}
|
||||
|
||||
bool ngraph::op::is_binary_elementwise_arithmetic(const ngraph::Node* node)
|
||||
{
|
||||
return dynamic_cast<const ngraph::op::util::BinaryElementwiseArithmetic*>(node) != nullptr;
|
||||
}
|
||||
|
||||
bool ngraph::op::is_binary_elementwise_comparison(const ngraph::Node* node)
|
||||
{
|
||||
return dynamic_cast<const ngraph::op::util::BinaryElementwiseComparison*>(node) != nullptr;
|
||||
}
|
||||
|
||||
bool ngraph::op::is_binary_elementwise_logical(const ngraph::Node* node)
|
||||
{
|
||||
return dynamic_cast<const ngraph::op::util::BinaryElementwiseLogical*>(node) != nullptr;
|
||||
}
|
||||
|
||||
bool ngraph::op::supports_auto_broadcast(const ngraph::Node* node)
|
||||
{
|
||||
return dynamic_cast<const ngraph::op::v1::Select*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::util::BinaryElementwiseComparison*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::util::BinaryElementwiseLogical*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::util::BinaryElementwiseArithmetic*>(node) != nullptr;
|
||||
}
|
||||
|
||||
bool ngraph::op::supports_decompose(const ngraph::Node* node)
|
||||
{
|
||||
return dynamic_cast<const ngraph::op::util::FusedOp*>(node) != nullptr;
|
||||
}
|
||||
|
||||
bool ngraph::op::is_op(const ngraph::Node* node)
|
||||
{
|
||||
return dynamic_cast<const ngraph::op::Op*>(node) != nullptr;
|
||||
}
|
||||
|
||||
bool ngraph::op::is_parameter(const ngraph::Node* node)
|
||||
{
|
||||
return dynamic_cast<const ngraph::op::Parameter*>(node) != nullptr;
|
||||
}
|
||||
|
||||
bool ngraph::op::is_output(const ngraph::Node* node)
|
||||
{
|
||||
return dynamic_cast<const ngraph::op::Result*>(node) != nullptr;
|
||||
}
|
||||
|
||||
bool ngraph::op::is_constant(const ngraph::Node* node)
|
||||
{
|
||||
return dynamic_cast<const ngraph::op::Constant*>(node) != nullptr;
|
||||
}
|
||||
|
||||
bool ngraph::op::is_commutative(const ngraph::Node* node)
|
||||
{
|
||||
return dynamic_cast<const ngraph::op::v0::Add*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v1::Add*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v0::Maximum*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v1::Maximum*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v0::Equal*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v1::Equal*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v0::NotEqual*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v1::NotEqual*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v1::LogicalAnd*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v0::Xor*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v1::LogicalXor*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v0::Minimum*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v1::Minimum*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v0::Multiply*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v1::Multiply*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v0::Or*>(node) != nullptr ||
|
||||
dynamic_cast<const ngraph::op::v1::LogicalOr*>(node) != nullptr;
|
||||
}
|
||||
|
||||
bool ngraph::op::is_unary_elementwise_arithmetic(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
return is_unary_elementwise_arithmetic(node.get());
|
||||
}
|
||||
bool ngraph::op::is_binary_elementwise_arithmetic(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
return is_binary_elementwise_arithmetic(node.get());
|
||||
}
|
||||
bool ngraph::op::is_binary_elementwise_comparison(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
return is_binary_elementwise_comparison(node.get());
|
||||
}
|
||||
bool ngraph::op::is_binary_elementwise_logical(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
return is_binary_elementwise_logical(node.get());
|
||||
}
|
||||
|
||||
bool ngraph::op::supports_auto_broadcast(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
return supports_auto_broadcast(node.get());
|
||||
}
|
||||
|
||||
bool ngraph::op::supports_decompose(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
return supports_decompose(node.get());
|
||||
}
|
||||
|
||||
bool ngraph::op::is_op(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
return is_op(node.get());
|
||||
}
|
||||
bool ngraph::op::is_parameter(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
return is_parameter(node.get());
|
||||
}
|
||||
bool ngraph::op::is_output(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
return is_output(node.get());
|
||||
}
|
||||
bool ngraph::op::is_constant(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
return is_constant(node.get());
|
||||
}
|
||||
bool ngraph::op::is_commutative(const std::shared_ptr<ngraph::Node>& node)
|
||||
{
|
||||
return is_commutative(node.get());
|
||||
}
|
79
ngraph/src/ngraph/op/util/op_types.hpp
Normal file
79
ngraph/src/ngraph/op/util/op_types.hpp
Normal file
@ -0,0 +1,79 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include "ngraph/ngraph_visibility.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
NGRAPH_API
|
||||
bool is_unary_elementwise_arithmetic(const ngraph::Node* node);
|
||||
NGRAPH_API
|
||||
bool is_binary_elementwise_arithmetic(const ngraph::Node* node);
|
||||
NGRAPH_API
|
||||
bool is_binary_elementwise_comparison(const ngraph::Node* node);
|
||||
NGRAPH_API
|
||||
bool is_binary_elementwise_logical(const ngraph::Node* node);
|
||||
|
||||
NGRAPH_API
|
||||
bool supports_auto_broadcast(const ngraph::Node* node);
|
||||
|
||||
NGRAPH_API
|
||||
bool supports_decompose(const ngraph::Node* node);
|
||||
|
||||
NGRAPH_API
|
||||
bool is_op(const ngraph::Node* node);
|
||||
NGRAPH_API
|
||||
bool is_parameter(const ngraph::Node* node);
|
||||
NGRAPH_API
|
||||
bool is_output(const ngraph::Node* node);
|
||||
NGRAPH_API
|
||||
bool is_constant(const ngraph::Node* node);
|
||||
NGRAPH_API
|
||||
bool is_commutative(const ngraph::Node* node);
|
||||
|
||||
NGRAPH_API
|
||||
bool is_unary_elementwise_arithmetic(const std::shared_ptr<ngraph::Node>& node);
|
||||
NGRAPH_API
|
||||
bool is_binary_elementwise_arithmetic(const std::shared_ptr<ngraph::Node>& node);
|
||||
NGRAPH_API
|
||||
bool is_binary_elementwise_comparison(const std::shared_ptr<ngraph::Node>& node);
|
||||
NGRAPH_API
|
||||
bool is_binary_elementwise_logical(const std::shared_ptr<ngraph::Node>& node);
|
||||
|
||||
NGRAPH_API
|
||||
bool supports_auto_broadcast(const std::shared_ptr<ngraph::Node>& node);
|
||||
|
||||
NGRAPH_API
|
||||
bool supports_decompose(const std::shared_ptr<ngraph::Node>& node);
|
||||
|
||||
NGRAPH_API
|
||||
bool is_op(const std::shared_ptr<ngraph::Node>& node);
|
||||
NGRAPH_API
|
||||
bool is_parameter(const std::shared_ptr<ngraph::Node>& node);
|
||||
NGRAPH_API
|
||||
bool is_output(const std::shared_ptr<ngraph::Node>& node);
|
||||
NGRAPH_API
|
||||
bool is_constant(const std::shared_ptr<ngraph::Node>& node);
|
||||
NGRAPH_API
|
||||
bool is_commutative(const std::shared_ptr<ngraph::Node>& node);
|
||||
}
|
||||
}
|
@ -15,6 +15,7 @@
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/util/scatter_base.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
|
||||
@ -80,7 +81,7 @@ void op::util::ScatterBase::validate_and_infer_types()
|
||||
|
||||
bool compatible = true;
|
||||
int64_t axis;
|
||||
bool is_axis_constant = input_value(AXIS).get_node_shared_ptr()->is_constant();
|
||||
bool is_axis_constant = op::is_constant(input_value(AXIS).get_node());
|
||||
|
||||
// Get axis value if possible.
|
||||
if (is_axis_constant && data_shape.rank().is_static())
|
||||
|
@ -15,6 +15,7 @@
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||
#include "ngraph/op/util/elementwise_args.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
@ -28,6 +29,21 @@ op::util::UnaryElementwiseArithmetic::UnaryElementwiseArithmetic(const Output<No
|
||||
{
|
||||
}
|
||||
|
||||
void op::util::UnaryElementwiseArithmetic::validate_and_infer_elementwise_arithmetic()
|
||||
{
|
||||
auto args_et_pshape = op::util::validate_and_infer_elementwise_args(this);
|
||||
element::Type& args_et = std::get<0>(args_et_pshape);
|
||||
PartialShape& args_pshape = std::get<1>(args_et_pshape);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
args_et.is_dynamic() || args_et != element::boolean,
|
||||
"Arguments cannot have boolean element type (argument element type: ",
|
||||
args_et,
|
||||
").");
|
||||
|
||||
set_output_type(0, args_et, args_pshape);
|
||||
}
|
||||
|
||||
void op::util::UnaryElementwiseArithmetic::validate_and_infer_types()
|
||||
{
|
||||
validate_and_infer_elementwise_arithmetic();
|
||||
|
@ -57,8 +57,10 @@ namespace ngraph
|
||||
|
||||
public:
|
||||
void validate_and_infer_types() override;
|
||||
bool is_unary_elementwise_arithmetic() const override { return true; }
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
private:
|
||||
void validate_and_infer_elementwise_arithmetic();
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/op/variadic_split.hpp"
|
||||
#include "ngraph/validation_util.hpp"
|
||||
|
||||
@ -62,8 +63,8 @@ void ngraph::op::v1::VariadicSplit::validate_and_infer_types()
|
||||
const auto& data_type = data.get_element_type();
|
||||
|
||||
set_output_size(num_outputs);
|
||||
if (data_shape.rank().is_static() && axis_input->is_constant() &&
|
||||
split_lengths_input->is_constant())
|
||||
if (data_shape.rank().is_static() && op::is_constant(axis_input) &&
|
||||
op::is_constant(split_lengths_input))
|
||||
{
|
||||
const auto axis_input_constant = as_type_ptr<op::Constant>(axis_input);
|
||||
auto axis_val = axis_input_constant->cast_vector<int64_t>()[0];
|
||||
|
@ -52,7 +52,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
@ -85,7 +84,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual bool is_commutative() const override { return true; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
|
@ -168,5 +168,6 @@
|
||||
#include "ngraph/op/topk.hpp"
|
||||
#include "ngraph/op/transpose.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/op/variadic_split.hpp"
|
||||
#include "ngraph/op/xor.hpp"
|
||||
|
@ -38,6 +38,7 @@
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/op/transpose.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/opsets/opset2.hpp"
|
||||
#include "ngraph/opsets/opset3.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
@ -818,7 +819,7 @@ bool pass::AlgebraicSimplification::run_on_function(shared_ptr<Function> f)
|
||||
bool replaced = false;
|
||||
for (auto n : f->get_ordered_ops())
|
||||
{
|
||||
if (n->is_output() || n->is_parameter())
|
||||
if (op::is_output(n) || op::is_parameter(n))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include <sstream>
|
||||
|
||||
#include "common_function_collection.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -48,11 +49,11 @@ bool pass::CommonFunctionCollection::run_on_module(vector<shared_ptr<Function>>&
|
||||
{
|
||||
for (const shared_ptr<Node>& n : current_function->get_ordered_ops())
|
||||
{
|
||||
if (n->is_constant() || n->is_parameter())
|
||||
if (op::is_constant(n) || op::is_parameter(n))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if (n->is_op())
|
||||
if (op::is_op(n))
|
||||
{
|
||||
auto op = std::static_pointer_cast<op::Op>(n);
|
||||
auto annotations = op->get_op_annotations();
|
||||
|
@ -58,6 +58,7 @@
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/op/tan.hpp"
|
||||
#include "ngraph/op/tanh.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -277,7 +278,7 @@ namespace std
|
||||
|
||||
// TODO: Do we need another map, so we could
|
||||
// specify how to compute hash for each op?
|
||||
if (p_this.is_commutative())
|
||||
if (ngraph::op::is_commutative(&p_this))
|
||||
{
|
||||
sort(begin(cargs), end(cargs));
|
||||
}
|
||||
@ -301,7 +302,7 @@ bool ngraph::pass::CommonSubexpressionElimination::run_on_function(shared_ptr<ng
|
||||
|
||||
for (auto n : f->get_ordered_ops())
|
||||
{
|
||||
if (n->is_output() || n->is_parameter())
|
||||
if (op::is_output(n) || op::is_parameter(n))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
@ -16,6 +16,7 @@
|
||||
#include "ngraph/pass/fused_op_decomposition.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/op/get_output_element.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/provenance.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -30,7 +31,7 @@ bool pass::FusedOpDecomposition::run_on_node(shared_ptr<Node> node)
|
||||
{
|
||||
bool modified = false;
|
||||
|
||||
if (node->supports_decompose())
|
||||
if (op::supports_decompose(node))
|
||||
{
|
||||
if (m_has_direct_support && m_has_direct_support(*node))
|
||||
{
|
||||
|
@ -21,13 +21,14 @@
|
||||
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
|
||||
#include "ngraph/op/util/binary_elementwise_comparison.hpp"
|
||||
#include "ngraph/op/util/binary_elementwise_logical.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
bool ngraph::pass::ImplicitBroadcastElimination::run_on_node(std::shared_ptr<Node> node)
|
||||
{
|
||||
if (node->supports_auto_broadcast())
|
||||
if (ngraph::op::supports_auto_broadcast(node))
|
||||
{
|
||||
if (node->get_autob().m_type != op::AutoBroadcastType::NONE)
|
||||
{
|
||||
@ -45,7 +46,7 @@ bool ngraph::pass::ImplicitBroadcastElimination::run_on_node(std::shared_ptr<Nod
|
||||
NodeVector ngraph::pass::explicit_broadcast(std::shared_ptr<Node>& node)
|
||||
{
|
||||
NodeVector rc;
|
||||
if (node->supports_auto_broadcast())
|
||||
if (ngraph::op::supports_auto_broadcast(node))
|
||||
{
|
||||
auto autob = node->get_autob();
|
||||
if (autob.m_type == op::AutoBroadcastType::NONE)
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/get_output_element.hpp"
|
||||
#include "ngraph/op/slice.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/pass/liveness.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pass/memory_layout.hpp"
|
||||
@ -48,7 +49,7 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function)
|
||||
std::map<descriptor::Tensor*, descriptor::Tensor*> in_place_outputs;
|
||||
std::set<const descriptor::Tensor*> reused_inputs;
|
||||
|
||||
if (node->is_op())
|
||||
if (op::is_op(node))
|
||||
{
|
||||
auto op = std::static_pointer_cast<op::Op>(node);
|
||||
// concat and slice in_place_oi should be treated differently
|
||||
@ -67,7 +68,7 @@ bool pass::MemoryLayout::run_on_function(shared_ptr<Function> function)
|
||||
if ((node->liveness_free_list.count(input) != 0 ||
|
||||
is_type<op::GetOutputElement>(node) ||
|
||||
(m_disable_memory_sharing && !oi_pair.destructive &&
|
||||
!input_node->is_parameter() && !input_node->is_constant())) &&
|
||||
!op::is_parameter(input_node) && !op::is_constant(input_node))) &&
|
||||
node->liveness_new_list.count(output) != 0)
|
||||
|
||||
{
|
||||
|
@ -35,6 +35,7 @@
|
||||
#include "ngraph/op/slice.hpp"
|
||||
#include "ngraph/op/stop_gradient.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/opsets/opset3.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
#include "nop_elimination.hpp"
|
||||
@ -344,7 +345,7 @@ static bool eliminate_squeeze(const std::shared_ptr<Node>& node)
|
||||
if (auto unsqueeze = as_type_ptr<opset3::Unsqueeze>(input))
|
||||
{
|
||||
PartialShape data_shape;
|
||||
if (input->is_parameter())
|
||||
if (op::is_parameter(input))
|
||||
{
|
||||
data_shape = unsqueeze->input(0).get_partial_shape();
|
||||
}
|
||||
@ -393,7 +394,7 @@ static bool eliminate_squeeze(const std::shared_ptr<Node>& node)
|
||||
if (auto squeeze_i = as_type_ptr<opset3::Squeeze>(input))
|
||||
{
|
||||
PartialShape data_shape;
|
||||
if (input->is_parameter())
|
||||
if (op::is_parameter(input))
|
||||
{
|
||||
data_shape = squeeze_i->input(0).get_partial_shape();
|
||||
}
|
||||
|
@ -21,6 +21,7 @@
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/util/op_annotations.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -29,7 +30,7 @@ bool pass::PropagateCacheability::run_on_function(shared_ptr<Function> function)
|
||||
{
|
||||
for (auto& node : function->get_ordered_ops())
|
||||
{
|
||||
if (node->is_op())
|
||||
if (op::is_op(node))
|
||||
{
|
||||
auto op = static_pointer_cast<op::Op>(node);
|
||||
NGRAPH_DEBUG << "propagate cacheability: node is " << node->get_name();
|
||||
@ -40,7 +41,7 @@ bool pass::PropagateCacheability::run_on_function(shared_ptr<Function> function)
|
||||
op_annotations = op_annotations_factory();
|
||||
op->set_op_annotations(op_annotations);
|
||||
}
|
||||
if (node->is_parameter())
|
||||
if (op::is_parameter(node))
|
||||
{
|
||||
auto parameter = static_pointer_cast<op::Parameter>(node);
|
||||
op_annotations->set_cacheable(parameter->get_cacheable());
|
||||
@ -54,7 +55,7 @@ bool pass::PropagateCacheability::run_on_function(shared_ptr<Function> function)
|
||||
{
|
||||
auto input_value_node = input.get_source_output().get_node_shared_ptr();
|
||||
NGRAPH_DEBUG << "propagate cacheability: arg is " << *input_value_node;
|
||||
if (input_value_node->is_op())
|
||||
if (op::is_op(input_value_node))
|
||||
{
|
||||
auto arg_op = static_pointer_cast<op::Op>(input_value_node);
|
||||
auto arg_op_annotations = arg_op->get_op_annotations();
|
||||
|
@ -35,6 +35,7 @@
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
#include "ngraph/op/slice.hpp"
|
||||
#include "ngraph/op/util/binary_elementwise_arithmetic.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/op/util/unary_elementwise_arithmetic.hpp"
|
||||
#include "ngraph/pattern/op/label.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
@ -181,7 +182,7 @@ void swim(Input<Node> input, shared_ptr<op::Reshape> reshape)
|
||||
continue;
|
||||
}
|
||||
NGRAPH_DEBUG << "Processing (swimming) " << n->get_name();
|
||||
if (n->is_unary_elementwise_arithmetic())
|
||||
if (op::is_unary_elementwise_arithmetic(n))
|
||||
{
|
||||
Swimmer nsw{n->input(0), csw.reshape};
|
||||
work_queue.push_back(nsw);
|
||||
@ -549,7 +550,7 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
|
||||
{
|
||||
NGRAPH_DEBUG << "Start: Processing node " << n->get_name();
|
||||
// collect all Result nodes for a sanity check
|
||||
if (n->is_output())
|
||||
if (ngraph::op::is_output(n))
|
||||
{
|
||||
results.push_back(n);
|
||||
}
|
||||
@ -558,11 +559,11 @@ bool ngraph::pass::ReshapeSinking::run_on_function(shared_ptr<ngraph::Function>
|
||||
{
|
||||
sink_reshape(reshape, reorders, reshapes_to_delete);
|
||||
}
|
||||
else if (n->is_unary_elementwise_arithmetic())
|
||||
else if (op::is_unary_elementwise_arithmetic(n))
|
||||
{
|
||||
sink_unary(n, reorders, reshapes_to_delete);
|
||||
}
|
||||
else if (n->is_binary_elementwise_arithmetic())
|
||||
else if (op::is_binary_elementwise_arithmetic(n))
|
||||
{
|
||||
sink_binary(n, reorders, reshapes_to_delete);
|
||||
}
|
||||
|
@ -17,6 +17,7 @@
|
||||
#include "ngraph/pass/shape_relevance.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
@ -91,7 +92,7 @@ bool pass::ShapeRelevance::run_on_function(std::shared_ptr<Function> f)
|
||||
shape_determinants.insert(node);
|
||||
already_visited.insert(node);
|
||||
|
||||
if (node->is_parameter())
|
||||
if (op::is_parameter(node))
|
||||
{
|
||||
auto node_as_param = static_cast<op::Parameter*>(node);
|
||||
if (!node_as_param->is_relevant_to_shapes())
|
||||
|
@ -24,6 +24,7 @@
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/get_output_element.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/pass/pass.hpp"
|
||||
#include "ngraph/pass/visualize_tree.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
@ -348,7 +349,7 @@ static std::string pretty_value(const vector<T>& value)
|
||||
|
||||
std::string pass::VisualizeTree::get_constant_value(std::shared_ptr<Node> node, size_t max_elements)
|
||||
{
|
||||
if (!node->is_constant())
|
||||
if (!op::is_constant(node))
|
||||
return {};
|
||||
std::stringstream ss;
|
||||
ss << "{" << node->get_element_type().get_type_name() << "}";
|
||||
@ -392,7 +393,7 @@ string pass::VisualizeTree::get_attributes(shared_ptr<Node> node)
|
||||
vector<string> attributes;
|
||||
attributes.push_back("shape=box");
|
||||
|
||||
if (node->is_output())
|
||||
if (ngraph::op::is_output(node))
|
||||
{
|
||||
attributes.push_back("color=crimson");
|
||||
attributes.push_back("penwidth=1.5");
|
||||
|
@ -29,6 +29,7 @@
|
||||
#include "ngraph/op/product.hpp"
|
||||
#include "ngraph/op/replace_slice.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/type.hpp"
|
||||
#include "zero_dim_tensor_elimination.hpp"
|
||||
|
||||
@ -46,7 +47,8 @@ static bool verify_no_internal_zero_length_ops(shared_ptr<Function> f)
|
||||
set<Output<Node>> zero_length_source_outputs;
|
||||
for (auto n : f->get_ordered_ops())
|
||||
{
|
||||
if (n->is_output() || n->is_parameter() || n->is_constant() || n->get_output_size() > 1)
|
||||
if (op::is_output(n) || op::is_parameter(n) || op::is_constant(n) ||
|
||||
n->get_output_size() > 1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
@ -92,7 +94,7 @@ bool pass::ZeroDimTensorElimination::run_on_function(shared_ptr<Function> f)
|
||||
// if any `GetOutputElement` is zero-length
|
||||
// we replace it w/ a signalling constant
|
||||
// so we don't have to deal w/ multi-output nodes directly
|
||||
if (n->is_output() || n->is_parameter() || n->get_output_size() > 1)
|
||||
if (op::is_output(n) || op::is_parameter(n) || n->get_output_size() > 1)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include "ngraph/log.hpp"
|
||||
#include "ngraph/op/get_output_element.hpp"
|
||||
#include "ngraph/op/parameter.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -187,7 +188,7 @@ namespace ngraph
|
||||
return false;
|
||||
}
|
||||
|
||||
if (graph_node->is_commutative())
|
||||
if (ngraph::op::is_commutative(graph_node))
|
||||
{
|
||||
// TODO: [nikolayk] we don't really have to use lexicographically-based perms,
|
||||
// heap's algo should be faster
|
||||
|
@ -85,7 +85,6 @@ namespace ngraph
|
||||
|
||||
ValuePredicate get_predicate() const;
|
||||
|
||||
bool is_pattern() const override { return true; }
|
||||
protected:
|
||||
ValuePredicate m_predicate;
|
||||
};
|
||||
|
@ -1,40 +0,0 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <deque>
|
||||
#include <sstream>
|
||||
|
||||
#include "ngraph/function.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/placement.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
std::string ngraph::placement_to_string(Placement placement)
|
||||
{
|
||||
switch (placement)
|
||||
{
|
||||
case Placement::DEFAULT: return "DEFAULT";
|
||||
case Placement::INTERPRETER: return "INTERPRETER";
|
||||
case Placement::CPU: return "CPU";
|
||||
case Placement::GPU: return "GPU";
|
||||
case Placement::NNP: return "NNP";
|
||||
}
|
||||
throw runtime_error("unhandled placement type");
|
||||
}
|
@ -19,6 +19,7 @@
|
||||
#include "ngraph/op/assign.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/tensor_iterator.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
@ -67,7 +68,7 @@ std::shared_ptr<Function>
|
||||
|
||||
for (auto old_node : f->get_ordered_ops())
|
||||
{
|
||||
if (old_node->is_parameter())
|
||||
if (op::is_parameter(old_node))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
@ -481,7 +481,7 @@ TEST(autobroadcast, axes_mapping_from_bcast_axes)
|
||||
const AxisSet broadcast_axes{0, 2};
|
||||
|
||||
auto axes_mapping = builder::opset1::get_axes_mapping_output(output_shape, broadcast_axes);
|
||||
EXPECT_TRUE(axes_mapping.get_node()->is_constant());
|
||||
EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
|
||||
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
|
||||
EXPECT_EQ(axes_mapping_shape.size(), 2);
|
||||
EXPECT_EQ(axes_mapping_shape, (Shape{1, 3}));
|
||||
@ -494,7 +494,7 @@ TEST(autobroadcast, axes_mapping_from_bcast_axes_scalar)
|
||||
const AxisSet broadcast_axes{0, 1, 2, 3};
|
||||
|
||||
auto axes_mapping = builder::opset1::get_axes_mapping_output(output_shape, broadcast_axes);
|
||||
EXPECT_TRUE(axes_mapping.get_node()->is_constant());
|
||||
EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
|
||||
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
|
||||
EXPECT_EQ(axes_mapping_shape.size(), 0);
|
||||
EXPECT_EQ(axes_mapping_shape, (Shape{}));
|
||||
@ -507,7 +507,7 @@ TEST(autobroadcast, axes_mapping_from_bcast_axes_identical)
|
||||
const AxisSet broadcast_axes{};
|
||||
|
||||
auto axes_mapping = builder::opset1::get_axes_mapping_output(output_shape, broadcast_axes);
|
||||
EXPECT_TRUE(axes_mapping.get_node()->is_constant());
|
||||
EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
|
||||
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
|
||||
EXPECT_EQ(axes_mapping_shape.size(), output_shape.size());
|
||||
EXPECT_EQ(axes_mapping_shape, (Shape{0, 1, 2, 3}));
|
||||
@ -521,7 +521,7 @@ TEST(autobroadcast, axes_mapping_start_match_axis)
|
||||
|
||||
auto axes_mapping =
|
||||
builder::opset1::get_axes_mapping_output(output_shape, input_shape, start_match_axis);
|
||||
EXPECT_TRUE(axes_mapping.get_node()->is_constant());
|
||||
EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
|
||||
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
|
||||
EXPECT_EQ(axes_mapping_shape.size(), 2);
|
||||
EXPECT_EQ(axes_mapping_shape, (Shape{1, 2}));
|
||||
@ -535,7 +535,7 @@ TEST(autobroadcast, axes_mapping_start_match_axis_scalar)
|
||||
|
||||
auto axes_mapping =
|
||||
builder::opset1::get_axes_mapping_output(output_shape, input_shape, start_match_axis);
|
||||
EXPECT_TRUE(axes_mapping.get_node()->is_constant());
|
||||
EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
|
||||
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
|
||||
EXPECT_EQ(axes_mapping_shape.size(), 0);
|
||||
EXPECT_EQ(axes_mapping_shape, (Shape{}));
|
||||
@ -549,7 +549,7 @@ TEST(autobroadcast, axes_mapping_start_match_axis_identical)
|
||||
|
||||
auto axes_mapping =
|
||||
builder::opset1::get_axes_mapping_output(output_shape, input_shape, start_match_axis);
|
||||
EXPECT_TRUE(axes_mapping.get_node()->is_constant());
|
||||
EXPECT_TRUE(op::is_constant(axes_mapping.get_node()));
|
||||
Shape axes_mapping_shape = as_type<op::v0::Constant>(axes_mapping.get_node())->get_shape_val();
|
||||
EXPECT_EQ(axes_mapping_shape.size(), output_shape.size());
|
||||
EXPECT_EQ(axes_mapping_shape, (Shape{0, 1, 2, 3}));
|
||||
|
@ -35,6 +35,7 @@
|
||||
// clang-format on
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/frontend/onnx_import/core/null_node.hpp"
|
||||
#include "ngraph/frontend/onnx_import/onnx.hpp"
|
||||
#include "ngraph/frontend/onnx_import/onnx_utils.hpp"
|
||||
#include "ngraph/frontend/onnx_import/default_opset.hpp"
|
||||
@ -376,7 +377,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_missing_input)
|
||||
std::shared_ptr<ngraph::Node> C = ng_inputs.at(2);
|
||||
|
||||
A = A * C;
|
||||
if (!B->is_null())
|
||||
if (!ngraph::op::is_null(B))
|
||||
{
|
||||
B = B / C;
|
||||
}
|
||||
@ -393,7 +394,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_missing_input)
|
||||
|
||||
for (const auto& ng_input : ng_inputs)
|
||||
{
|
||||
if (!ng_input->is_null())
|
||||
if (!ngraph::op::is_null(ng_input))
|
||||
{
|
||||
result = ng_input * result;
|
||||
}
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "ngraph/file_util.hpp"
|
||||
#include "ngraph/frontend/onnx_import/default_opset.hpp"
|
||||
#include "ngraph/frontend/onnx_import/onnx.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/pass/constant_folding.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "util/all_close.hpp"
|
||||
@ -44,7 +45,7 @@ namespace
|
||||
|
||||
for (auto ng_node : ng_function->get_ordered_ops())
|
||||
{
|
||||
if (ng_node->is_constant())
|
||||
if (op::is_constant(ng_node))
|
||||
{
|
||||
const auto folded_node = as_type_ptr<default_opset::Constant>(ng_node);
|
||||
const auto output_values = folded_node->cast_vector<T>();
|
||||
|
@ -33,7 +33,7 @@ TEST(op, is_op)
|
||||
{
|
||||
auto arg0 = make_shared<op::Parameter>(element::f32, Shape{1});
|
||||
ASSERT_NE(nullptr, arg0);
|
||||
EXPECT_TRUE(arg0->is_parameter());
|
||||
EXPECT_TRUE(op::is_parameter(arg0));
|
||||
}
|
||||
|
||||
TEST(op, is_parameter)
|
||||
@ -42,7 +42,7 @@ TEST(op, is_parameter)
|
||||
ASSERT_NE(nullptr, arg0);
|
||||
auto t0 = make_shared<op::Add>(arg0, arg0);
|
||||
ASSERT_NE(nullptr, t0);
|
||||
EXPECT_FALSE(t0->is_parameter());
|
||||
EXPECT_FALSE(op::is_parameter(t0));
|
||||
}
|
||||
|
||||
TEST(op, provenance_tag)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -3,6 +3,7 @@
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "opset0_downgrade.hpp"
|
||||
#include "opset1_upgrade.hpp"
|
||||
@ -28,8 +29,8 @@ TEST(opset_transform, opset1_broadcast_upgrade_pass)
|
||||
ASSERT_TRUE(bcast_v1);
|
||||
EXPECT_EQ(bcast_v1->get_broadcast_spec(), op::AutoBroadcastSpec());
|
||||
EXPECT_EQ(bcast_v1->get_broadcast_axes(), (std::make_pair<bool, AxisSet>(true, AxisSet{0, 2})));
|
||||
ASSERT_TRUE(bcast_v1->input_value(1).get_node()->is_constant());
|
||||
ASSERT_TRUE(bcast_v1->input_value(2).get_node()->is_constant());
|
||||
ASSERT_TRUE(op::is_constant(bcast_v1->input_value(1).get_node()));
|
||||
ASSERT_TRUE(op::is_constant(bcast_v1->input_value(2).get_node()));
|
||||
EXPECT_EQ(
|
||||
as_type_ptr<op::Constant>(bcast_v1->input_value(1).get_node_shared_ptr())->get_shape_val(),
|
||||
(Shape{3, 5, 4, 6}));
|
||||
|
@ -34,6 +34,7 @@
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/op/sum.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/pass/graph_rewrite.hpp"
|
||||
#include "ngraph/pass/manager.hpp"
|
||||
#include "ngraph/pattern/matcher.hpp"
|
||||
@ -322,7 +323,7 @@ TEST(pattern, matcher)
|
||||
auto b = make_shared<op::Parameter>(element::i32, shape);
|
||||
|
||||
auto is_bea = [](std::shared_ptr<Node> node) -> bool {
|
||||
return node->is_binary_elementwise_arithmetic();
|
||||
return op::is_binary_elementwise_arithmetic(node);
|
||||
};
|
||||
auto bea = std::make_shared<pattern::op::Any>(a, is_bea, NodeVector{a, b});
|
||||
auto add_ab = a + b;
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "ngraph/cpio.hpp"
|
||||
#include "ngraph/descriptor/layout/dense_tensor_layout.hpp"
|
||||
#include "ngraph/except.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/ops.hpp"
|
||||
#include "ngraph/pass/assign_layout.hpp"
|
||||
#include "ngraph/pass/core_fusion.hpp"
|
||||
@ -164,7 +165,7 @@ bool runtime::interpreter::INTExecutable::call(const vector<shared_ptr<runtime::
|
||||
for (auto op : m_nodes)
|
||||
{
|
||||
event::Duration d2(op->description(), "Interpreter");
|
||||
if (op->is_parameter())
|
||||
if (op::is_parameter(op))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
@ -24,6 +24,7 @@
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/node.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/ops.hpp"
|
||||
#include "ngraph/pass/implicit_broadcast_elimination.hpp"
|
||||
#include "ngraph/provenance.hpp"
|
||||
@ -135,7 +136,7 @@ namespace
|
||||
*node);
|
||||
const auto& arg_shape = arg_pshape.to_shape();
|
||||
|
||||
NGRAPH_CHECK(target_shape_input.get_node_shared_ptr()->is_constant());
|
||||
NGRAPH_CHECK(op::is_constant(target_shape_input.get_node()));
|
||||
auto target_shape = node->get_output_shape(0);
|
||||
NGRAPH_CHECK(node->get_broadcast_axes().first);
|
||||
|
||||
@ -250,7 +251,7 @@ namespace
|
||||
|
||||
const auto target_shape_input = node->input_value(1).get_node_shared_ptr();
|
||||
const auto input_rank = node->get_input_partial_shape(0).rank();
|
||||
if (target_shape_input->is_constant() && node->get_output_partial_shape(0).is_static() &&
|
||||
if (op::is_constant(target_shape_input) && node->get_output_partial_shape(0).is_static() &&
|
||||
input_rank.is_static())
|
||||
{
|
||||
const auto output_shape = node->get_output_shape(0);
|
||||
@ -416,12 +417,12 @@ namespace
|
||||
shared_ptr<Node> op_cast(shared_ptr<op::v1::OneHot> node)
|
||||
{
|
||||
const auto indices = node->input_value(0);
|
||||
const auto depth = node->input_value(1).get_node_shared_ptr();
|
||||
const auto depth = node->input_value(1).get_node();
|
||||
auto on_value = node->input_value(2);
|
||||
auto off_value = node->input_value(3);
|
||||
const auto axis = node->get_axis();
|
||||
|
||||
NGRAPH_CHECK(depth->is_constant(), "depth input must be constant", *node);
|
||||
NGRAPH_CHECK(op::is_constant(depth), "depth input must be constant", *node);
|
||||
const auto output_pshape = node->get_output_partial_shape(0);
|
||||
NGRAPH_CHECK(output_pshape.is_static(), "output shape must be static", *node);
|
||||
const auto output_shape = output_pshape.to_shape();
|
||||
@ -529,7 +530,7 @@ namespace
|
||||
shared_ptr<Node> op_cast(shared_ptr<op::v1::Reverse> node)
|
||||
{
|
||||
auto axes_node = node->input_value(1).get_node_shared_ptr();
|
||||
NGRAPH_CHECK(axes_node->is_constant(),
|
||||
NGRAPH_CHECK(op::is_constant(axes_node),
|
||||
"Unable to convert Reverse:v1 to Reverse:v0 "
|
||||
"if reduction axes are not constant. Node: ",
|
||||
*node);
|
||||
@ -685,7 +686,7 @@ namespace
|
||||
const auto data_shape = data_pshape.to_shape();
|
||||
|
||||
const auto order_node = node->input_value(1).get_node_shared_ptr();
|
||||
NGRAPH_CHECK(order_node->is_constant(),
|
||||
NGRAPH_CHECK(op::is_constant(order_node),
|
||||
"Unable to convert Transpose:v1 to Reshape:v0 "
|
||||
"if order node is not constant. Node: ",
|
||||
*node);
|
||||
@ -715,7 +716,7 @@ namespace
|
||||
{
|
||||
const auto split_lengths = node->input_value(2).get_node_shared_ptr();
|
||||
|
||||
NGRAPH_CHECK(split_lengths->is_constant(),
|
||||
NGRAPH_CHECK(op::is_constant(split_lengths),
|
||||
"Unable to convert VariadicSplit:v1 to Split:v0 "
|
||||
"if 'split_lengths' input is not constant. Node: ",
|
||||
*node);
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include "ngraph/builder/autobroadcast.hpp"
|
||||
#include "ngraph/builder/reshape.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/op/util/op_types.hpp"
|
||||
#include "ngraph/ops.hpp"
|
||||
#include "ngraph/provenance.hpp"
|
||||
#include "op/avg_pool.hpp"
|
||||
@ -404,7 +405,7 @@ namespace
|
||||
|
||||
shared_ptr<Node> op_cast(shared_ptr<op::Softmax> node)
|
||||
{
|
||||
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(),
|
||||
NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
|
||||
"axes parameter is expected to be a static constant");
|
||||
|
||||
AxisSet axes = node->get_axes();
|
||||
@ -487,9 +488,9 @@ namespace
|
||||
|
||||
shared_ptr<Node> op_cast(shared_ptr<op::TopK> node)
|
||||
{
|
||||
NGRAPH_CHECK(node->input_value(1).get_node_shared_ptr()->is_constant(),
|
||||
NGRAPH_CHECK(op::is_constant(node->input_value(1).get_node()),
|
||||
"parameter k is expected to be a static constant");
|
||||
NGRAPH_CHECK(node->input_value(2).get_node_shared_ptr()->is_constant(),
|
||||
NGRAPH_CHECK(op::is_constant(node->input_value(2).get_node()),
|
||||
"parameter top_k_axis is expected to be a static constant");
|
||||
|
||||
const auto k = node->get_k();
|
||||
|
@ -86,6 +86,6 @@ TEST(tensor, output_flag)
|
||||
|
||||
for (size_t i = 0; i < f0->get_output_size(); ++i)
|
||||
{
|
||||
EXPECT_TRUE(f0->get_output_op(i)->is_output());
|
||||
EXPECT_TRUE(op::is_output(f0->get_output_op(i)));
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user