Visitor api ti implementation (#3576)

* TensorIterator deserialization. Introduce new on_adapter(Function) and add implementation in on_adapter(void) for Input and Output Descriptions. Remove factory adapter.

* Add comments to functions provided. Add missing  on_adapter() after rebase.

* Apply formatting.

* Remove visit_attributes from SubGraphOp, remove declaration for createSubGraphLayer.

* Add port map parsing to address not consecutive order of external_port_id appearance.

* Remove header for factory_adapter.

* Add on_adapter() in V10Parse::parse() function.

* Add m_num_iterations initialization for concat output.

* Remove redundant lines, add doxygen comments.

* Change cpp/ie_cnn_network.h to local include, remove temporary map object from range for loop.

* Restore protected access for SubGraphOp.
This commit is contained in:
Szymon Durawa 2020-12-24 05:34:21 +01:00 committed by GitHub
parent 4f0720176c
commit 431485e4a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 577 additions and 590 deletions

View File

@ -27,7 +27,7 @@ convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function>& gr
const ICNNNetwork &ngraphNetwork,
CNNNetworkImpl* cnnNetworkImpl,
bool keep_constant_inputs = false);
// TODO: move ConstAllocatorWrapper class, shareWeights add addBlob into CNNLayerCreator when NodeConverter class is removed
class ConstAllocatorWrapper : public IAllocator {
public:

View File

@ -49,8 +49,11 @@
#include "transformations/utils/utils.hpp"
#include "transformations/rt_info/fused_names_attribute.hpp"
#include "transformations/rt_info/primitives_priority_attribute.hpp"
#include "cpp/ie_cnn_network.h"
#include "legacy/convert_function_to_cnn_network.hpp"
#include "legacy/graph_tools.hpp"
#include "legacy/net_pass.h"
#include "ie_legacy_itt.hpp"
#include "ie_cnn_layer_builder_ngraph.h"
@ -66,6 +69,210 @@ namespace details {
return nullptr;\
});\
/// \brief Creates legacy representation of CNNLayer for SubGraphOp.
/// \param layer node type
/// \return pointer to CNNLayer with legacy representation of SubGraphOp.
CNNLayer::Ptr createSubGraphLayer(const std::shared_ptr<ngraph::Node>& layer) {
auto sub_graph = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp>(layer);
if (!sub_graph) {
THROW_IE_EXCEPTION << "Cannot cast layer to SubGraphOp.";
}
// inputs/outputs of TensorIterator (ngraph representation)
auto parameters = sub_graph->get_function()->get_parameters();
auto results = sub_graph->get_function()->get_results();
// Convert body (ngraph representation) to CNNNetwork.
// This network will contain nodes of type = "Input" and data nodes with wrong names.
// IE TensorIterator doesn't include such nodes so we create CNNNetwork in a separate scope
// to call the destructor and delete these "Input"/data nodes.
TensorIterator::Body body;
{
InferenceEngine::CNNNetwork body_net(sub_graph->get_function());
InferenceEngine::CNNNetwork net(InferenceEngine::details::convertFunctionToICNNNetwork(body_net.getFunction(), body_net));
// Paranoid check for cycles
bool res = CNNNetForestDFS(
CNNNetGetAllInputLayers(net), [](const CNNLayerPtr& layer) {}, false);
if (!res) {
THROW_IE_EXCEPTION << "Loop detected. SubGraphOp body should not contain loops.";
}
// Get inputs/outputs of cnn network
auto in_info_map_with_parameters = net.getInputsInfo();
auto out_info_map = net.getOutputsInfo();
IE_ASSERT(in_info_map_with_parameters.size() == parameters.size());
IE_ASSERT(out_info_map.size() == results.size());
InferenceEngine::TensorIterator::Body temp_body;
temp_body.inputs.resize(in_info_map_with_parameters.size());
temp_body.outputs.resize(out_info_map.size());
// Fill inputs/outs in order aligned with ng representation
uint64_t counter = 0;
for (const auto& param : parameters) {
auto info = in_info_map_with_parameters.at(param->get_friendly_name());
temp_body.inputs[counter++] = info->getInputData();
}
auto map_ng_result_to_ie_name = [] (std::shared_ptr<ngraph::op::v0::Result> res_op) {
auto result = res_op->input(0).get_source_output();
std::string name = result.get_node()->get_friendly_name();
if (result.get_node()->get_output_size() > 1) {
name += "." + std::to_string(result.get_index());
}
return name;
};
counter = 0;
for (const auto& result : results) {
auto data = out_info_map.at(map_ng_result_to_ie_name(result));
temp_body.outputs[counter++] = data;
}
// This deep copy will hold all unreachable constants. See the comment in CopyTIBody function.
body = InferenceEngine::NetPass::CopyTIBody(temp_body);
// Check if data is really const layer holder
auto is_constant_holder = [] (const DataPtr data) {
return data->getPrecision() == Precision::UNSPECIFIED;
};
// Strip unreached node holder from Inputs node.
auto holder = body.inputs.back();
if (is_constant_holder(holder)) {
auto& holder_map = getInputTo(holder);
for( auto it = holder_map.begin(); it != holder_map.end(); ) {
if( it->second->type == "Input")
it = holder_map.erase(it);
else
++it;
}
}
// TODO: Disable this WA after total switch onto Ngraph
// WA: Some plugins (like GPU) require matching of Data object name and producer Layer name.
// Data name is expected in format "[layer_name]" or "[layer_name].[port_idx]" in case
// of multiple inputs. We have to restore it if possible and ignore original names of
// Ngraph parameter and result ops.
// Will not change data name if:
// - data has several consumer layers
// - data has no consumer (example if data is straight used as output)
//
for (auto &in : body.inputs) {
if (is_constant_holder(in))
continue;
const auto input_to = getInputTo(in);
if (input_to.size() != 1)
continue;
const auto consumer_layer = input_to.begin()->second;
const auto consumer_in_port_set = consumer_layer->insData;
const auto found = std::find_if(consumer_in_port_set.begin(), consumer_in_port_set.end(),
[&in] (const DataWeakPtr &wptr) { return wptr.lock() == in; });
IE_ASSERT(found != consumer_in_port_set.end());
const auto consumer_port_idx = std::distance(consumer_in_port_set.begin(), found);
auto new_name = consumer_layer->name;
if (consumer_in_port_set.size() > 1) {
new_name += '.' + std::to_string(consumer_port_idx);
}
in->setName(new_name);
}
// TODO: this WA restore original precisions of outputs.
// convertFunctionToICNNNetwork has internal fallback policy for unsupported
// precisions for inputs/outputs ports. Particular for U8 will be translated
// to FP32. However Loop body has strong requirements for continue_condition
// port, it should be BOOL(U8).
//
for (int i = 0; i < results.size(); i++) {
auto result = results[i];
auto output = body.outputs[i];
if (result->get_element_type() == ngraph::element::u8) {
output->setPrecision(InferenceEngine::Precision::U8);
}
}
}
// Create Inference Engine representation of TensorIterator
LayerParams params = {layer->get_friendly_name(), "TensorIterator",
details::convertPrecision(layer->get_output_element_type(0))};
auto res = std::make_shared<InferenceEngine::TensorIterator>(params);
res->body = body;
// Port map: outputs
for (const auto& desc : sub_graph->get_output_descriptions()) {
auto body_output_idx = desc->m_body_value_index;
std::string type_name = desc->get_type_info().name;
if (type_name == "ConcatOutputDescription") {
auto output_desc = ::ngraph::as_type_ptr<ngraph::op::util::SubGraphOp::ConcatOutputDescription>(desc);
IE_ASSERT(output_desc != nullptr);
res->output_port_map.emplace_back(InferenceEngine::TensorIterator::PortMap {
static_cast<int>(output_desc->m_output_index), static_cast<int>(body_output_idx),
static_cast<int>(output_desc->m_axis), static_cast<int>(output_desc->m_stride),
static_cast<int>(output_desc->m_start), static_cast<int>(output_desc->m_end),
static_cast<int>(output_desc->m_part_size)});
} else if (type_name == "BodyOutputDescription") {
auto output_desc = ::ngraph::as_type_ptr<ngraph::op::util::SubGraphOp::BodyOutputDescription>(desc);
IE_ASSERT(output_desc != nullptr);
res->output_port_map.emplace_back(InferenceEngine::TensorIterator::PortMap {
static_cast<int>(output_desc->m_output_index), static_cast<int>(body_output_idx), -1, 1, 0, -1, 1});
} else {
THROW_IE_EXCEPTION << "Incorrect type of the output description.";
}
}
// Port map : inputs and back edges
for (const auto& desc : sub_graph->get_input_descriptions()) {
auto body_input_index = desc->m_body_parameter_index;
if (const auto slice_desc = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp::SliceInputDescription>(desc)) {
res->input_port_map.emplace_back(InferenceEngine::TensorIterator::PortMap {
static_cast<int>(slice_desc->m_input_index), static_cast<int>(body_input_index),
static_cast<int>(slice_desc->m_axis), static_cast<int>(slice_desc->m_stride),
static_cast<int>(slice_desc->m_start), static_cast<int>(slice_desc->m_end),
static_cast<int>(slice_desc->m_part_size)});
} else if (const auto merge_desc = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp::MergedInputDescription>(desc)) {
res->input_port_map.emplace_back(InferenceEngine::TensorIterator::PortMap {
static_cast<int>(merge_desc->m_input_index), static_cast<int>(body_input_index), -1, 1, 0, -1, 1});
auto body_output_idx = merge_desc->m_body_value_index;
res->back_edges.emplace_back(InferenceEngine::TensorIterator::PortMap {
static_cast<int>(body_output_idx), static_cast<int>(body_input_index), -1, 1, 0, -1, 1});
} else if (const auto inv_desc = std::dynamic_pointer_cast<ngraph::op::util::SubGraphOp::InvariantInputDescription>(desc)) {
res->input_port_map.emplace_back(InferenceEngine::TensorIterator::PortMap {
static_cast<int>(inv_desc->m_input_index), static_cast<int>(body_input_index), -1, 1, 0, -1, 1});
} else {
THROW_IE_EXCEPTION << "Incorrect type of the input description.";
}
}
if (const auto loop_op = std::dynamic_pointer_cast<const ngraph::opset5::Loop>(layer)) {
auto spec_port = loop_op->get_special_body_ports();
if (spec_port.current_iteration_input_idx != -1) {
auto ie_port_idx = spec_port.current_iteration_input_idx;
res->params["loop_body_current_iteration_idx"] = std::to_string(ie_port_idx);
}
if (spec_port.body_condition_output_idx != -1) {
auto body_output_idx = spec_port.body_condition_output_idx;
res->params["loop_body_condition_output_idx"] = std::to_string(body_output_idx);
}
res->params["loop_trip_count_idx"] = "0";
res->params["loop_execution_condition_idx"] = "1";
}
return res;
}
/**
* @brief Creator for CNNLayer from nGraph op
*/
@ -134,6 +341,9 @@ public:
params[name] = joinVec(data);
}
void on_adapter(const std::string& name, ::ngraph::ValueAccessor<std::shared_ptr<ngraph::Function>>& adapter) override {
}
void on_adapter(const std::string& name, ::ngraph::ValueAccessor<void>& adapter) override;
void on_adapter(const std::string& name, ::ngraph::ValueAccessor<void*>& adapter) override {
@ -171,6 +381,10 @@ void InferenceEngine::details::CNNLayerCreator::on_adapter(const std::string& na
} else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<std::vector<size_t>>>(& adapter)) {
auto data = a->get();
params[name] = joinVec(data);
} else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<std::vector<std::shared_ptr<
ngraph::op::util::SubGraphOp::InputDescription>>>>(& adapter)) {
} else if (auto a = ::ngraph::as_type<::ngraph::AttributeAdapter<std::vector<std::shared_ptr<
ngraph::op::util::SubGraphOp::OutputDescription>>>>(& adapter)) {
} else {
THROW_IE_EXCEPTION << "Error converting ngraph to CNN network. "
"Attribute adapter can not be found for " << name << " parameter";
@ -1300,6 +1514,12 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
res->params["ctc_merge_repeated"] = res->getBoolStrParamAsIntStr("ctc_merge_repeated");
return res;
});
addSpecificCreator({"TensorIterator"}, [](const std::shared_ptr<::ngraph::Node>& node, const std::map<std::string, std::string>& params) -> CNNLayerPtr {
auto res = createSubGraphLayer(node);
res->type = "TensorIterator";
return res;
});
}
CNNLayerPtr InferenceEngine::details::CNNLayerCreator::create() {
@ -1344,7 +1564,6 @@ void convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function
std::make_shared<Builder::NodeConverter<::ngraph::op::ScaleShiftIE>>(),
std::make_shared<Builder::NodeConverter<::ngraph::op::SquaredDifference>>(),
std::make_shared<Builder::NodeConverter<::ngraph::op::VariadicSplit>>(),
std::make_shared<Builder::NodeConverter<::ngraph::op::TensorIterator>>(),
std::make_shared<Builder::NodeConverter<::ngraph::opset5::Loop>>(),
std::make_shared<Builder::NodeConverter<::ngraph::op::ShuffleChannels>>(),
std::make_shared<Builder::NodeConverter<::ngraph::op::v4::Interpolate>>(),

View File

@ -23,6 +23,7 @@
#include <ngraph/opsets/opset5.hpp>
#include <ngraph/opsets/opset6.hpp>
#include <ngraph/variant.hpp>
#include <ngraph/op/util/sub_graph_base.hpp>
#include <cpp/ie_cnn_network.h>
#include "ie_blob_stream.hpp"
@ -46,44 +47,223 @@ IRParser::IRParser(size_t version, const std::vector<InferenceEngine::IExtension
}
}
std::shared_ptr<ICNNNetwork> IRParser::parse(const pugi::xml_node& root, const Blob::CPtr& weights) {
return parser->parse(root, weights);
}
void V10Parser::XmlDeserializer::map_type_in_function(const pugi::xml_node& node,
const std::string map_type, std::map<uint64_t, uint64_t>& type_id_in_function) {
uint64_t map_type_number = 0;
auto body_node = node.child("body");
/**
* Hold original blob in order to avoid situations when original blob is allocated on stack
*/
class WeightsHolderBlob : public TBlob<uint8_t> {
Blob::CPtr originBlob;
if (body_node.empty()) {
THROW_IE_EXCEPTION << "Missing body part.";
}
public:
explicit WeightsHolderBlob(const Blob::CPtr& weights) :
TBlob<uint8_t>(weights->getTensorDesc(),
weights->cbuffer().as<uint8_t*>()),
originBlob(weights) { }
};
// Fill map: parameter/result id to parameter/result number in Function
FOREACH_CHILD(_layer, body_node.child("layers"), "layer") {
auto type = XMLParseUtils::GetStrAttr(_layer, "type");
V10Parser::V10Parser(const std::vector<IExtensionPtr>& exts) : _exts(exts) {
// Load default opsets
opsets["opset1"] = ngraph::get_opset1();
opsets["opset2"] = ngraph::get_opset2();
opsets["opset3"] = ngraph::get_opset3();
opsets["opset4"] = ngraph::get_opset4();
opsets["opset5"] = ngraph::get_opset5();
opsets["opset6"] = ngraph::get_opset6();
// Load custom opsets
for (const auto& ext : exts) {
std::map<std::string, ngraph::OpSet> extOpsets = ext->getOpSets();
for (const auto& it : extOpsets) {
if (opsets.find(it.first) != opsets.end())
THROW_IE_EXCEPTION << "Cannot add opset with name: " << it.first << ". Opset with the same name already exists.";
opsets[it.first] = it.second;
if (type == map_type) {
auto id = XMLParseUtils::GetUIntAttr(_layer, "id");
type_id_in_function[id] = map_type_number;
map_type_number++;
}
}
}
std::shared_ptr<ICNNNetwork> V10Parser::parse(const pugi::xml_node& root, const Blob::CPtr& weights) {
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>> V10Parser::XmlDeserializer::parseInputDescription(const pugi::xml_node& node) {
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>> inputs;
std::map<uint64_t, uint64_t> param_id_in_function;
std::map<uint64_t, uint64_t> result_id_in_function;
map_type_in_function(node, "Parameter", param_id_in_function);
map_type_in_function(node, "Result", result_id_in_function);
// Parse PortMap: external_port_id for inputs does not always appear in consecutive order
std::map<uint64_t, pugi::xml_node> input_map;
FOREACH_CHILD(_input, node.child("port_map"), "input") {
int64_t ext_port_id = GetInt64Attr(_input, "external_port_id");
input_map[ext_port_id] = _input;
}
for (const auto& input : input_map) {
auto &_input = input.second;
auto axis_attr = _input.attribute("axis");
auto purpose = XMLParseUtils::GetStrAttr(_input, "purpose", "");
int64_t ti_input_index = XMLParseUtils::GetInt64Attr(_input, "external_port_id");
size_t body_parameter_index = XMLParseUtils::GetUIntAttr(_input, "internal_layer_id");
// if axis is set, then slicing is enabled. Create ngraph::TensorIterator::SlicedInput.
if (!axis_attr.empty()) {
size_t axis = XMLParseUtils::GetUIntAttr(_input, "axis");
int64_t start = XMLParseUtils::GetInt64Attr(_input, "start", 0);
int64_t stride = XMLParseUtils::GetInt64Attr(_input, "stride", 1);
int64_t end = XMLParseUtils::GetInt64Attr(_input, "end", -1);
int64_t part_size = XMLParseUtils::GetInt64Attr(_input, "part_size", 1);
inputs.push_back(std::make_shared<ngraph::op::util::SubGraphOp::SliceInputDescription>
(ti_input_index,
param_id_in_function[body_parameter_index],
start,
stride,
part_size,
end,
axis));
} else {
// otherwise find corresponding back edge and create ngraph::TensorIterator::MergedInput
bool is_back_edge_exist = false;
FOREACH_CHILD(_edge, node.child("back_edges"), "edge") {
size_t to_layer = XMLParseUtils::GetUIntAttr(_edge, "to-layer");
if (to_layer == body_parameter_index) {
size_t from_layer = XMLParseUtils::GetUIntAttr(_edge, "from-layer");
inputs.push_back(std::make_shared<ngraph::op::util::SubGraphOp::MergedInputDescription>
(ti_input_index,
param_id_in_function[body_parameter_index],
result_id_in_function[from_layer]));
is_back_edge_exist = true;
break;
}
}
// ti_input_index = -1 means that Parameter of the body is not connected to inputs of TensorIterator
// and is used only for internal needs.
if (!is_back_edge_exist && ti_input_index >= 0) {
inputs.push_back(std::make_shared<ngraph::op::util::SubGraphOp::InvariantInputDescription>
(ti_input_index,
param_id_in_function[body_parameter_index]));
}
}
}
return inputs;
}
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>> V10Parser::XmlDeserializer::parseOutputDescription(const pugi::xml_node& node) {
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>> outputs;
std::map<uint64_t, uint64_t> result_id_in_function;
map_type_in_function(node, "Result", result_id_in_function);
// Parse PortMap: outputs
std::map<int64_t, pugi::xml_node> output_map;
FOREACH_CHILD(_output, node.child("port_map"), "output") {
int64_t ext_port_id = GetInt64Attr(_output, "external_port_id");
output_map[ext_port_id] = _output;
}
uint64_t output_number = 0;
for (const auto& output : output_map) {
auto& _output = output.second;
auto axis_attr = _output.attribute("axis");
auto purpose = XMLParseUtils::GetStrAttr(_output, "purpose", "");
size_t body_result_index = XMLParseUtils::GetUIntAttr(_output, "internal_layer_id");
// if axis is set, then concatenation is enabled. Create ngraph::TensorIterator::ConcatOutput.
if (!axis_attr.empty()) {
int64_t axis = XMLParseUtils::GetInt64Attr(_output, "axis");
int64_t start = XMLParseUtils::GetInt64Attr(_output, "start", 0);
int64_t stride = XMLParseUtils::GetInt64Attr(_output, "stride", 1);
int64_t end = XMLParseUtils::GetInt64Attr(_output, "end", -1);
int64_t part_size = XMLParseUtils::GetInt64Attr(_output, "part_size", 1);
outputs.push_back(std::make_shared<ngraph::op::util::SubGraphOp::ConcatOutputDescription>
(result_id_in_function[body_result_index],
output_number,
start,
stride,
part_size,
end,
axis));
} else {
// otherwise create ngraph::TensorIterator::BodyOutput. -1 means last iteration.
outputs.push_back(std::make_shared<ngraph::op::util::SubGraphOp::BodyOutputDescription>
(result_id_in_function[body_result_index],
output_number,
-1));
}
output_number++;
}
return outputs;
}
void V10Parser::XmlDeserializer::on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) {
std::string val;
// for TensorIterator look for 'port_map' as 'data' does not exist
if (node.child("port_map")) {
if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::vector<std::shared_ptr
<ngraph::op::util::SubGraphOp::InputDescription>>>>(&adapter)) {
a->set(parseInputDescription(node));
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::vector<std::shared_ptr
<ngraph::op::util::SubGraphOp::OutputDescription>>>>(&adapter)) {
a->set(parseOutputDescription(node));
}
}
if (!getStrAttribute(node.child("data"), name, val)) return;
if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::element::Type>>(&adapter)) {
static_cast<ngraph::element::Type&>(*a) = details::convertPrecision(val);
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::PartialShape>>(&adapter)) {
std::vector<int64_t> shape;
std::vector<ngraph::Dimension> dims;
if (!getParameters<int64_t>(node.child("data"), name, shape)) return;
for (const auto& dim : shape) dims.emplace_back(dim);
static_cast<ngraph::PartialShape&>(*a) = ngraph::PartialShape(dims);
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::Shape>>(&adapter)) {
std::vector<size_t> shape;
if (!getParameters<size_t>(node.child("data"), name, shape)) return;
static_cast<ngraph::Shape&>(*a) = ngraph::Shape(shape);
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::Strides>>(&adapter)) {
std::vector<size_t> shape;
if (!getParameters<size_t>(node.child("data"), name, shape)) return;
static_cast<ngraph::Strides&>(*a) = ngraph::Strides(shape);
#ifdef __APPLE__
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::vector<size_t>>>(&adapter)) {
std::vector<size_t> result;
if (!getParameters<size_t>(node.child("data"), name, result)) return;
static_cast<std::vector<size_t>&>(*a) = result;
#else
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::vector<size_t>>>(&adapter)) {
std::vector<size_t> result;
if (!getParameters<size_t>(node.child("data"), name, result)) return;
a->set(result);
#endif
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::AxisSet>>(&adapter)) {
std::vector<size_t> axes;
if (!getParameters<size_t>(node.child("data"), name, axes)) return;
static_cast<ngraph::AxisSet&>(*a) = ngraph::AxisSet(axes);
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::TopKSortType>>(&adapter)) {
if (!getStrAttribute(node.child("data"), name, val)) return;
static_cast<ngraph::op::TopKSortType&>(*a) = ngraph::as_enum<ngraph::op::TopKSortType>(val);
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::TopKMode>>(&adapter)) {
if (!getStrAttribute(node.child("data"), name, val)) return;
static_cast<ngraph::op::TopKMode&>(*a) = ngraph::as_enum<ngraph::op::TopKMode>(val);
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::CoordinateDiff>>(&adapter)) {
std::vector<size_t> shape;
if (!getParameters<size_t>(node.child("data"), name, shape)) return;
std::vector<std::ptrdiff_t> coord_diff(shape.begin(), shape.end());
static_cast<ngraph::CoordinateDiff&>(*a) = ngraph::CoordinateDiff(coord_diff);
} else {
THROW_IE_EXCEPTION << "Error IR reading. Attribute adapter can not be found for " << name
<< " parameter";
}
}
void V10Parser::XmlDeserializer::on_adapter(const std::string& name, ngraph::ValueAccessor<std::shared_ptr<ngraph::Function>>& adapter) {
std::shared_ptr<ngraph::Function> ngraph_function;
if (!name.compare("body")) {
auto body_node = node.child(name.c_str());
if (body_node.empty()) {
THROW_IE_EXCEPTION << "TensorIterator has no body.";
}
ngraph_function = parse_function(node.child(name.c_str()), weights);
} else if (!name.compare("net")) {
ngraph_function = parse_function(node, weights);
} else {
THROW_IE_EXCEPTION << "Error: not recognized adapter name: " << name << ".";
}
// Disabled reshape for generic operations in the TI body
ngraph::op::GenericIE::DisableReshape noReshape(ngraph_function);
adapter.set(ngraph_function);
}
std::shared_ptr<ngraph::Function> V10Parser::XmlDeserializer::parse_function(const pugi::xml_node& root, const Blob::CPtr& weights) {
OV_ITT_TASK_CHAIN(taskChain, itt::domains::V10Reader_RT, "V10Parser", "Parse");
using node_params = struct {
@ -202,8 +382,53 @@ std::shared_ptr<ICNNNetwork> V10Parser::parse(const pugi::xml_node& root, const
OV_ITT_TASK_NEXT(taskChain, "ConstructCNNNetwork");
CNNNetwork net(function, _exts);
return function;
}
std::shared_ptr<ICNNNetwork> IRParser::parse(const pugi::xml_node& root, const Blob::CPtr& weights) {
return parser->parse(root, weights);
}
/**
* Hold original blob in order to avoid situations when original blob is allocated on stack
*/
class WeightsHolderBlob : public TBlob<uint8_t> {
Blob::CPtr originBlob;
public:
explicit WeightsHolderBlob(const Blob::CPtr& weights) :
TBlob<uint8_t>(weights->getTensorDesc(),
weights->cbuffer().as<uint8_t*>()),
originBlob(weights) { }
};
V10Parser::V10Parser(const std::vector<IExtensionPtr>& exts) : _exts(exts) {
// Load default opsets
opsets["opset1"] = ngraph::get_opset1();
opsets["opset2"] = ngraph::get_opset2();
opsets["opset3"] = ngraph::get_opset3();
opsets["opset4"] = ngraph::get_opset4();
opsets["opset5"] = ngraph::get_opset5();
opsets["opset6"] = ngraph::get_opset6();
// Load custom opsets
for (const auto& ext : exts) {
for (const auto& it : ext->getOpSets()) {
if (opsets.find(it.first) != opsets.end())
THROW_IE_EXCEPTION << "Cannot add opset with name: " << it.first << ". Opset with the same name already exists.";
opsets[it.first] = it.second;
}
}
}
std::shared_ptr<ICNNNetwork> V10Parser::parse(const pugi::xml_node& root, const Blob::CPtr& weights) {
OV_ITT_TASK_CHAIN(taskChain, itt::domains::V10Reader_RT, "V10Parser", "Parse");
std::shared_ptr<ngraph::Function> function;
XmlDeserializer visitor(root, weights, opsets);
visitor.on_attribute("net", function);
CNNNetwork net(function, _exts);
parsePreProcess(net, root, weights);
return net;
@ -335,7 +560,7 @@ void V10Parser::parsePreProcess(CNNNetwork& network, const pugi::xml_node& root,
}
}
V10Parser::GenericLayerParams V10Parser::parseGenericParams(const pugi::xml_node& node) {
V10Parser::GenericLayerParams V10Parser::XmlDeserializer::parseGenericParams(const pugi::xml_node& node) {
const auto parsePort = [](const pugi::xml_node& parentNode,
const GenericLayerParams& params,
bool input) -> GenericLayerParams::LayerPortData {
@ -392,8 +617,10 @@ bool V10Parser::LayerBaseCreator::shouldCreate(const std::string& nodeType) cons
return comparator(nodeType, type);
}
std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Output<ngraph::Node>>& inputs,
const pugi::xml_node& node, const Blob::CPtr& weights,
std::shared_ptr<ngraph::Node> V10Parser::XmlDeserializer::createNode(
const std::vector<ngraph::Output<ngraph::Node>>& inputs,
const pugi::xml_node& node,
const Blob::CPtr& weights,
const GenericLayerParams& params) {
static std::vector<std::shared_ptr<LayerBaseCreator>> creators = {
std::make_shared<LayerCreator<ngraph::op::v1::DeformableConvolution>>("DeformableConvolution"),
@ -412,7 +639,6 @@ std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Ou
std::make_shared<LayerCreator<ngraph::op::Result>>("Result"),
std::make_shared<LayerCreator<ngraph::op::PSROIPooling>>("PSROIPooling"),
std::make_shared<LayerCreator<ngraph::op::VariadicSplit>>("VariadicSplit"),
std::make_shared<LayerCreator<ngraph::op::TensorIterator>>("TensorIterator"),
std::make_shared<LayerCreator<ngraph::opset5::Loop>>("Loop"),
std::make_shared<LayerCreator<ngraph::op::v1::LogicalAnd>>("LogicalAnd"),
std::make_shared<LayerCreator<ngraph::op::v1::LogicalOr>>("LogicalOr"),
@ -483,7 +709,7 @@ std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Ou
ngraphNode = std::shared_ptr<ngraph::Node>(opset.create_insensitive(type));
ngraphNode->set_friendly_name(params.name);
ngraphNode->set_arguments(inputs);
XmlDeserializer visitor(node, weights);
XmlDeserializer visitor(node, weights, opsets);
if (ngraphNode->visit_attributes(visitor))
ngraphNode->constructor_validate_and_infer_types();
}

View File

@ -8,6 +8,7 @@
# include <ngraph/node.hpp>
# include <ngraph/op/util/sub_graph_base.hpp>
# include <ie_ngraph_utils.hpp>
# include <ngraph/opsets/opset.hpp>
#endif // IR_READER_V10
#include <ie_blob.h>
@ -51,7 +52,6 @@ public:
};
#ifdef IR_READER_V10
class V10Parser : public IParser {
public:
explicit V10Parser(const std::vector<IExtensionPtr>& exts);
@ -171,10 +171,6 @@ private:
}
};
std::shared_ptr<ngraph::Node> createNode(const ngraph::OutputVector& inputs, const pugi::xml_node& node,
const Blob::CPtr& weights, const GenericLayerParams& params);
GenericLayerParams parseGenericParams(const pugi::xml_node& node);
void parsePreProcess(CNNNetwork& network, const pugi::xml_node& root, const Blob::CPtr& weights);
std::map<std::string, DataPtr> portsToData;
@ -182,7 +178,8 @@ private:
class XmlDeserializer : public ngraph::AttributeVisitor {
public:
explicit XmlDeserializer(const pugi::xml_node& node, const Blob::CPtr& weights): node(node), weights(weights) {}
explicit XmlDeserializer(const pugi::xml_node& node, const Blob::CPtr& weights,
const std::map<std::string, ngraph::OpSet>& opsets) : node(node), weights(weights), opsets(opsets) {}
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::string>& value) override {
std::string val;
if (!getStrAttribute(node.child("data"), name, val)) return;
@ -203,56 +200,7 @@ private:
if (!is_true && !is_false) return;
value.set(is_true);
}
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override {
std::string val;
if (!getStrAttribute(node.child("data"), name, val)) return;
if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::element::Type>>(&adapter)) {
static_cast<ngraph::element::Type&>(*a) = details::convertPrecision(val);
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::PartialShape>>(&adapter)) {
std::vector<int64_t> shape;
std::vector<ngraph::Dimension> dims;
if (!getParameters<int64_t>(node.child("data"), name, shape)) return;
for (const auto& dim : shape) dims.emplace_back(dim);
static_cast<ngraph::PartialShape&>(*a) = ngraph::PartialShape(dims);
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::Shape>>(&adapter)) {
std::vector<size_t> shape;
if (!getParameters<size_t>(node.child("data"), name, shape)) return;
static_cast<ngraph::Shape&>(*a) = ngraph::Shape(shape);
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::Strides>>(&adapter)) {
std::vector<size_t> shape;
if (!getParameters<size_t>(node.child("data"), name, shape)) return;
static_cast<ngraph::Strides&>(*a) = ngraph::Strides(shape);
#ifdef __APPLE__
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::vector<size_t>>>(&adapter)) {
std::vector<size_t> result;
if (!getParameters<size_t>(node.child("data"), name, result)) return;
static_cast<std::vector<size_t>&>(*a) = result;
#else
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<std::vector<size_t>>>(&adapter)) {
std::vector<size_t> result;
if (!getParameters<size_t>(node.child("data"), name, result)) return;
a->set(result);
#endif
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::AxisSet>>(&adapter)) {
std::vector<size_t> axes;
if (!getParameters<size_t>(node.child("data"), name, axes)) return;
static_cast<ngraph::AxisSet&>(*a) = ngraph::AxisSet(axes);
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::TopKSortType>>(&adapter)) {
if (!getStrAttribute(node.child("data"), name, val)) return;
static_cast<ngraph::op::TopKSortType&>(*a) = ngraph::as_enum<ngraph::op::TopKSortType>(val);
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::op::TopKMode>>(&adapter)) {
if (!getStrAttribute(node.child("data"), name, val)) return;
static_cast<ngraph::op::TopKMode&>(*a) = ngraph::as_enum<ngraph::op::TopKMode>(val);
} else if (auto a = ngraph::as_type<ngraph::AttributeAdapter<ngraph::CoordinateDiff>>(&adapter)) {
std::vector<size_t> shape;
if (!getParameters<size_t>(node.child("data"), name, shape)) return;
std::vector<std::ptrdiff_t> coord_diff(shape.begin(), shape.end());
static_cast<ngraph::CoordinateDiff&>(*a) = ngraph::CoordinateDiff(coord_diff);
} else {
THROW_IE_EXCEPTION << "Error IR reading. Attribute adapter can not be found for " << name
<< " parameter";
}
}
void on_adapter(const std::string& name, ngraph::ValueAccessor<void>& adapter) override;
void on_adapter(const std::string& name, ngraph::ValueAccessor<double>& adapter) override {
std::string val;
if (!getStrAttribute(node.child("data"), name, val))
@ -307,6 +255,8 @@ private:
adapter.set(value);
}
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::shared_ptr<ngraph::Function>>& adapter) override;
void on_adapter(const std::string& name, ngraph::ValueAccessor<std::vector<int32_t>>& adapter) override {
std::vector<int32_t> value;
if (!getParameters<int32_t>(node.child("data"), name, value)) return;
@ -334,6 +284,32 @@ private:
private:
const pugi::xml_node node;
const Blob::CPtr& weights;
const std::map<std::string, ngraph::OpSet>& opsets;
/// \brief Traverses port_map in order to create vector of InputDescription shared_ptrs.
/// Shall be used only for ops which have port_map attribute.
/// \param node xml op representation
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>> parseInputDescription(
const pugi::xml_node& node);
/// \brief Traverses port_map in order to create vector of OutputDescription shared_ptrs.
/// Shall be used only for ops which have port_map attribute.
/// \param node xml op representation
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>> parseOutputDescription(
const pugi::xml_node& node);
/// \brief Traverses nGraph body function for specified op type and creates a map of all
/// op iterations. Map constains type id and assigned to it consecutive number starting from 0.
/// \param node xml op representation
/// \param type op type name to find
/// \param type_id_in_function map container
void map_type_in_function(const pugi::xml_node& node, std::string type, std::map<uint64_t, uint64_t>& type_id_in_function);
/// \brief Traverses xml node representation in order to create nGraph function for it.
/// \param node xml node representation
/// \param weights weights attached to current node
/// \return shared pointer to function representing input node
std::shared_ptr<ngraph::Function> parse_function(const pugi::xml_node& root, const Blob::CPtr& weights);
GenericLayerParams parseGenericParams(const pugi::xml_node& node);
std::shared_ptr<ngraph::Node> createNode(const ngraph::OutputVector& inputs, const pugi::xml_node& node,
const Blob::CPtr& weights, const GenericLayerParams& params);
bool getStrAttribute(const pugi::xml_node& node, const std::string& name, std::string& value) {
if (!node) return false;

View File

@ -30,6 +30,7 @@ namespace ngraph
class ValueAccessor;
class VisitorAdapter;
class Node;
class Function;
/// \brief Visits the attributes of a node, primarily for serialization-like tasks.
///
@ -116,6 +117,12 @@ namespace ngraph
/// \brief Hook for adapters that need visitor access
virtual void on_adapter(const std::string& name, VisitorAdapter& adapter);
/// \brief Provides API to handle nGraph Function attribute type, accessed as ValueAccessor
/// \param name attribute name
/// \param adapter reference to a Function ValueAccessor<VAT>
virtual void on_adapter(const std::string& name,
ValueAccessor<std::shared_ptr<Function>>& adapter);
/// The generic visitor. There must be a definition of AttributeAdapter<T> that can convert
/// to a ValueAccessor<U> for one of the on_adpater methods.
template <typename AT>

View File

@ -190,16 +190,17 @@ namespace ngraph
};
template <>
class NGRAPH_API AttributeAdapter<std::shared_ptr<Function>> : public VisitorAdapter
class NGRAPH_API AttributeAdapter<std::shared_ptr<Function>>
: public DirectValueAccessor<std::shared_ptr<Function>>
{
public:
AttributeAdapter(std::shared_ptr<Function>& ref);
AttributeAdapter(std::shared_ptr<Function>& value)
: DirectValueAccessor<std::shared_ptr<Function>>(value)
{
}
bool visit_attributes(AttributeVisitor& visitor) override;
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<shared_ptr<Function>>", 0};
static constexpr DiscreteTypeInfo type_info{"AttributeAdapter<std::shared_ptr<Function>>",
0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
protected:
std::shared_ptr<Function>& m_ref;
};
}

View File

@ -17,7 +17,6 @@
#pragma once
#include <ngraph/op/parameter.hpp>
#include "ngraph/factory_adapter.hpp"
#include "ngraph/op/op.hpp"
namespace ngraph
@ -50,7 +49,6 @@ namespace ngraph
virtual std::shared_ptr<InputDescription> copy() const = 0;
virtual const type_info_t& get_type_info() const = 0;
virtual bool visit_attributes(AttributeVisitor& visitor);
uint64_t m_input_index{0};
uint64_t m_body_parameter_index{0};
@ -85,7 +83,6 @@ namespace ngraph
int64_t axis);
SliceInputDescription() = default;
std::shared_ptr<InputDescription> copy() const override;
bool visit_attributes(AttributeVisitor& visitor) override;
int64_t m_start{0};
int64_t m_stride{0};
int64_t m_part_size{0};
@ -118,7 +115,6 @@ namespace ngraph
uint64_t body_value_index);
MergedInputDescription() = default;
std::shared_ptr<InputDescription> copy() const override;
bool visit_attributes(AttributeVisitor& visitor) override;
uint64_t m_body_value_index{0};
};
@ -140,7 +136,6 @@ namespace ngraph
InvariantInputDescription(uint64_t input_index, uint64_t body_parameter_index);
InvariantInputDescription() = default;
std::shared_ptr<InputDescription> copy() const override;
bool visit_attributes(AttributeVisitor& visitor) override;
};
/// \brief Describes how a SubGraphOp output is produced from the body.
@ -160,7 +155,6 @@ namespace ngraph
using type_info_t = DiscreteTypeInfo;
virtual ~OutputDescription() = default;
virtual std::shared_ptr<OutputDescription> copy() const = 0;
virtual bool visit_attributes(AttributeVisitor& visitor);
virtual const type_info_t& get_type_info() const = 0;
uint64_t m_body_value_index{0};
@ -194,7 +188,6 @@ namespace ngraph
ConcatOutputDescription() = default;
std::shared_ptr<OutputDescription> copy() const override;
bool visit_attributes(AttributeVisitor& visitor) override;
int64_t m_start{0};
int64_t m_stride{0};
int64_t m_part_size{0};
@ -221,7 +214,6 @@ namespace ngraph
int64_t iteration);
BodyOutputDescription() = default;
std::shared_ptr<OutputDescription> copy() const override;
bool visit_attributes(AttributeVisitor& visitor) override;
int64_t m_iteration{0};
};
@ -347,79 +339,48 @@ namespace ngraph
using OutputDescriptionVector = std::vector<OutputDescriptionPtr>;
}
}
template class NGRAPH_API FactoryRegistry<op::util::SubGraphOp::InputDescription>;
template <>
FactoryRegistry<op::util::SubGraphOp::InputDescription>&
FactoryRegistry<op::util::SubGraphOp::InputDescription>::get();
template <>
class NGRAPH_API AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::InputDescription>>
: public FactoryAttributeAdapter<op::util::SubGraphOp::InputDescription>
class NGRAPH_API AttributeAdapter<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>
: public DirectValueAccessor<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>
{
public:
using FactoryAttributeAdapter::FactoryAttributeAdapter;
AttributeAdapter(
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>& value)
: DirectValueAccessor<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::InputDescription>>>(
value)
{
}
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::InputDescription>>"
">>",
"AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::"
"InputDescription>>>",
0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
template <>
class NGRAPH_API
AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>>
: public VisitorAdapter
class NGRAPH_API AttributeAdapter<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>
: public DirectValueAccessor<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>
{
public:
explicit AttributeAdapter(
std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>& ref);
AttributeAdapter(
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>& value)
: DirectValueAccessor<
std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::OutputDescription>>>(
value)
{
}
bool visit_attributes(AttributeVisitor& visitor) override;
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>"
">>",
"AttributeAdapter<std::vector<std::shared_ptr<ngraph::op::util::SubGraphOp::"
"OutputDescription>>>",
0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
protected:
std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>& m_ref;
};
template class NGRAPH_API FactoryRegistry<op::util::SubGraphOp::OutputDescription>;
template <>
FactoryRegistry<op::util::SubGraphOp::OutputDescription>&
FactoryRegistry<op::util::SubGraphOp::OutputDescription>::get();
template <>
class NGRAPH_API AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>
: public FactoryAttributeAdapter<op::util::SubGraphOp::OutputDescription>
{
public:
using FactoryAttributeAdapter::FactoryAttributeAdapter;
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>"
">>",
0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
template <>
class NGRAPH_API
AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>>
: public VisitorAdapter
{
public:
explicit AttributeAdapter(
std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>& ref);
bool visit_attributes(AttributeVisitor& visitor) override;
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>"
">>",
0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
protected:
std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>& m_ref;
};
}

View File

@ -170,6 +170,12 @@ void AttributeVisitor::on_adapter(const string& name, ValueAccessor<std::vector<
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}
void AttributeVisitor::on_adapter(const string& name,
ValueAccessor<std::shared_ptr<ngraph::Function>>& adapter)
{
on_adapter(name, static_cast<ValueAccessor<void>&>(adapter));
}
const AttributeVisitor::node_id_t AttributeVisitor::invalid_node_id = "";
void AttributeVisitor::register_node(const std::shared_ptr<Node>& node, node_id_t id)

View File

@ -19,7 +19,6 @@
#include <memory>
#include "itt.hpp"
#include "ngraph/factory_adapter.hpp"
#include "ngraph/function.hpp"
#include "ngraph/graph_util.hpp"
#include "ngraph/log.hpp"
@ -407,269 +406,3 @@ void Function::remove_result(const std::shared_ptr<op::Result>& result)
}
constexpr DiscreteTypeInfo AttributeAdapter<shared_ptr<Function>>::type_info;
AttributeAdapter<shared_ptr<Function>>::AttributeAdapter(shared_ptr<Function>& ref)
: m_ref(ref)
{
}
class NodeAttributeAdapter : public FactoryAttributeAdapter<Node>
{
public:
using FactoryAttributeAdapter::FactoryAttributeAdapter;
bool on_start(AttributeVisitor& visitor) override
{
// Indicate that there is a node following
m_id = visitor.get_registered_node_id(m_ref);
m_set_id = (m_ref == nullptr);
visitor.on_attribute("id", m_id);
return m_ref == nullptr || m_id != AttributeVisitor::invalid_node_id;
}
bool on_finish(AttributeVisitor&) override
{
if (m_set_id && m_ref)
{
m_ref->set_friendly_name(m_id);
}
return true;
}
void visit(AttributeVisitor& visitor, const std::string& id)
{
visitor.start_structure(id);
visitor.on_adapter(id, *this);
visitor.finish_structure();
}
static constexpr DiscreteTypeInfo type_info{"Lambda.NodeAttributeAdapter", 0};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
string m_id;
bool m_set_id;
};
constexpr DiscreteTypeInfo NodeAttributeAdapter::type_info;
bool AttributeAdapter<shared_ptr<Function>>::visit_attributes(AttributeVisitor& visitor)
{
if (m_ref->get_results().size() > 0)
{
NodeVector serialized_nodes;
{
// Start with all nodes not already serialized
visitor.start_structure("nodes");
NodeVector results;
for (auto result : m_ref->get_results())
{
results.push_back(result);
}
for (auto sink : m_ref->get_sinks())
{
results.push_back(sink);
}
int64_t i = 0;
ostringstream index;
traverse_nodes(
results, [&i, &index, &visitor, &serialized_nodes](shared_ptr<Node> node) -> void {
if (AttributeVisitor::invalid_node_id == visitor.get_registered_node_id(node))
{
// This node hasn't been seen before
visitor.register_node(node);
index.str("");
index << i++;
string id = index.str();
NodeAttributeAdapter adapter(node);
adapter.visit(visitor, id);
serialized_nodes.push_back(node);
}
});
{
// Sentinel at end
index.str("");
index << i++;
string id = index.str();
shared_ptr<Node> null_node;
NodeAttributeAdapter adapter(null_node);
adapter.visit(visitor, id);
}
visitor.finish_structure();
}
{
// Now do all the edges
visitor.start_structure("edges");
int64_t i = 0;
ostringstream index;
for (auto node : serialized_nodes)
{
for (auto input : node->inputs())
{
index.str("");
index << i++;
string id = index.str();
visitor.start_structure(id);
string input_node_id = visitor.get_registered_node_id(node);
uint64_t input_index = input.get_index();
visitor.on_attribute("input_node", input_node_id);
visitor.on_attribute("input_index", input_index);
auto output = input.get_source_output();
string output_node_id =
visitor.get_registered_node_id(output.get_node_shared_ptr());
uint64_t output_index = output.get_index();
visitor.on_attribute("output_node", output_node_id);
visitor.on_attribute("output_index", output_index);
visitor.finish_structure();
}
}
{
// Add a sentinel
index.str("");
index << i++;
string id = index.str();
visitor.start_structure(id);
string input_node_id = AttributeVisitor::invalid_node_id;
visitor.on_attribute("input_node", input_node_id);
visitor.finish_structure();
}
visitor.finish_structure();
}
{
// Control dependencies
visitor.start_structure("control");
int64_t i = 0;
ostringstream index;
for (auto node : serialized_nodes)
{
for (auto control : node->get_control_dependencies())
{
index.str("");
index << i++;
string id = index.str();
visitor.start_structure(id);
string node_id = visitor.get_registered_node_id(node);
string dependency_id = visitor.get_registered_node_id(control);
visitor.on_attribute("node", node_id);
visitor.on_attribute("dependency", dependency_id);
visitor.finish_structure();
}
}
{
// Add a sentinel
index.str("");
index << i++;
string id = index.str();
visitor.start_structure(id);
string node_id = AttributeVisitor::invalid_node_id;
visitor.on_attribute("node", node_id);
visitor.finish_structure();
}
visitor.finish_structure();
}
}
else
{
NodeVector deserialized_nodes;
{
// Read the graph
visitor.start_structure("nodes");
int64_t i = 0;
ostringstream index;
while (true)
{
index.str("");
index << i++;
string id = index.str();
shared_ptr<Node> node;
NodeAttributeAdapter adapter(node);
adapter.visit(visitor, id);
if (node)
{
visitor.register_node(node);
deserialized_nodes.push_back(node);
}
else
{
break;
}
}
visitor.finish_structure();
}
{
visitor.start_structure("edges");
// Connect the nodes
int64_t i = 0;
ostringstream index;
bool more_edges = true;
while (more_edges)
{
index.str("");
index << i++;
string id = index.str();
visitor.start_structure(id);
string input_node_id;
visitor.on_attribute("input_node", input_node_id);
if (!input_node_id.empty())
{
shared_ptr<Node> input_node = visitor.get_registered_node(input_node_id);
NGRAPH_CHECK(input_node, "input node of edge not known");
uint64_t input_index;
string output_node_id;
uint64_t output_index;
visitor.on_attribute("input_index", input_index);
visitor.on_attribute("output_node", output_node_id);
visitor.on_attribute("output_index", output_index);
shared_ptr<Node> output_node = visitor.get_registered_node(output_node_id);
NGRAPH_CHECK(output_node, "output_node of edge not known");
input_node->set_argument(input_index, output_node->output(output_index));
}
else
{
more_edges = false;
}
visitor.finish_structure();
}
visitor.finish_structure();
}
{
// Control dependencies
visitor.start_structure("control");
int64_t i = 0;
ostringstream index;
bool more_control = true;
while (more_control)
{
index.str("");
index << i++;
string id = index.str();
visitor.start_structure(id);
string node_id;
visitor.on_attribute("node", node_id);
if (!node_id.empty())
{
shared_ptr<Node> node = visitor.get_registered_node(node_id);
NGRAPH_CHECK(node, "node of control edge not known");
string dependency_id;
visitor.on_attribute("dependency", dependency_id);
shared_ptr<Node> dependency = visitor.get_registered_node(dependency_id);
NGRAPH_CHECK(dependency, "dependency of control edge not known");
node->add_control_dependency(dependency);
}
else
{
more_control = false;
}
visitor.finish_structure();
}
visitor.finish_structure();
}
for (auto node : topological_sort(deserialized_nodes))
{
node->validate_and_infer_types();
}
}
{
// Finally visit the object attributes
visitor.start_structure("value");
m_ref->visit_attributes(visitor);
visitor.finish_structure();
}
return true;
}

View File

@ -35,7 +35,15 @@ bool op::v0::TensorIterator::visit_attributes(AttributeVisitor& visitor)
visitor.on_attribute("input_descriptions", m_input_descriptions);
visitor.on_attribute("output_descriptions", m_output_descriptions);
return false;
for (const auto& output_description : m_output_descriptions)
{
if (auto concat = as_type_ptr<ConcatOutputDescription>(output_description))
{
m_num_iterations = ((std::abs(concat->m_end - concat->m_start)) / concat->m_part_size);
}
}
return true;
}
void op::v0::TensorIterator::revalidate_and_infer_types_for_body_ops()
@ -88,10 +96,6 @@ void op::v0::TensorIterator::validate_and_infer_types()
get_input_size() == m_input_descriptions.size(),
"Number of inputs must be the same as number of input descriptions");
NODE_VALIDATION_CHECK(this,
get_output_size() == m_output_descriptions.size(),
"Number of outputs must be the same as number of output descriptions");
std::vector<std::shared_ptr<Node>> ends;
auto make_positive = [](int64_t value, uint64_t dim_size) -> int64_t {
@ -226,6 +230,10 @@ void op::v0::TensorIterator::validate_and_infer_types()
set_output_type(index, body_value.get_element_type(), body_value.get_partial_shape());
}
}
NODE_VALIDATION_CHECK(this,
get_output_size() == m_output_descriptions.size(),
"Number of outputs must be the same as number of output descriptions");
}
std::shared_ptr<Function> op::v0::TensorIterator::get_function()

View File

@ -35,13 +35,6 @@ op::util::SubGraphOp::InputDescription::InputDescription(uint64_t input_index,
{
}
bool op::util::SubGraphOp::InputDescription::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("input_index", m_input_index);
visitor.on_attribute("body_parameter_index", m_body_parameter_index);
return true;
}
op::util::SubGraphOp::SliceInputDescription::SliceInputDescription(uint64_t input_index,
uint64_t body_parameter_index,
int64_t start,
@ -65,17 +58,6 @@ std::shared_ptr<op::util::SubGraphOp::InputDescription>
m_input_index, m_body_parameter_index, m_start, m_stride, m_part_size, m_end, m_axis);
}
bool op::util::SubGraphOp::SliceInputDescription::visit_attributes(AttributeVisitor& visitor)
{
InputDescription::visit_attributes(visitor);
visitor.on_attribute("start", m_start);
visitor.on_attribute("stride", m_stride);
visitor.on_attribute("part_size", m_part_size);
visitor.on_attribute("end", m_end);
visitor.on_attribute("axis", m_axis);
return true;
}
op::util::SubGraphOp::MergedInputDescription::MergedInputDescription(uint64_t input_index,
uint64_t body_parameter_index,
uint64_t body_value_index)
@ -91,13 +73,6 @@ std::shared_ptr<op::util::SubGraphOp::InputDescription>
m_input_index, m_body_parameter_index, m_body_value_index);
}
bool op::util::SubGraphOp::MergedInputDescription::visit_attributes(AttributeVisitor& visitor)
{
InputDescription::visit_attributes(visitor);
visitor.on_attribute("body_value_index", m_body_value_index);
return true;
}
op::util::SubGraphOp::InvariantInputDescription::InvariantInputDescription(
uint64_t input_index, uint64_t body_parameter_index)
: InputDescription(input_index, body_parameter_index)
@ -110,12 +85,6 @@ std::shared_ptr<op::util::SubGraphOp::InputDescription>
return std::make_shared<InvariantInputDescription>(m_input_index, m_body_parameter_index);
}
bool op::util::SubGraphOp::InvariantInputDescription::visit_attributes(AttributeVisitor& visitor)
{
InputDescription::visit_attributes(visitor);
return true;
}
op::util::SubGraphOp::OutputDescription::OutputDescription(uint64_t body_value_index,
uint64_t output_index)
: m_body_value_index(body_value_index)
@ -123,13 +92,6 @@ op::util::SubGraphOp::OutputDescription::OutputDescription(uint64_t body_value_i
{
}
bool op::util::SubGraphOp::OutputDescription::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("body_value_index", m_body_value_index);
visitor.on_attribute("output_index", m_output_index);
return true;
}
op::util::SubGraphOp::ConcatOutputDescription::ConcatOutputDescription(uint64_t body_value_index,
uint64_t output_index,
int64_t start,
@ -146,17 +108,6 @@ op::util::SubGraphOp::ConcatOutputDescription::ConcatOutputDescription(uint64_t
{
}
bool op::util::SubGraphOp::ConcatOutputDescription::visit_attributes(AttributeVisitor& visitor)
{
OutputDescription::visit_attributes(visitor);
visitor.on_attribute("start", m_start);
visitor.on_attribute("stride", m_stride);
visitor.on_attribute("part_size", m_part_size);
visitor.on_attribute("end", m_end);
visitor.on_attribute("axis", m_axis);
return true;
}
std::shared_ptr<op::util::SubGraphOp::OutputDescription>
op::util::SubGraphOp::ConcatOutputDescription::copy() const
{
@ -178,13 +129,6 @@ std::shared_ptr<op::util::SubGraphOp::OutputDescription>
return std::make_shared<BodyOutputDescription>(m_body_value_index, m_output_index, m_iteration);
}
bool op::util::SubGraphOp::BodyOutputDescription::visit_attributes(AttributeVisitor& visitor)
{
OutputDescription::visit_attributes(visitor);
visitor.on_attribute("iteration", m_iteration);
return true;
}
op::util::SubGraphOp::SubGraphOp(const OutputVector& args)
: Op(args)
{
@ -257,103 +201,9 @@ Input<Node> op::util::SubGraphOp::input_for_value(const Output<Node>& value)
namespace ngraph
{
template <>
FactoryRegistry<op::util::SubGraphOp::InputDescription>&
FactoryRegistry<op::util::SubGraphOp::InputDescription>::get()
{
static FactoryRegistry<op::util::SubGraphOp::InputDescription> registry;
static std::mutex init_guard;
if (registry.m_factory_map.size() == 0)
{
std::lock_guard<std::mutex> guard(init_guard);
if (registry.m_factory_map.size() == 0)
{
registry.register_factory<op::util::SubGraphOp::SliceInputDescription>();
registry.register_factory<op::util::SubGraphOp::MergedInputDescription>();
registry.register_factory<op::util::SubGraphOp::InvariantInputDescription>();
}
}
return registry;
}
constexpr DiscreteTypeInfo
AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::InputDescription>>::type_info;
constexpr DiscreteTypeInfo AttributeAdapter<
std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>>::type_info;
AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>>::
AttributeAdapter(std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>& ref)
: m_ref(ref)
{
}
bool AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::InputDescription>>>::
visit_attributes(AttributeVisitor& visitor)
{
int64_t size = m_ref.size();
visitor.on_attribute("size", size);
if (size != m_ref.size())
{
m_ref.resize(size);
}
std::ostringstream index;
for (int64_t i = 0; i < size; i++)
{
index.str("");
index << i;
visitor.on_attribute(index.str(), m_ref[i]);
}
return true;
}
template <>
FactoryRegistry<op::util::SubGraphOp::OutputDescription>&
FactoryRegistry<op::util::SubGraphOp::OutputDescription>::get()
{
static FactoryRegistry<op::util::SubGraphOp::OutputDescription> registry;
static std::mutex init_guard;
// TODO: Add a lock
if (registry.m_factory_map.size() == 0)
{
std::lock_guard<std::mutex> guard(init_guard);
if (registry.m_factory_map.size() == 0)
{
registry.register_factory<op::util::SubGraphOp::ConcatOutputDescription>();
registry.register_factory<op::util::SubGraphOp::BodyOutputDescription>();
}
}
return registry;
}
constexpr DiscreteTypeInfo AttributeAdapter<
std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>>::type_info;
constexpr DiscreteTypeInfo
AttributeAdapter<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>::type_info;
AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>>::
AttributeAdapter(std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>& ref)
: m_ref(ref)
{
}
bool AttributeAdapter<std::vector<std::shared_ptr<op::util::SubGraphOp::OutputDescription>>>::
visit_attributes(AttributeVisitor& visitor)
{
int64_t size = m_ref.size();
visitor.on_attribute("size", size);
if (size != m_ref.size())
{
m_ref.resize(size);
}
std::ostringstream index;
for (int64_t i = 0; i < size; i++)
{
index.str("");
index << i;
visitor.on_attribute(index.str(), m_ref[i]);
}
return true;
}
}