Visitor api ti implementation (#3576)
* TensorIterator deserialization. Introduce new on_adapter(Function) and add implementation in on_adapter(void) for Input and Output Descriptions. Remove factory adapter. * Add comments to functions provided. Add missing on_adapter() after rebase. * Apply formatting. * Remove visit_attributes from SubGraphOp, remove declaration for createSubGraphLayer. * Add port map parsing to address not consecutive order of external_port_id appearance. * Remove header for factory_adapter. * Add on_adapter() in V10Parse::parse() function. * Add m_num_iterations initialization for concat output. * Remove redundant lines, add doxygen comments. * Change cpp/ie_cnn_network.h to local include, remove temporary map object from range for loop. * Restore protected access for SubGraphOp.
This commit is contained in:
parent
4f0720176c
commit
431485e4a6
@ -27,7 +27,7 @@ convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function>& gr
|
||||
const ICNNNetwork &ngraphNetwork,
|
||||
CNNNetworkImpl* cnnNetworkImpl,
|
||||
bool keep_constant_inputs = false);
|
||||
|
||||
|
||||
// TODO: move ConstAllocatorWrapper class, shareWeights add addBlob into CNNLayerCreator when NodeConverter class is removed
|
||||
class ConstAllocatorWrapper : public IAllocator {
|
||||
public:
|
||||
|
@ -49,8 +49,11 @@
|
||||
#include "transformations/utils/utils.hpp"
|
||||
#include "transformations/rt_info/fused_names_attribute.hpp"
|
||||
#include "transformations/rt_info/primitives_priority_attribute.hpp"
|
||||
#include "cpp/ie_cnn_network.h"
|
||||
|
||||
#include "legacy/convert_function_to_cnn_network.hpp"
|
||||
#include "legacy/graph_tools.hpp"
|
||||
#include "legacy/net_pass.h"
|
||||
#include "ie_legacy_itt.hpp"
|
||||
#include "ie_cnn_layer_builder_ngraph.h"
|
||||
|
||||
@ -66,6 +69,210 @@ namespace details {
|
||||
return nullptr;\
|
||||
});\
|
||||
|
||||
/// \brief Creates legacy representation of CNNLayer for SubGraphOp.
|
||||
/// \param layer node type
|
||||
/// \return pointer to CNNLayer with legacy representation of SubGraphOp.
|
||||
CNNLayer::Ptr createSubGraphLayer(const std::shared_ptr<ngraph::Node>& layer) {
|
||||
auto sub_graph = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(layer);
|
||||
if (!sub_graph) {
|
||||
THROW_IE_EXCEPTION << "Cannot cast layer to SubGraphOp.";
|
||||
}
|
||||
|
||||
// inputs/outputs of TensorIterator (ngraph representation)
|
||||
auto parameters = sub_graph->get_function()->get_parameters();
|
||||
auto results = sub_graph->get_function()->get_results();
|
||||
|
||||
// Convert body (ngraph representation) to CNNNetwork.
|
||||
// This network will contain nodes of type = "Input" and data nodes with wrong names.
|
||||
// IE TensorIterator doesn't include such nodes so we create CNNNetwork in a separate scope
|
||||
// to call the destructor and delete these "Input"/data nodes.
|
||||
|
||||
TensorIterator::Body body;
|
||||
{
|
||||
InferenceEngine::CNNNetwork body_net(sub_graph->get_function());
|
||||
InferenceEngine::CNNNetwork net(InferenceEngine::details::convertFunctionToICNNNetwork(body_net.getFunction(), body_net));
|
||||
// Paranoid check for cycles
|
||||
bool res = CNNNetForestDFS(
|
||||
CNNNetGetAllInputLayers(net), [](const CNNLayerPtr& layer) {}, false);
|
||||
if (!res) {
|
||||
THROW_IE_EXCEPTION << "Loop detected. SubGraphOp body should not contain loops.";
|
||||
}
|
||||
|
||||
// Get inputs/outputs of cnn network
|
||||
auto in_info_map_with_parameters = net.getInputsInfo();
|
||||
auto out_info_map = net.getOutputsInfo();
|
||||
|
||||
IE_ASSERT(in_info_map_with_parameters.size() == parameters.size());
|
||||
IE_ASSERT(out_info_map.size() == results.size());
|
||||
|
||||
InferenceEngine::TensorIterator::Body temp_body;
|
||||
temp_body.inputs.resize(in_info_map_with_parameters.size());
|
||||
temp_body.outputs.resize(out_info_map.size());
|
||||
|
||||
// Fill inputs/outs in order aligned with ng representation
|
||||
uint64_t counter = 0;
|
||||
for (const auto& param : parameters) {
|
||||
auto info = in_info_map_with_parameters.at(param->get_friendly_name());
|
||||
temp_body.inputs[counter++] = info->getInputData();
|
||||
}
|
||||
|
||||
auto map_ng_result_to_ie_name = [] (std::shared_ptr<ngraph::op::v0::Result> res_op) {
|
||||
auto result = res_op->input(0).get_source_output();
|
||||
|
||||
std::string name = result.get_node()->get_friendly_name();
|
||||
if (result.get_node()->get_output_size() > 1) {
|
||||
name += "." + std::to_string(result.get_index());
|
||||
}
|
||||
return name;
|
||||
};
|
||||
|
||||
counter = 0;
|
||||
for (const auto& result : results) {
|
||||
auto data = out_info_map.at(map_ng_result_to_ie_name(result));
|
||||
temp_body.outputs[counter++] = data;
|
||||
}
|
||||
|
||||
// This deep copy will hold all unreachable constants. See the comment in CopyTIBody function.
|
||||
body = InferenceEngine::NetPass::CopyTIBody(temp_body);
|
||||
|
||||
// Check if data is really const layer holder
|
||||
auto is_constant_holder = [] (const DataPtr data) {
|
||||
return data->getPrecision() == Precision::UNSPECIFIED;
|
||||
};
|
||||
|
||||
// Strip unreached node holder from Inputs node.
|
||||
auto holder = body.inputs.back();
|
||||
if (is_constant_holder(holder)) {
|
||||
auto& holder_map = getInputTo(holder);
|
||||
|
||||
for( auto it = holder_map.begin(); it != holder_map.end(); ) {
|
||||
if( it->second->type == "Input")
|
||||
it = holder_map.erase(it);
|
||||
else
|
||||
++it;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Disable this WA after total switch onto Ngraph
|
||||
// WA: Some plugins (like GPU) require matching of Data object name and producer Layer name.
|
||||
// Data name is expected in format "[layer_name]" or "[layer_name].[port_idx]" in case
|
||||
// of multiple inputs. We have to restore it if possible and ignore original names of
|
||||
// Ngraph parameter and result ops.
|
||||
// Will not change data name if:
|
||||
// - data has several consumer layers
|
||||
// - data has no consumer (example if data is straight used as output)
|
||||
//
|
||||
for (auto &in : body.inputs) {
|
||||
if (is_constant_holder(in))
|
||||
continue;
|
||||
|
||||
const auto input_to = getInputTo(in);
|
||||
if (input_to.size() != 1)
|
||||
continue;
|
||||
|
||||
const auto consumer_layer = input_to.begin()->second;
|
||||
const auto consumer_in_port_set = consumer_layer->insData;
|
||||
const auto found = std::find_if(consumer_in_port_set.begin(), consumer_in_port_set.end(),
|
||||
[&in] (const DataWeakPtr &wptr) { return wptr.lock() == in; });
|
||||
IE_ASSERT(found != consumer_in_port_set.end());
|
||||
const auto consumer_port_idx = std::distance(consumer_in_port_set.begin(), found);
|
||||
|
||||
auto new_name = consumer_layer->name;
|
||||
if (consumer_in_port_set.size() > 1) {
|
||||
new_name += '.' + std::to_string(consumer_port_idx);
|
||||
}
|
||||
in->setName(new_name);
|
||||
}
|
||||
|
||||
// TODO: this WA restore original precisions of outputs.
|
||||
// convertFunctionToICNNNetwork has internal fallback policy for unsupported
|
||||
// precisions for inputs/outputs ports. Particular for U8 will be translated
|
||||
// to FP32. However Loop body has strong requirements for continue_condition
|
||||
// port, it should be BOOL(U8).
|
||||
//
|
||||
for (int i = 0; i < results.size(); i++) {
|
||||
auto result = results[i];
|
||||
auto output = body.outputs[i];
|
||||
if (result->get_element_type() == ngraph::element::u8) {
|
||||
output->setPrecision(InferenceEngine::Precision::U8);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create Inference Engine representation of TensorIterator
|
||||
LayerParams params = {layer->get_friendly_name(), "TensorIterator",
|
||||
details::convertPrecision(layer->get_output_element_type(0))};
|
||||
auto res = std::make_shared<InferenceEngine::TensorIterator>(params);
|
||||
res->body = body;
|
||||
|
||||
// Port map: outputs
|
||||
for (const auto& desc : sub_graph->get_output_descriptions()) {
|
||||
auto body_output_idx = desc->m_body_value_index;
|
||||
|
||||
std::string type_name = desc->get_type_info().name;
|
||||
if (type_name == "ConcatOutputDescription") {
|
||||
auto output_desc = ::ngraph::as_type_ptr<ngraph::op::util::SubGraphOp::ConcatOutputDescription>(desc);
|
||||
IE_ASSERT(output_desc != nullptr);
|
||||
|
||||
res->output_port_map.emplace_back(InferenceEngine::TensorIterator::PortMap {
|
||||
static_cast<int>(output_desc->m_output_index), static_cast<int>(body_output_idx),
|
||||
static_cast<int>(output_desc->m_axis), static_cast<int>(output_desc->m_stride),
|
||||
static_cast<int>(output_desc->m_start), static_cast<int>(output_desc->m_end),
|
||||
static_cast<int>(output_desc->m_part_size)});
|
||||
} else if (type_name == "BodyOutputDescription") {
|
||||
auto output_desc = ::ngraph::as_type_ptr<ngraph::op::util::SubGraphOp::BodyOutputDescription>(desc);
|
||||
IE_ASSERT(output_desc != nullptr);
|
||||
|
||||
res->output_port_map.emplace_back(InferenceEngine::TensorIterator::PortMap {
|
||||
static_cast<int>(output_desc->m_output_index), static_cast<int>(body_output_idx), -1, 1, 0, -1, 1});
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Incorrect type of the output description.";
|
||||
}
|
||||
}
|
||||
|
||||
// Port map : inputs and back edges
|
||||
for (const auto& desc : sub_graph->get_input_descriptions()) {
|
||||
auto body_input_index = desc->m_body_parameter_index;
|
||||
|
||||
if (const auto slice_desc = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp::SliceInputDescription>(desc)) {
|
||||
res->input_port_map.emplace_back(InferenceEngine::TensorIterator::PortMap {
|
||||
static_cast<int>(slice_desc->m_input_index), static_cast<int>(body_input_index),
|
||||
static_cast<int>(slice_desc->m_axis), static_cast<int>(slice_desc->m_stride),
|
||||
static_cast<int>(slice_desc->m_start), static_cast<int>(slice_desc->m_end),
|
||||
static_cast<int>(slice_desc->m_part_size)});
|
||||
} else if (const auto merge_desc = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp::MergedInputDescription>(desc)) {
|
||||
res->input_port_map.emplace_back(InferenceEngine::TensorIterator::PortMap {
|
||||
static_cast<int>(merge_desc->m_input_index), static_cast<int>(body_input_index), -1, 1, 0, -1, 1});
|
||||
|
||||
auto body_output_idx = merge_desc->m_body_value_index;
|
||||
|
||||
res->back_edges.emplace_back(InferenceEngine::TensorIterator::PortMap {
|
||||
static_cast<int>(body_output_idx), static_cast<int>(body_input_index), -1, 1, 0, -1, 1});
|
||||
} else if (const auto inv_desc = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp::InvariantInputDescription>(desc)) {
|
||||
res->input_port_map.emplace_back(InferenceEngine::TensorIterator::PortMap {
|
||||
static_cast<int>(inv_desc->m_input_index), static_cast<int>(body_input_index), -1, 1, 0, -1, 1});
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Incorrect type of the input description.";
|
||||
}
|
||||
}
|
||||
|
||||
if (const auto loop_op = std::dynamic_pointer_cast<const ngraph::opset5::Loop>(layer)) {
|
||||
auto spec_port = loop_op->get_special_body_ports();
|
||||
if (spec_port.current_iteration_input_idx != -1) {
|
||||
auto ie_port_idx = spec_port.current_iteration_input_idx;
|
||||
res->params["loop_body_current_iteration_idx"] = std::to_string(ie_port_idx);
|
||||
}
|
||||
if (spec_port.body_condition_output_idx != -1) {
|
||||
auto body_output_idx = spec_port.body_condition_output_idx;
|
||||
res->params["loop_body_condition_output_idx"] = std::to_string(body_output_idx);
|
||||
}
|
||||
res->params["loop_trip_count_idx"] = "0";
|
||||
res->params["loop_execution_condition_idx"] = "1";
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Creator for CNNLayer from nGraph op
|
||||
*/
|
||||
@ -134,6 +341,9 @@ public:
|
||||
params[name] = joinVec(data);
|
||||
}
|
||||
|
||||
void on_adapter(const std::string& name, ::ngraph::ValueAccessor<std::shared_ptr<ngraph::Function>>& adapter) override {
|
||||
}
|
||||
|
||||
void on_adapter(const std::string& name, ::ngraph::ValueAccessor<void>& adapter) override;
|
||||
|
||||
void on_adapter(const std::string& name, ::ngraph::ValueAccessor<void*>& adapter) override {
|
||||
@ -171,6 +381,10 @@ void InferenceEngine::details::CNNLayerCreator::on_adapter(const std::string& na
|
||||
} else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<std::vector<size_t>>>(& adapter)) {
|
||||
auto data = a->get();
|
||||
params[name] = joinVec(data);
|
||||
} else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<std::vector<std::shared_ptr<
|
||||
ngraph::op::util::SubGraphOp::InputDescription>>>>(& adapter)) {
|
||||
} else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<std::vector<std::shared_ptr<
|
||||
ngraph::op::util::SubGraphOp::OutputDescription>>>>(& adapter)) {
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Error converting ngraph to CNN network. "
|
||||
"Attribute adapter can not be found for " << name << " parameter";
|
||||
@ -1300,6 +1514,12 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
|
||||
res->params["ctc_merge_repeated"] = res->getBoolStrParamAsIntStr("ctc_merge_repeated");
|
||||
return res;
|
||||
});
|
||||
|
||||
addSpecificCreator({"TensorIterator"}, [](const std::shared_ptr<::ngraph::Node>& node, const std::map<std::string, std::string>& params) -> CNNLayerPtr {
|
||||
auto res = createSubGraphLayer(node);
|
||||
res->type = "TensorIterator";
|
||||
return res;
|
||||
});
|
||||
}
|
||||
|
||||
CNNLayerPtr InferenceEngine::details::CNNLayerCreator::create() {
|
||||
@ -1344,7 +1564,6 @@ void convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::ScaleShiftIE>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::SquaredDifference>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::VariadicSplit>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::TensorIterator>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::opset5::Loop>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::ShuffleChannels>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::v4::Interpolate>>(),
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/variant.hpp>
|
||||
#include <ngraph/op/util/sub_graph_base.hpp>
|
||||
|
||||
#include <cpp/ie_cnn_network.h>
|
||||
#include "ie_blob_stream.hpp"
|
||||
@ -46,44 +47,223 @@ IRParser::IRParser(size_t version, const std::vector<InferenceEngine::IExtension
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ICNNNetwork> IRParser::parse(const pugi::xml_node& root, const Blob::CPtr& weights) {
|
||||
return parser->parse(root, weights);
|
||||
}
|
||||
void V10Parser::XmlDeserializer::map_type_in_function(const pugi::xml_node& node,
|
||||
const std::string map_type, std::map<uint64_t, uint64_t>& type_id_in_function) {
|
||||
uint64_t map_type_number = 0;
|
||||
auto body_node = node.child("body");
|
||||
|
||||
/**
|
||||
* Hold original blob in order to avoid situations when original blob is allocated on stack
|
||||
*/
|
||||
class WeightsHolderBlob : public TBlob<uint8_t> {
|
||||
Blob::CPtr originBlob;
|
||||
if (body_node.empty()) {
|
||||
THROW_IE_EXCEPTION << "Missing body part.";
|
||||
}
|
||||
|
||||
public:
|
||||
explicit WeightsHolderBlob(const Blob::CPtr& weights) :
|
||||
TBlob<uint8_t>(weights->getTensorDesc(),
|
||||
weights->cbuffer().as<uint8_t*>()),
|
||||
originBlob(weights) { }
|
||||
};
|
||||
// Fill map: parameter/result id to parameter/result number in Function
|
||||
FOREACH_CHILD(_layer, body_node.child("layers"), "layer") {
|
||||
auto type = XMLParseUtils::GetStrAttr(_layer, "type");
|
||||
|
||||
V10Parser::V10Parser(const std::vector<IExtensionPtr>& exts) : _exts(exts) {
|
||||
// Load default opsets
|
||||
opsets["opset1"] = ngraph::get_opset1();
|
||||
opsets["opset2"] = ngraph::get_opset2();
|
||||
opsets["opset3"] = ngraph::get_opset3();
|
||||
opsets["opset4"] = ngraph::get_opset4();
|
||||
opsets["opset5"] = ngraph::get_opset5();
|
||||
opsets["opset6"] = ngraph::get_opset6();
|
||||
|
||||
// Load custom opsets
|
||||
for (const auto& ext : exts) {
|
||||
std::map<std::string, ngraph::OpSet> extOpsets = ext->getOpSets();
|
||||
for (const auto& it : extOpsets) {
|
||||
if (opsets.find(it.first) != opsets.end())
|
||||
THROW_IE_EXCEPTION << "Cannot add opset with name: " << it.first << ". Opset with the same name already exists.";
|
||||
opsets[it.first] = it.second;
|
||||
if (type == map_type) {
|
||||
auto id = XMLParseUtils::GetUIntAttr(_layer, "id");
|
||||
type_id_in_function[id] = map_type_number;
|
||||
map_type_number++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ICNNNetwork> V10Parser::parse(const pugi::xml_node& root, const Blob::CPtr& weights) {
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>> V10Parser::XmlDeserializer::parseInputDescription(const pugi::xml_node& node) {
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>> inputs;
|
||||
std::map<uint64_t, uint64_t> param_id_in_function;
|
||||
std::map<uint64_t, uint64_t> result_id_in_function;
|
||||
map_type_in_function(node, "Parameter", param_id_in_function);
|
||||
map_type_in_function(node, "Result", result_id_in_function);
|
||||
|
||||
// Parse PortMap: external_port_id for inputs does not always appear in consecutive order
|
||||
std::map<uint64_t, pugi::xml_node> input_map;
|
||||
FOREACH_CHILD(_input, node.child("port_map"), "input") {
|
||||
int64_t ext_port_id = GetInt64Attr(_input, "external_port_id");
|
||||
input_map[ext_port_id] = _input;
|
||||
}
|
||||
|
||||
for (const auto& input : input_map) {
|
||||
auto &_input = input.second;
|
||||
auto axis_attr = _input.attribute("axis");
|
||||
auto purpose = XMLParseUtils::GetStrAttr(_input, "purpose", "");
|
||||
int64_t ti_input_index = XMLParseUtils::GetInt64Attr(_input, "external_port_id");
|
||||
size_t body_parameter_index = XMLParseUtils::GetUIntAttr(_input, "internal_layer_id");
|
||||
|
||||
// if axis is set, then slicing is enabled. Create ngraph::TensorIterator::SlicedInput.
|
||||
if (!axis_attr.empty()) {
|
||||
size_t axis = XMLParseUtils::GetUIntAttr(_input, "axis");
|
||||
int64_t start = XMLParseUtils::GetInt64Attr(_input, "start", 0);
|
||||
int64_t stride = XMLParseUtils::GetInt64Attr(_input, "stride", 1);
|
||||
int64_t end = XMLParseUtils::GetInt64Attr(_input, "end", -1);
|
||||
int64_t part_size = XMLParseUtils::GetInt64Attr(_input, "part_size", 1);
|
||||
|
||||
inputs.push_back(std::make_shared<ngraph::op::util::SubGraphOp::SliceInputDescription>
|
||||
(ti_input_index,
|
||||
param_id_in_function[body_parameter_index],
|
||||
start,
|
||||
stride,
|
||||
part_size,
|
||||
end,
|
||||
axis));
|
||||
} else {
|
||||
// otherwise find corresponding back edge and create ngraph::TensorIterator::MergedInput
|
||||
bool is_back_edge_exist = false;
|
||||
FOREACH_CHILD(_edge, node.child("back_edges"), "edge") {
|
||||
size_t to_layer = XMLParseUtils::GetUIntAttr(_edge, "to-layer");
|
||||
|
||||
if (to_layer == body_parameter_index) {
|
||||
size_t from_layer = XMLParseUtils::GetUIntAttr(_edge, "from-layer");
|
||||
inputs.push_back(std::make_shared<ngraph::op::util::SubGraphOp::MergedInputDescription>
|
||||
(ti_input_index,
|
||||
param_id_in_function[body_parameter_index],
|
||||
result_id_in_function[from_layer]));
|
||||
|
||||
is_back_edge_exist = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// ti_input_index = -1 means that Parameter of the body is not connected to inputs of TensorIterator
|
||||
// and is used only for internal needs.
|
||||
if (!is_back_edge_exist && ti_input_index >= 0) {
|
||||
inputs.push_back(std::make_shared<ngraph::op::util::SubGraphOp::InvariantInputDescription>
|
||||
(ti_input_index,
|
||||
param_id_in_function[body_parameter_index]));
|
||||
}
|
||||
}
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>> V10Parser::XmlDeserializer::parseOutputDescription(const pugi::xml_node& node) {
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>> outputs;
|
||||
std::map<uint64_t, uint64_t> result_id_in_function;
|
||||
map_type_in_function(node, "Result", result_id_in_function);
|
||||
|
||||
// Parse PortMap: outputs
|
||||
std::map<int64_t, pugi::xml_node> output_map;
|
||||
FOREACH_CHILD(_output, node.child("port_map"), "output") {
|
||||
int64_t ext_port_id = GetInt64Attr(_output, "external_port_id");
|
||||
output_map[ext_port_id] = _output;
|
||||
}
|
||||
|
||||
uint64_t output_number = 0;
|
||||
for (const auto& output : output_map) {
|
||||
auto& _output = output.second;
|
||||
auto axis_attr = _output.attribute("axis");
|
||||
auto purpose = XMLParseUtils::GetStrAttr(_output, "purpose", "");
|
||||
size_t body_result_index = XMLParseUtils::GetUIntAttr(_output, "internal_layer_id");
|
||||
|
||||
// if axis is set, then concatenation is enabled. Create ngraph::TensorIterator::ConcatOutput.
|
||||
if (!axis_attr.empty()) {
|
||||
int64_t axis = XMLParseUtils::GetInt64Attr(_output, "axis");
|
||||
int64_t start = XMLParseUtils::GetInt64Attr(_output, "start", 0);
|
||||
int64_t stride = XMLParseUtils::GetInt64Attr(_output, "stride", 1);
|
||||
int64_t end = XMLParseUtils::GetInt64Attr(_output, "end", -1);
|
||||
int64_t part_size = XMLParseUtils::GetInt64Attr(_output, "part_size", 1);
|
||||
|
||||
outputs.push_back(std::make_shared<ngraph::op::util::SubGraphOp::ConcatOutputDescription>
|
||||
(result_id_in_function[body_result_index],
|
||||
output_number,
|
||||
start,
|
||||
stride,
|
||||
part_size,
|
||||
end,
|
||||
axis));
|
||||
} else {
|
||||
// otherwise create ngraph::TensorIterator::BodyOutput. -1 means last iteration.
|
||||
outputs.push_back(std::make_shared<ngraph::op::util::SubGraphOp::BodyOutputDescription>
|
||||
(result_id_in_function[body_result_index],
|
||||
output_number,
|
||||
-1));
|
||||
}
|
||||
output_number++;
|
||||
}
|
||||
return outputs;
|
||||
}
|
||||
|
||||
void V10Parser::XmlDeserializer::on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
|
||||
std::string val;
|
||||
|
||||
// for TensorIterator look for 'port_map' as 'data' does not exist
|
||||
if (node.child("port_map")) {
|
||||
if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::vector<std::shared_ptr
|
||||
<ngraph::op::util::SubGraphOp::InputDescription>>>>(&adapter)) {
|
||||
a->set(parseInputDescription(node));
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::vector<std::shared_ptr
|
||||
<ngraph::op::util::SubGraphOp::OutputDescription>>>>(&adapter)) {
|
||||
a->set(parseOutputDescription(node));
|
||||
}
|
||||
}
|
||||
|
||||
if (!getStrAttribute(node.child("data"), name, val)) return;
|
||||
if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::element::Type>>(&adapter)) {
|
||||
static_cast<ngraph::element::Type&>(*a) = details::convertPrecision(val);
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::PartialShape>>(&adapter)) {
|
||||
std::vector<int64_t> shape;
|
||||
std::vector<ngraph::Dimension> dims;
|
||||
if (!getParameters<int64_t>(node.child("data"), name, shape)) return;
|
||||
for (const auto& dim : shape) dims.emplace_back(dim);
|
||||
static_cast<ngraph::PartialShape&>(*a) = ngraph::PartialShape(dims);
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::Shape>>(&adapter)) {
|
||||
std::vector<size_t> shape;
|
||||
if (!getParameters<size_t>(node.child("data"), name, shape)) return;
|
||||
static_cast<ngraph::Shape&>(*a) = ngraph::Shape(shape);
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::Strides>>(&adapter)) {
|
||||
std::vector<size_t> shape;
|
||||
if (!getParameters<size_t>(node.child("data"), name, shape)) return;
|
||||
static_cast<ngraph::Strides&>(*a) = ngraph::Strides(shape);
|
||||
#ifdef __APPLE__
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::vector<size_t>>>(&adapter)) {
|
||||
std::vector<size_t> result;
|
||||
if (!getParameters<size_t>(node.child("data"), name, result)) return;
|
||||
static_cast<std::vector<size_t>&>(*a) = result;
|
||||
#else
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::vector<size_t>>>(&adapter)) {
|
||||
std::vector<size_t> result;
|
||||
if (!getParameters<size_t>(node.child("data"), name, result)) return;
|
||||
a->set(result);
|
||||
#endif
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::AxisSet>>(&adapter)) {
|
||||
std::vector<size_t> axes;
|
||||
if (!getParameters<size_t>(node.child("data"), name, axes)) return;
|
||||
static_cast<ngraph::AxisSet&>(*a) = ngraph::AxisSet(axes);
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::TopKSortType>>(&adapter)) {
|
||||
if (!getStrAttribute(node.child("data"), name, val)) return;
|
||||
static_cast<ngraph::op::TopKSortType&>(*a) = ngraph::as_enum<ngraph::op::TopKSortType>(val);
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::TopKMode>>(&adapter)) {
|
||||
if (!getStrAttribute(node.child("data"), name, val)) return;
|
||||
static_cast<ngraph::op::TopKMode&>(*a) = ngraph::as_enum<ngraph::op::TopKMode>(val);
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::CoordinateDiff>>(&adapter)) {
|
||||
std::vector<size_t> shape;
|
||||
if (!getParameters<size_t>(node.child("data"), name, shape)) return;
|
||||
std::vector<std::ptrdiff_t> coord_diff(shape.begin(), shape.end());
|
||||
static_cast<ngraph::CoordinateDiff&>(*a) = ngraph::CoordinateDiff(coord_diff);
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Error IR reading. Attribute adapter can not be found for " << name
|
||||
<< " parameter";
|
||||
}
|
||||
}
|
||||
|
||||
void V10Parser::XmlDeserializer::on_adapter(const std::string& name, ngraph::ValueAccessor<std::shared_ptr<ngraph::Function>>& adapter) {
|
||||
std::shared_ptr<ngraph::Function> ngraph_function;
|
||||
if (!name.compare("body")) {
|
||||
auto body_node = node.child(name.c_str());
|
||||
if (body_node.empty()) {
|
||||
THROW_IE_EXCEPTION << "TensorIterator has no body.";
|
||||
}
|
||||
ngraph_function = parse_function(node.child(name.c_str()), weights);
|
||||
} else if (!name.compare("net")) {
|
||||
ngraph_function = parse_function(node, weights);
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Error: not recognized adapter name: " << name << ".";
|
||||
}
|
||||
// Disabled reshape for generic operations in the TI body
|
||||
ngraph::op::GenericIE::DisableReshape noReshape(ngraph_function);
|
||||
adapter.set(ngraph_function);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> V10Parser::XmlDeserializer::parse_function(const pugi::xml_node& root, const Blob::CPtr& weights) {
|
||||
OV_ITT_TASK_CHAIN(taskChain, itt::domains::V10Reader_RT, "V10Parser", "Parse");
|
||||
|
||||
using node_params = struct {
|
||||
@ -202,8 +382,53 @@ std::shared_ptr<ICNNNetwork> V10Parser::parse(const pugi::xml_node& root, const
|
||||
|
||||
OV_ITT_TASK_NEXT(taskChain, "ConstructCNNNetwork");
|
||||
|
||||
CNNNetwork net(function, _exts);
|
||||
return function;
|
||||
}
|
||||
|
||||
std::shared_ptr<ICNNNetwork> IRParser::parse(const pugi::xml_node& root, const Blob::CPtr& weights) {
|
||||
return parser->parse(root, weights);
|
||||
}
|
||||
|
||||
/**
|
||||
* Hold original blob in order to avoid situations when original blob is allocated on stack
|
||||
*/
|
||||
class WeightsHolderBlob : public TBlob<uint8_t> {
|
||||
Blob::CPtr originBlob;
|
||||
|
||||
public:
|
||||
explicit WeightsHolderBlob(const Blob::CPtr& weights) :
|
||||
TBlob<uint8_t>(weights->getTensorDesc(),
|
||||
weights->cbuffer().as<uint8_t*>()),
|
||||
originBlob(weights) { }
|
||||
};
|
||||
|
||||
V10Parser::V10Parser(const std::vector<IExtensionPtr>& exts) : _exts(exts) {
|
||||
// Load default opsets
|
||||
opsets["opset1"] = ngraph::get_opset1();
|
||||
opsets["opset2"] = ngraph::get_opset2();
|
||||
opsets["opset3"] = ngraph::get_opset3();
|
||||
opsets["opset4"] = ngraph::get_opset4();
|
||||
opsets["opset5"] = ngraph::get_opset5();
|
||||
opsets["opset6"] = ngraph::get_opset6();
|
||||
|
||||
// Load custom opsets
|
||||
for (const auto& ext : exts) {
|
||||
for (const auto& it : ext->getOpSets()) {
|
||||
if (opsets.find(it.first) != opsets.end())
|
||||
THROW_IE_EXCEPTION << "Cannot add opset with name: " << it.first << ". Opset with the same name already exists.";
|
||||
opsets[it.first] = it.second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ICNNNetwork> V10Parser::parse(const pugi::xml_node& root, const Blob::CPtr& weights) {
|
||||
OV_ITT_TASK_CHAIN(taskChain, itt::domains::V10Reader_RT, "V10Parser", "Parse");
|
||||
|
||||
std::shared_ptr<ngraph::Function> function;
|
||||
XmlDeserializer visitor(root, weights, opsets);
|
||||
visitor.on_attribute("net", function);
|
||||
|
||||
CNNNetwork net(function, _exts);
|
||||
parsePreProcess(net, root, weights);
|
||||
|
||||
return net;
|
||||
@ -335,7 +560,7 @@ void V10Parser::parsePreProcess(CNNNetwork& network, const pugi::xml_node& root,
|
||||
}
|
||||
}
|
||||
|
||||
V10Parser::GenericLayerParams V10Parser::parseGenericParams(const pugi::xml_node& node) {
|
||||
V10Parser::GenericLayerParams V10Parser::XmlDeserializer::parseGenericParams(const pugi::xml_node& node) {
|
||||
const auto parsePort = [](const pugi::xml_node& parentNode,
|
||||
const GenericLayerParams& params,
|
||||
bool input) -> GenericLayerParams::LayerPortData {
|
||||
@ -392,8 +617,10 @@ bool V10Parser::LayerBaseCreator::shouldCreate(const std::string& nodeType) cons
|
||||
return comparator(nodeType, type);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Output<ngraph::Node>>& inputs,
|
||||
const pugi::xml_node& node, const Blob::CPtr& weights,
|
||||
std::shared_ptr<ngraph::Node> V10Parser::XmlDeserializer::createNode(
|
||||
const std::vector<ngraph::Output<ngraph::Node>>& inputs,
|
||||
const pugi::xml_node& node,
|
||||
const Blob::CPtr& weights,
|
||||
const GenericLayerParams& params) {
|
||||
static std::vector<std::shared_ptr<LayerBaseCreator>> creators = {
|
||||
std::make_shared<LayerCreator<ngraph::op::v1::DeformableConvolution>>("DeformableConvolution"),
|
||||
@ -412,7 +639,6 @@ std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Ou
|
||||
std::make_shared<LayerCreator<ngraph::op::Result>>("Result"),
|
||||
std::make_shared<LayerCreator<ngraph::op::PSROIPooling>>("PSROIPooling"),
|
||||
std::make_shared<LayerCreator<ngraph::op::VariadicSplit>>("VariadicSplit"),
|
||||
std::make_shared<LayerCreator<ngraph::op::TensorIterator>>("TensorIterator"),
|
||||
std::make_shared<LayerCreator<ngraph::opset5::Loop>>("Loop"),
|
||||
std::make_shared<LayerCreator<ngraph::op::v1::LogicalAnd>>("LogicalAnd"),
|
||||
std::make_shared<LayerCreator<ngraph::op::v1::LogicalOr>>("LogicalOr"),
|
||||
@ -483,7 +709,7 @@ std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Ou
|
||||
ngraphNode = std::shared_ptr<ngraph::Node>(opset.create_insensitive(type));
|
||||
ngraphNode->set_friendly_name(params.name);
|
||||
ngraphNode->set_arguments(inputs);
|
||||
XmlDeserializer visitor(node, weights);
|
||||
XmlDeserializer visitor(node, weights, opsets);
|
||||
if (ngraphNode->visit_attributes(visitor))
|
||||
ngraphNode->constructor_validate_and_infer_types();
|
||||
}
|
||||
|
@ -8,6 +8,7 @@
|
||||
# include <ngraph/node.hpp>
|
||||
# include <ngraph/op/util/sub_graph_base.hpp>
|
||||
# include <ie_ngraph_utils.hpp>
|
||||
# include <ngraph/opsets/opset.hpp>
|
||||
#endif // IR_READER_V10
|
||||
|
||||
#include <ie_blob.h>
|
||||
@ -51,7 +52,6 @@ public:
|
||||
};
|
||||
|
||||
#ifdef IR_READER_V10
|
||||
|
||||
class V10Parser : public IParser {
|
||||
public:
|
||||
explicit V10Parser(const std::vector<IExtensionPtr>& exts);
|
||||
@ -171,10 +171,6 @@ private:
|
||||
}
|
||||
};
|
||||
|
||||
std::shared_ptr<ngraph::Node> createNode(const ngraph::OutputVector& inputs, const pugi::xml_node& node,
|
||||
const Blob::CPtr& weights, const GenericLayerParams& params);
|
||||
|
||||
GenericLayerParams parseGenericParams(const pugi::xml_node& node);
|
||||
void parsePreProcess(CNNNetwork& network, const pugi::xml_node& root, const Blob::CPtr& weights);
|
||||
|
||||
std::map<std::string, DataPtr> portsToData;
|
||||
@ -182,7 +178,8 @@ private:
|
||||
|
||||
class XmlDeserializer : public ngraph::AttributeVisitor {
|
||||
public:
|
||||
explicit XmlDeserializer(const pugi::xml_node& node, const Blob::CPtr& weights): node(node), weights(weights) {}
|
||||
explicit XmlDeserializer(const pugi::xml_node& node, const Blob::CPtr& weights,
|
||||
const std::map<std::string, ngraph::OpSet>& opsets) : node(node), weights(weights), opsets(opsets) {}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::string>& value) override {
|
||||
std::string val;
|
||||
if (!getStrAttribute(node.child("data"), name, val)) return;
|
||||
@ -203,56 +200,7 @@ private:
|
||||
if (!is_true && !is_false) return;
|
||||
value.set(is_true);
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override {
|
||||
std::string val;
|
||||
if (!getStrAttribute(node.child("data"), name, val)) return;
|
||||
if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::element::Type>>(&adapter)) {
|
||||
static_cast<ngraph::element::Type&>(*a) = details::convertPrecision(val);
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::PartialShape>>(&adapter)) {
|
||||
std::vector<int64_t> shape;
|
||||
std::vector<ngraph::Dimension> dims;
|
||||
if (!getParameters<int64_t>(node.child("data"), name, shape)) return;
|
||||
for (const auto& dim : shape) dims.emplace_back(dim);
|
||||
static_cast<ngraph::PartialShape&>(*a) = ngraph::PartialShape(dims);
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::Shape>>(&adapter)) {
|
||||
std::vector<size_t> shape;
|
||||
if (!getParameters<size_t>(node.child("data"), name, shape)) return;
|
||||
static_cast<ngraph::Shape&>(*a) = ngraph::Shape(shape);
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::Strides>>(&adapter)) {
|
||||
std::vector<size_t> shape;
|
||||
if (!getParameters<size_t>(node.child("data"), name, shape)) return;
|
||||
static_cast<ngraph::Strides&>(*a) = ngraph::Strides(shape);
|
||||
#ifdef __APPLE__
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::vector<size_t>>>(&adapter)) {
|
||||
std::vector<size_t> result;
|
||||
if (!getParameters<size_t>(node.child("data"), name, result)) return;
|
||||
static_cast<std::vector<size_t>&>(*a) = result;
|
||||
#else
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::vector<size_t>>>(&adapter)) {
|
||||
std::vector<size_t> result;
|
||||
if (!getParameters<size_t>(node.child("data"), name, result)) return;
|
||||
a->set(result);
|
||||
#endif
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::AxisSet>>(&adapter)) {
|
||||
std::vector<size_t> axes;
|
||||
if (!getParameters<size_t>(node.child("data"), name, axes)) return;
|
||||
static_cast<ngraph::AxisSet&>(*a) = ngraph::AxisSet(axes);
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::TopKSortType>>(&adapter)) {
|
||||
if (!getStrAttribute(node.child("data"), name, val)) return;
|
||||
static_cast<ngraph::op::TopKSortType&>(*a) = ngraph::as_enum<ngraph::op::TopKSortType>(val);
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::TopKMode>>(&adapter)) {
|
||||
if (!getStrAttribute(node.child("data"), name, val)) return;
|
||||
static_cast<ngraph::op::TopKMode&>(*a) = ngraph::as_enum<ngraph::op::TopKMode>(val);
|
||||
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::CoordinateDiff>>(&adapter)) {
|
||||
std::vector<size_t> shape;
|
||||
if (!getParameters<size_t>(node.child("data"), name, shape)) return;
|
||||
std::vector<std::ptrdiff_t> coord_diff(shape.begin(), shape.end());
|
||||
static_cast<ngraph::CoordinateDiff&>(*a) = ngraph::CoordinateDiff(coord_diff);
|
||||
} else {
|
||||
THROW_IE_EXCEPTION << "Error IR reading. Attribute adapter can not be found for " << name
|
||||
<< " parameter";
|
||||
}
|
||||
}
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override;
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<double>& adapter) override {
|
||||
std::string val;
|
||||
if (!getStrAttribute(node.child("data"), name, val))
|
||||
@ -307,6 +255,8 @@ private:
|
||||
adapter.set(value);
|
||||
}
|
||||
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::shared_ptr<ngraph::Function>>& adapter) override;
|
||||
|
||||
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override {
|
||||
std::vector<int32_t> value;
|
||||
if (!getParameters<int32_t>(node.child("data"), name, value)) return;
|
||||
@ -334,6 +284,32 @@ private:
|
||||
private:
|
||||
const pugi::xml_node node;
|
||||
const Blob::CPtr& weights;
|
||||
const std::map<std::string, ngraph::OpSet>& opsets;
|
||||
/// \brief Traverses port_map in order to create vector of InputDescription shared_ptrs.
|
||||
/// Shall be used only for ops which have port_map attribute.
|
||||
/// \param node xml op representation
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>> parseInputDescription(
|
||||
const pugi::xml_node& node);
|
||||
/// \brief Traverses port_map in order to create vector of OutputDescription shared_ptrs.
|
||||
/// Shall be used only for ops which have port_map attribute.
|
||||
/// \param node xml op representation
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>> parseOutputDescription(
|
||||
const pugi::xml_node& node);
|
||||
/// \brief Traverses nGraph body function for specified op type and creates a map of all
|
||||
/// op iterations. Map constains type id and assigned to it consecutive number starting from 0.
|
||||
/// \param node xml op representation
|
||||
/// \param type op type name to find
|
||||
/// \param type_id_in_function map container
|
||||
void map_type_in_function(const pugi::xml_node& node, std::string type, std::map<uint64_t, uint64_t>& type_id_in_function);
|
||||
/// \brief Traverses xml node representation in order to create nGraph function for it.
|
||||
/// \param node xml node representation
|
||||
/// \param weights weights attached to current node
|
||||
/// \return shared pointer to function representing input node
|
||||
std::shared_ptr<ngraph::Function> parse_function(const pugi::xml_node& root, const Blob::CPtr& weights);
|
||||
|
||||
GenericLayerParams parseGenericParams(const pugi::xml_node& node);
|
||||
std::shared_ptr<ngraph::Node> createNode(const ngraph::OutputVector& inputs, const pugi::xml_node& node,
|
||||
const Blob::CPtr& weights, const GenericLayerParams& params);
|
||||
|
||||
bool getStrAttribute(const pugi::xml_node& node, const std::string& name, std::string& value) {
|
||||
if (!node) return false;
|
||||
|
@ -30,6 +30,7 @@ namespace ngraph
|
||||
class ValueAccessor;
|
||||
class VisitorAdapter;
|
||||
class Node;
|
||||
class Function;
|
||||
|
||||
/// \brief Visits the attributes of a node, primarily for serialization-like tasks.
|
||||
///
|
||||
@ -116,6 +117,12 @@ namespace ngraph
|
||||
/// \brief Hook for adapters that need visitor access
|
||||
virtual void on_adapter(const std::string& name, VisitorAdapter& adapter);
|
||||
|
||||
/// \brief Provides API to handle nGraph Function attribute type, accessed as ValueAccessor
|
||||
/// \param name attribute name
|
||||
/// \param adapter reference to a Function ValueAccessor<VAT>
|
||||
virtual void on_adapter(const std::string& name,
|
||||
ValueAccessor<std::shared_ptr<Function>>& adapter);
|
||||
|
||||
/// The generic visitor. There must be a definition of AttributeAdapter<T> that can convert
|
||||
/// to a ValueAccessor<U> for one of the on_adpater methods.
|
||||
template <typename AT>
|
||||
|
@ -190,16 +190,17 @@ namespace ngraph
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<std::shared_ptr<Function>> : public VisitorAdapter
|
||||
class NGRAPH_API AttributeAdapter<std::shared_ptr<Function>>
|
||||
: public DirectValueAccessor<std::shared_ptr<Function>>
|
||||
{
|
||||
public:
|
||||
AttributeAdapter(std::shared_ptr<Function>& ref);
|
||||
AttributeAdapter(std::shared_ptr<Function>& value)
|
||||
: DirectValueAccessor<std::shared_ptr<Function>>(value)
|
||||
{
|
||||
}
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<shared_ptr<Function>>", 0};
|
||||
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Function>>",
|
||||
0};
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
protected:
|
||||
std::shared_ptr<Function>& m_ref;
|
||||
};
|
||||
}
|
||||
|
@ -17,7 +17,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/op/parameter.hpp>
|
||||
#include "ngraph/factory_adapter.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph
|
||||
@ -50,7 +49,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<InputDescription> copy() const = 0;
|
||||
|
||||
virtual const type_info_t& get_type_info() const = 0;
|
||||
virtual bool visit_attributes(AttributeVisitor& visitor);
|
||||
|
||||
uint64_t m_input_index{0};
|
||||
uint64_t m_body_parameter_index{0};
|
||||
@ -85,7 +83,6 @@ namespace ngraph
|
||||
int64_t axis);
|
||||
SliceInputDescription() = default;
|
||||
std::shared_ptr<InputDescription> copy() const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
int64_t m_start{0};
|
||||
int64_t m_stride{0};
|
||||
int64_t m_part_size{0};
|
||||
@ -118,7 +115,6 @@ namespace ngraph
|
||||
uint64_t body_value_index);
|
||||
MergedInputDescription() = default;
|
||||
std::shared_ptr<InputDescription> copy() const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
uint64_t m_body_value_index{0};
|
||||
};
|
||||
|
||||
@ -140,7 +136,6 @@ namespace ngraph
|
||||
InvariantInputDescription(uint64_t input_index, uint64_t body_parameter_index);
|
||||
InvariantInputDescription() = default;
|
||||
std::shared_ptr<InputDescription> copy() const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
};
|
||||
|
||||
/// \brief Describes how a SubGraphOp output is produced from the body.
|
||||
@ -160,7 +155,6 @@ namespace ngraph
|
||||
using type_info_t = DiscreteTypeInfo;
|
||||
virtual ~OutputDescription() = default;
|
||||
virtual std::shared_ptr<OutputDescription> copy() const = 0;
|
||||
virtual bool visit_attributes(AttributeVisitor& visitor);
|
||||
virtual const type_info_t& get_type_info() const = 0;
|
||||
|
||||
uint64_t m_body_value_index{0};
|
||||
@ -194,7 +188,6 @@ namespace ngraph
|
||||
ConcatOutputDescription() = default;
|
||||
|
||||
std::shared_ptr<OutputDescription> copy() const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
int64_t m_start{0};
|
||||
int64_t m_stride{0};
|
||||
int64_t m_part_size{0};
|
||||
@ -221,7 +214,6 @@ namespace ngraph
|
||||
int64_t iteration);
|
||||
BodyOutputDescription() = default;
|
||||
std::shared_ptr<OutputDescription> copy() const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
int64_t m_iteration{0};
|
||||
};
|
||||
|
||||
@ -347,79 +339,48 @@ namespace ngraph
|
||||
using OutputDescriptionVector = std::vector<OutputDescriptionPtr>;
|
||||
}
|
||||
}
|
||||
template class NGRAPH_API FactoryRegistry<op::util::SubGraphOp::InputDescription>;
|
||||
|
||||
template <>
|
||||
FactoryRegistry<op::util::SubGraphOp::InputDescription>&
|
||||
FactoryRegistry<op::util::SubGraphOp::InputDescription>::get();
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::InputDescription>>
|
||||
: public FactoryAttributeAdapter<op::util::SubGraphOp::InputDescription>
|
||||
class NGRAPH_API AttributeAdapter<
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>
|
||||
: public DirectValueAccessor<
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>
|
||||
{
|
||||
public:
|
||||
using FactoryAttributeAdapter::FactoryAttributeAdapter;
|
||||
AttributeAdapter(
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>& value)
|
||||
: DirectValueAccessor<
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>(
|
||||
value)
|
||||
{
|
||||
}
|
||||
|
||||
static constexpr DiscreteTypeInfo type_info{
|
||||
"AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::InputDescription>>"
|
||||
">>",
|
||||
"AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::"
|
||||
"InputDescription>>>",
|
||||
0};
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API
|
||||
AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>>
|
||||
: public VisitorAdapter
|
||||
class NGRAPH_API AttributeAdapter<
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>
|
||||
: public DirectValueAccessor<
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>
|
||||
{
|
||||
public:
|
||||
explicit AttributeAdapter(
|
||||
std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>& ref);
|
||||
AttributeAdapter(
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>& value)
|
||||
: DirectValueAccessor<
|
||||
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>(
|
||||
value)
|
||||
{
|
||||
}
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
static constexpr DiscreteTypeInfo type_info{
|
||||
"AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>"
|
||||
">>",
|
||||
"AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::"
|
||||
"OutputDescription>>>",
|
||||
0};
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
protected:
|
||||
std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>& m_ref;
|
||||
};
|
||||
|
||||
template class NGRAPH_API FactoryRegistry<op::util::SubGraphOp::OutputDescription>;
|
||||
|
||||
template <>
|
||||
FactoryRegistry<op::util::SubGraphOp::OutputDescription>&
|
||||
FactoryRegistry<op::util::SubGraphOp::OutputDescription>::get();
|
||||
|
||||
template <>
|
||||
class NGRAPH_API AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>
|
||||
: public FactoryAttributeAdapter<op::util::SubGraphOp::OutputDescription>
|
||||
{
|
||||
public:
|
||||
using FactoryAttributeAdapter::FactoryAttributeAdapter;
|
||||
static constexpr DiscreteTypeInfo type_info{
|
||||
"AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>"
|
||||
">>",
|
||||
0};
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
};
|
||||
|
||||
template <>
|
||||
class NGRAPH_API
|
||||
AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>>
|
||||
: public VisitorAdapter
|
||||
{
|
||||
public:
|
||||
explicit AttributeAdapter(
|
||||
std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>& ref);
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
static constexpr DiscreteTypeInfo type_info{
|
||||
"AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>"
|
||||
">>",
|
||||
0};
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
protected:
|
||||
std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>& m_ref;
|
||||
};
|
||||
}
|
||||
|
@ -170,6 +170,12 @@ void AttributeVisitor::on_adapter(const string& name, ValueAccessor<std::vector<
|
||||
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
|
||||
}
|
||||
|
||||
void AttributeVisitor::on_adapter(const string& name,
|
||||
ValueAccessor<std::shared_ptr<ngraph::Function>>& adapter)
|
||||
{
|
||||
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
|
||||
}
|
||||
|
||||
const AttributeVisitor::node_id_t AttributeVisitor::invalid_node_id = "";
|
||||
|
||||
void AttributeVisitor::register_node(const std::shared_ptr<Node>& node, node_id_t id)
|
||||
|
@ -19,7 +19,6 @@
|
||||
#include <memory>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/factory_adapter.hpp"
|
||||
#include "ngraph/function.hpp"
|
||||
#include "ngraph/graph_util.hpp"
|
||||
#include "ngraph/log.hpp"
|
||||
@ -407,269 +406,3 @@ void Function::remove_result(const std::shared_ptr<op::Result>& result)
|
||||
}
|
||||
|
||||
constexpr DiscreteTypeInfo AttributeAdapter<shared_ptr<Function>>::type_info;
|
||||
|
||||
AttributeAdapter<shared_ptr<Function>>::AttributeAdapter(shared_ptr<Function>& ref)
|
||||
: m_ref(ref)
|
||||
{
|
||||
}
|
||||
|
||||
class NodeAttributeAdapter : public FactoryAttributeAdapter<Node>
|
||||
{
|
||||
public:
|
||||
using FactoryAttributeAdapter::FactoryAttributeAdapter;
|
||||
bool on_start(AttributeVisitor& visitor) override
|
||||
{
|
||||
// Indicate that there is a node following
|
||||
m_id = visitor.get_registered_node_id(m_ref);
|
||||
m_set_id = (m_ref == nullptr);
|
||||
visitor.on_attribute("id", m_id);
|
||||
return m_ref == nullptr || m_id != AttributeVisitor::invalid_node_id;
|
||||
}
|
||||
bool on_finish(AttributeVisitor&) override
|
||||
{
|
||||
if (m_set_id && m_ref)
|
||||
{
|
||||
m_ref->set_friendly_name(m_id);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
void visit(AttributeVisitor& visitor, const std::string& id)
|
||||
{
|
||||
visitor.start_structure(id);
|
||||
visitor.on_adapter(id, *this);
|
||||
visitor.finish_structure();
|
||||
}
|
||||
static constexpr DiscreteTypeInfo type_info{"Lambda.NodeAttributeAdapter", 0};
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
string m_id;
|
||||
bool m_set_id;
|
||||
};
|
||||
|
||||
constexpr DiscreteTypeInfo NodeAttributeAdapter::type_info;
|
||||
|
||||
bool AttributeAdapter<shared_ptr<Function>>::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
if (m_ref->get_results().size() > 0)
|
||||
{
|
||||
NodeVector serialized_nodes;
|
||||
{
|
||||
// Start with all nodes not already serialized
|
||||
visitor.start_structure("nodes");
|
||||
NodeVector results;
|
||||
for (auto result : m_ref->get_results())
|
||||
{
|
||||
results.push_back(result);
|
||||
}
|
||||
for (auto sink : m_ref->get_sinks())
|
||||
{
|
||||
results.push_back(sink);
|
||||
}
|
||||
|
||||
int64_t i = 0;
|
||||
ostringstream index;
|
||||
traverse_nodes(
|
||||
results, [&i, &index, &visitor, &serialized_nodes](shared_ptr<Node> node) -> void {
|
||||
if (AttributeVisitor::invalid_node_id == visitor.get_registered_node_id(node))
|
||||
{
|
||||
// This node hasn't been seen before
|
||||
visitor.register_node(node);
|
||||
index.str("");
|
||||
index << i++;
|
||||
string id = index.str();
|
||||
NodeAttributeAdapter adapter(node);
|
||||
adapter.visit(visitor, id);
|
||||
serialized_nodes.push_back(node);
|
||||
}
|
||||
});
|
||||
{
|
||||
// Sentinel at end
|
||||
index.str("");
|
||||
index << i++;
|
||||
string id = index.str();
|
||||
shared_ptr<Node> null_node;
|
||||
NodeAttributeAdapter adapter(null_node);
|
||||
adapter.visit(visitor, id);
|
||||
}
|
||||
visitor.finish_structure();
|
||||
}
|
||||
{
|
||||
// Now do all the edges
|
||||
visitor.start_structure("edges");
|
||||
int64_t i = 0;
|
||||
ostringstream index;
|
||||
for (auto node : serialized_nodes)
|
||||
{
|
||||
for (auto input : node->inputs())
|
||||
{
|
||||
index.str("");
|
||||
index << i++;
|
||||
string id = index.str();
|
||||
visitor.start_structure(id);
|
||||
string input_node_id = visitor.get_registered_node_id(node);
|
||||
uint64_t input_index = input.get_index();
|
||||
visitor.on_attribute("input_node", input_node_id);
|
||||
visitor.on_attribute("input_index", input_index);
|
||||
auto output = input.get_source_output();
|
||||
string output_node_id =
|
||||
visitor.get_registered_node_id(output.get_node_shared_ptr());
|
||||
uint64_t output_index = output.get_index();
|
||||
visitor.on_attribute("output_node", output_node_id);
|
||||
visitor.on_attribute("output_index", output_index);
|
||||
visitor.finish_structure();
|
||||
}
|
||||
}
|
||||
{
|
||||
// Add a sentinel
|
||||
index.str("");
|
||||
index << i++;
|
||||
string id = index.str();
|
||||
visitor.start_structure(id);
|
||||
string input_node_id = AttributeVisitor::invalid_node_id;
|
||||
visitor.on_attribute("input_node", input_node_id);
|
||||
visitor.finish_structure();
|
||||
}
|
||||
visitor.finish_structure();
|
||||
}
|
||||
{
|
||||
// Control dependencies
|
||||
visitor.start_structure("control");
|
||||
int64_t i = 0;
|
||||
ostringstream index;
|
||||
for (auto node : serialized_nodes)
|
||||
{
|
||||
for (auto control : node->get_control_dependencies())
|
||||
{
|
||||
index.str("");
|
||||
index << i++;
|
||||
string id = index.str();
|
||||
visitor.start_structure(id);
|
||||
string node_id = visitor.get_registered_node_id(node);
|
||||
string dependency_id = visitor.get_registered_node_id(control);
|
||||
visitor.on_attribute("node", node_id);
|
||||
visitor.on_attribute("dependency", dependency_id);
|
||||
visitor.finish_structure();
|
||||
}
|
||||
}
|
||||
{
|
||||
// Add a sentinel
|
||||
index.str("");
|
||||
index << i++;
|
||||
string id = index.str();
|
||||
visitor.start_structure(id);
|
||||
string node_id = AttributeVisitor::invalid_node_id;
|
||||
visitor.on_attribute("node", node_id);
|
||||
visitor.finish_structure();
|
||||
}
|
||||
visitor.finish_structure();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
NodeVector deserialized_nodes;
|
||||
{
|
||||
// Read the graph
|
||||
visitor.start_structure("nodes");
|
||||
int64_t i = 0;
|
||||
ostringstream index;
|
||||
while (true)
|
||||
{
|
||||
index.str("");
|
||||
index << i++;
|
||||
string id = index.str();
|
||||
shared_ptr<Node> node;
|
||||
NodeAttributeAdapter adapter(node);
|
||||
adapter.visit(visitor, id);
|
||||
if (node)
|
||||
{
|
||||
visitor.register_node(node);
|
||||
deserialized_nodes.push_back(node);
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
visitor.finish_structure();
|
||||
}
|
||||
{
|
||||
visitor.start_structure("edges");
|
||||
// Connect the nodes
|
||||
int64_t i = 0;
|
||||
ostringstream index;
|
||||
bool more_edges = true;
|
||||
while (more_edges)
|
||||
{
|
||||
index.str("");
|
||||
index << i++;
|
||||
string id = index.str();
|
||||
visitor.start_structure(id);
|
||||
string input_node_id;
|
||||
visitor.on_attribute("input_node", input_node_id);
|
||||
if (!input_node_id.empty())
|
||||
{
|
||||
shared_ptr<Node> input_node = visitor.get_registered_node(input_node_id);
|
||||
NGRAPH_CHECK(input_node, "input node of edge not known");
|
||||
uint64_t input_index;
|
||||
string output_node_id;
|
||||
uint64_t output_index;
|
||||
visitor.on_attribute("input_index", input_index);
|
||||
visitor.on_attribute("output_node", output_node_id);
|
||||
visitor.on_attribute("output_index", output_index);
|
||||
shared_ptr<Node> output_node = visitor.get_registered_node(output_node_id);
|
||||
NGRAPH_CHECK(output_node, "output_node of edge not known");
|
||||
input_node->set_argument(input_index, output_node->output(output_index));
|
||||
}
|
||||
else
|
||||
{
|
||||
more_edges = false;
|
||||
}
|
||||
visitor.finish_structure();
|
||||
}
|
||||
visitor.finish_structure();
|
||||
}
|
||||
{
|
||||
// Control dependencies
|
||||
visitor.start_structure("control");
|
||||
int64_t i = 0;
|
||||
ostringstream index;
|
||||
bool more_control = true;
|
||||
while (more_control)
|
||||
{
|
||||
index.str("");
|
||||
index << i++;
|
||||
string id = index.str();
|
||||
visitor.start_structure(id);
|
||||
string node_id;
|
||||
visitor.on_attribute("node", node_id);
|
||||
if (!node_id.empty())
|
||||
{
|
||||
shared_ptr<Node> node = visitor.get_registered_node(node_id);
|
||||
NGRAPH_CHECK(node, "node of control edge not known");
|
||||
string dependency_id;
|
||||
visitor.on_attribute("dependency", dependency_id);
|
||||
shared_ptr<Node> dependency = visitor.get_registered_node(dependency_id);
|
||||
NGRAPH_CHECK(dependency, "dependency of control edge not known");
|
||||
node->add_control_dependency(dependency);
|
||||
}
|
||||
else
|
||||
{
|
||||
more_control = false;
|
||||
}
|
||||
visitor.finish_structure();
|
||||
}
|
||||
visitor.finish_structure();
|
||||
}
|
||||
for (auto node : topological_sort(deserialized_nodes))
|
||||
{
|
||||
node->validate_and_infer_types();
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// Finally visit the object attributes
|
||||
visitor.start_structure("value");
|
||||
m_ref->visit_attributes(visitor);
|
||||
visitor.finish_structure();
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -35,7 +35,15 @@ bool op::v0::TensorIterator::visit_attributes(AttributeVisitor& visitor)
|
||||
visitor.on_attribute("input_descriptions", m_input_descriptions);
|
||||
visitor.on_attribute("output_descriptions", m_output_descriptions);
|
||||
|
||||
return false;
|
||||
for (const auto& output_description : m_output_descriptions)
|
||||
{
|
||||
if (auto concat = as_type_ptr<ConcatOutputDescription>(output_description))
|
||||
{
|
||||
m_num_iterations = ((std::abs(concat->m_end - concat->m_start)) / concat->m_part_size);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::v0::TensorIterator::revalidate_and_infer_types_for_body_ops()
|
||||
@ -88,10 +96,6 @@ void op::v0::TensorIterator::validate_and_infer_types()
|
||||
get_input_size() == m_input_descriptions.size(),
|
||||
"Number of inputs must be the same as number of input descriptions");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
get_output_size() == m_output_descriptions.size(),
|
||||
"Number of outputs must be the same as number of output descriptions");
|
||||
|
||||
std::vector<std::shared_ptr<Node>> ends;
|
||||
|
||||
auto make_positive = [](int64_t value, uint64_t dim_size) -> int64_t {
|
||||
@ -226,6 +230,10 @@ void op::v0::TensorIterator::validate_and_infer_types()
|
||||
set_output_type(index, body_value.get_element_type(), body_value.get_partial_shape());
|
||||
}
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
get_output_size() == m_output_descriptions.size(),
|
||||
"Number of outputs must be the same as number of output descriptions");
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> op::v0::TensorIterator::get_function()
|
||||
|
@ -35,13 +35,6 @@ op::util::SubGraphOp::InputDescription::InputDescription(uint64_t input_index,
|
||||
{
|
||||
}
|
||||
|
||||
bool op::util::SubGraphOp::InputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("input_index", m_input_index);
|
||||
visitor.on_attribute("body_parameter_index", m_body_parameter_index);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::util::SubGraphOp::SliceInputDescription::SliceInputDescription(uint64_t input_index,
|
||||
uint64_t body_parameter_index,
|
||||
int64_t start,
|
||||
@ -65,17 +58,6 @@ std::shared_ptr<op::util::SubGraphOp::InputDescription>
|
||||
m_input_index, m_body_parameter_index, m_start, m_stride, m_part_size, m_end, m_axis);
|
||||
}
|
||||
|
||||
bool op::util::SubGraphOp::SliceInputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
InputDescription::visit_attributes(visitor);
|
||||
visitor.on_attribute("start", m_start);
|
||||
visitor.on_attribute("stride", m_stride);
|
||||
visitor.on_attribute("part_size", m_part_size);
|
||||
visitor.on_attribute("end", m_end);
|
||||
visitor.on_attribute("axis", m_axis);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::util::SubGraphOp::MergedInputDescription::MergedInputDescription(uint64_t input_index,
|
||||
uint64_t body_parameter_index,
|
||||
uint64_t body_value_index)
|
||||
@ -91,13 +73,6 @@ std::shared_ptr<op::util::SubGraphOp::InputDescription>
|
||||
m_input_index, m_body_parameter_index, m_body_value_index);
|
||||
}
|
||||
|
||||
bool op::util::SubGraphOp::MergedInputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
InputDescription::visit_attributes(visitor);
|
||||
visitor.on_attribute("body_value_index", m_body_value_index);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::util::SubGraphOp::InvariantInputDescription::InvariantInputDescription(
|
||||
uint64_t input_index, uint64_t body_parameter_index)
|
||||
: InputDescription(input_index, body_parameter_index)
|
||||
@ -110,12 +85,6 @@ std::shared_ptr<op::util::SubGraphOp::InputDescription>
|
||||
return std::make_shared<InvariantInputDescription>(m_input_index, m_body_parameter_index);
|
||||
}
|
||||
|
||||
bool op::util::SubGraphOp::InvariantInputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
InputDescription::visit_attributes(visitor);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::util::SubGraphOp::OutputDescription::OutputDescription(uint64_t body_value_index,
|
||||
uint64_t output_index)
|
||||
: m_body_value_index(body_value_index)
|
||||
@ -123,13 +92,6 @@ op::util::SubGraphOp::OutputDescription::OutputDescription(uint64_t body_value_i
|
||||
{
|
||||
}
|
||||
|
||||
bool op::util::SubGraphOp::OutputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("body_value_index", m_body_value_index);
|
||||
visitor.on_attribute("output_index", m_output_index);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::util::SubGraphOp::ConcatOutputDescription::ConcatOutputDescription(uint64_t body_value_index,
|
||||
uint64_t output_index,
|
||||
int64_t start,
|
||||
@ -146,17 +108,6 @@ op::util::SubGraphOp::ConcatOutputDescription::ConcatOutputDescription(uint64_t
|
||||
{
|
||||
}
|
||||
|
||||
bool op::util::SubGraphOp::ConcatOutputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
OutputDescription::visit_attributes(visitor);
|
||||
visitor.on_attribute("start", m_start);
|
||||
visitor.on_attribute("stride", m_stride);
|
||||
visitor.on_attribute("part_size", m_part_size);
|
||||
visitor.on_attribute("end", m_end);
|
||||
visitor.on_attribute("axis", m_axis);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::shared_ptr<op::util::SubGraphOp::OutputDescription>
|
||||
op::util::SubGraphOp::ConcatOutputDescription::copy() const
|
||||
{
|
||||
@ -178,13 +129,6 @@ std::shared_ptr<op::util::SubGraphOp::OutputDescription>
|
||||
return std::make_shared<BodyOutputDescription>(m_body_value_index, m_output_index, m_iteration);
|
||||
}
|
||||
|
||||
bool op::util::SubGraphOp::BodyOutputDescription::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
OutputDescription::visit_attributes(visitor);
|
||||
visitor.on_attribute("iteration", m_iteration);
|
||||
return true;
|
||||
}
|
||||
|
||||
op::util::SubGraphOp::SubGraphOp(const OutputVector& args)
|
||||
: Op(args)
|
||||
{
|
||||
@ -257,103 +201,9 @@ Input<Node> op::util::SubGraphOp::input_for_value(const Output<Node>& value)
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
template <>
|
||||
FactoryRegistry<op::util::SubGraphOp::InputDescription>&
|
||||
FactoryRegistry<op::util::SubGraphOp::InputDescription>::get()
|
||||
{
|
||||
static FactoryRegistry<op::util::SubGraphOp::InputDescription> registry;
|
||||
static std::mutex init_guard;
|
||||
if (registry.m_factory_map.size() == 0)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(init_guard);
|
||||
if (registry.m_factory_map.size() == 0)
|
||||
{
|
||||
registry.register_factory<op::util::SubGraphOp::SliceInputDescription>();
|
||||
registry.register_factory<op::util::SubGraphOp::MergedInputDescription>();
|
||||
registry.register_factory<op::util::SubGraphOp::InvariantInputDescription>();
|
||||
}
|
||||
}
|
||||
return registry;
|
||||
}
|
||||
|
||||
constexpr DiscreteTypeInfo
|
||||
AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::InputDescription>>::type_info;
|
||||
|
||||
constexpr DiscreteTypeInfo AttributeAdapter<
|
||||
std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>>::type_info;
|
||||
|
||||
AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>>::
|
||||
AttributeAdapter(std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>& ref)
|
||||
: m_ref(ref)
|
||||
{
|
||||
}
|
||||
|
||||
bool AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>>::
|
||||
visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
int64_t size = m_ref.size();
|
||||
visitor.on_attribute("size", size);
|
||||
if (size != m_ref.size())
|
||||
{
|
||||
m_ref.resize(size);
|
||||
}
|
||||
std::ostringstream index;
|
||||
for (int64_t i = 0; i < size; i++)
|
||||
{
|
||||
index.str("");
|
||||
index << i;
|
||||
visitor.on_attribute(index.str(), m_ref[i]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
FactoryRegistry<op::util::SubGraphOp::OutputDescription>&
|
||||
FactoryRegistry<op::util::SubGraphOp::OutputDescription>::get()
|
||||
{
|
||||
static FactoryRegistry<op::util::SubGraphOp::OutputDescription> registry;
|
||||
static std::mutex init_guard;
|
||||
// TODO: Add a lock
|
||||
if (registry.m_factory_map.size() == 0)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(init_guard);
|
||||
if (registry.m_factory_map.size() == 0)
|
||||
{
|
||||
registry.register_factory<op::util::SubGraphOp::ConcatOutputDescription>();
|
||||
registry.register_factory<op::util::SubGraphOp::BodyOutputDescription>();
|
||||
}
|
||||
}
|
||||
return registry;
|
||||
}
|
||||
|
||||
constexpr DiscreteTypeInfo AttributeAdapter<
|
||||
std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>>::type_info;
|
||||
|
||||
constexpr DiscreteTypeInfo
|
||||
AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>::type_info;
|
||||
|
||||
AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>>::
|
||||
AttributeAdapter(std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>& ref)
|
||||
: m_ref(ref)
|
||||
{
|
||||
}
|
||||
|
||||
bool AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>>::
|
||||
visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
int64_t size = m_ref.size();
|
||||
visitor.on_attribute("size", size);
|
||||
if (size != m_ref.size())
|
||||
{
|
||||
m_ref.resize(size);
|
||||
}
|
||||
std::ostringstream index;
|
||||
for (int64_t i = 0; i < size; i++)
|
||||
{
|
||||
index.str("");
|
||||
index << i;
|
||||
visitor.on_attribute(index.str(), m_ref[i]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user