Replaced copy_with_new_args() to clone_with_new_inputs() (#1395)

This commit is contained in:
Ilya Churaev 2020-07-22 13:44:22 +03:00 committed by GitHub
parent 821a3dae32
commit 141b24cf44
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
63 changed files with 64 additions and 151 deletions

View File

@ -12,7 +12,7 @@ To add your custom nGraph operation, create a new class that extends `ngraph::Op
3. Override the shape inference method `validate_and_infer_types`. This method is called multiple times during graph manipulations to determine the shapes and element types of the outputs of the operations. You can access the input shapes through the `get_input_partial_shape()` method and input element types through the `get_input_element_type()` method of `ngraph::Node`. Set the inferred shape and element type of the output using `set_output_type`.
4. Override the `copy_with_new_args` method, which allows graph manipulation routines to create copies of this operation and connect it to different nodes during optimization.
4. Override the `clone_with_new_inputs` method, which allows graph manipulation routines to create copies of this operation and connect it to different nodes during optimization.
5. Override the `visit_attributes` method, which allows serialization and deserialization of attributes. An `AttributeVisitor` is passed to the method, and the implementation is expected to walk over all the attributes in the op using the type-aware `on_attribute` helper. Helpers are already implemented for standard C++ types like `int64_t`, `float`, `bool`, `vector` and for existing nGraph defined types.
@ -39,9 +39,9 @@ nGraph operation contains two constructors: a default constructor, which allows
@snippet op.cpp op:validate
### `copy_with_new_args()`
### `clone_with_new_inputs()`
`ngraph::Node::copy_with_new_args` method creates a copy of nGraph operation with new inputs.
`ngraph::Node::clone_with_new_inputs` method creates a copy of nGraph operation with new inputs.
@snippet op.cpp op:copy

View File

@ -21,7 +21,7 @@ void Operation::validate_and_infer_types() {
//! [op:validate]
//! [op:copy]
std::shared_ptr<ngraph::Node> Operation::copy_with_new_args(const ngraph::NodeVector &new_args) const {
std::shared_ptr<ngraph::Node> Operation::clone_with_new_inputs(const ngraph::OutputVector &new_args) const {
if (new_args.size() != 1) {
throw ngraph::ngraph_error("Incorrect number of new arguments");
}

View File

@ -17,7 +17,7 @@ public:
Operation() = default;
Operation(const ngraph::Output<ngraph::Node>& arg, int64_t add);
void validate_and_infer_types() override;
std::shared_ptr<ngraph::Node> copy_with_new_args(const ngraph::NodeVector& new_args) const override;
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override;
bool visit_attributes(ngraph::AttributeVisitor& visitor) override;
int64_t getAddAttr() { return add; }

View File

@ -107,7 +107,7 @@ public:
set_output_type(0, get_input_element_type(0), ngraph::PartialShape(output_shape));
}
std::shared_ptr<ngraph::Node> copy_with_new_args(const ngraph::NodeVector& new_args) const override {
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override {
if (new_args.size() != 1) {
throw ngraph::ngraph_error("Incorrect number of new arguments");
}

View File

@ -57,11 +57,6 @@ std::vector<InferenceEngine::IShapeInferExtensionPtr> ngraph::op::GenericIE::get
return extensions;
}
ngraph::op::GenericIE::GenericIE(const ngraph::NodeVector& inputs,
const std::map<std::string, InferenceEngine::Parameter>& params,
const std::string type, const std::vector<PortIE>& outputs)
: GenericIE(as_output_vector(inputs), params, type, outputs) {}
ngraph::op::GenericIE::GenericIE(const ngraph::OutputVector& inputs,
const std::map<std::string, InferenceEngine::Parameter>& params_,
const std::string type_, const std::vector<PortIE>& outputs_)
@ -69,7 +64,7 @@ ngraph::op::GenericIE::GenericIE(const ngraph::OutputVector& inputs,
constructor_validate_and_infer_types();
}
std::shared_ptr<ngraph::Node> ngraph::op::GenericIE::copy_with_new_args(const ngraph::NodeVector& new_args) const {
std::shared_ptr<ngraph::Node> ngraph::op::GenericIE::clone_with_new_inputs(const ngraph::OutputVector& new_args) const {
auto genNode = std::make_shared<GenericIE>(new_args, params, type, outputs);
genNode->extensions = extensions;
genNode->reshape = reshape;

View File

@ -87,11 +87,6 @@ public:
* @param type string with original layer type
* @param outputs information about output ports from IR
*/
GenericIE(const NodeVector& inputs,
const std::map<std::string, InferenceEngine::Parameter>& params,
const std::string type,
const std::vector<PortIE>& outputs);
GenericIE(const OutputVector& inputs,
const std::map<std::string, InferenceEngine::Parameter>& params,
const std::string type,
@ -99,7 +94,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
static void addExtension(std::shared_ptr<const ngraph::Lambda> func, const InferenceEngine::IShapeInferExtensionPtr& ext);
static std::vector<InferenceEngine::IShapeInferExtensionPtr> getExtensions(std::shared_ptr<const ngraph::Function> func);

View File

@ -26,7 +26,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
std::vector<int64_t> axes, dim, offset;
};

View File

@ -26,7 +26,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
ELTWISE_TYPE eltwise_type;
};

View File

@ -33,7 +33,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
size_t get_out_size() { return m_output_size; }

View File

@ -33,7 +33,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace op

View File

@ -33,7 +33,7 @@ public:
GRUCellIE() = delete;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
std::size_t get_hidden_size() { return m_hidden_size; }

View File

@ -25,7 +25,7 @@ public:
float alpha,
float beta);
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
float get_alpha() const { return m_alpha; }

View File

@ -37,7 +37,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
InterpolateIEAttrs get_attrs() { return m_attrs; }
@ -65,7 +65,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
ResampleIEAttrs get_attrs() { return m_attrs; }
private:

View File

@ -28,7 +28,7 @@ public:
size_t size,
std::string region);
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
double get_alpha() const { return m_alpha; }

View File

@ -33,7 +33,7 @@ public:
LSTMCellIE() = delete;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
std::size_t get_hidden_size() { return m_hidden_size; }

View File

@ -33,7 +33,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
protected:
float m_eps;

View File

@ -26,7 +26,7 @@ public:
size_t get_version() const override { return 1; }
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
int get_axis() { return m_axis; }
int get_depth() { return m_depth; }

View File

@ -26,7 +26,7 @@ public:
size_t get_version() const override { return 1; }
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
PadMode get_pad_mode() { return m_pad_mode; }
CoordinateDiff get_pads_begin() { return m_pads_begin; }

View File

@ -23,7 +23,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
float scale, power, shift;
};

View File

@ -33,7 +33,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
const ProposalAttrs& get_attrs() const { return m_attrs; }

View File

@ -22,7 +22,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
float get_slope() { return m_negative_slope; }

View File

@ -32,7 +32,7 @@ public:
RNNCellIE() = delete;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
void validate_and_infer_types() override;
std::size_t get_hidden_size() { return m_hidden_size; }

View File

@ -24,7 +24,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace op

View File

@ -24,7 +24,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
float gamma, alpha;
};

View File

@ -24,7 +24,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
int64_t axis, tiles;
};

View File

@ -22,7 +22,7 @@ op::CropIE::CropIE(const Output<Node>& data, std::vector<int64_t> axes, std::vec
constructor_validate_and_infer_types();
}
std::shared_ptr<Node> op::CropIE::copy_with_new_args(const NodeVector& new_args) const {
std::shared_ptr<Node> op::CropIE::clone_with_new_inputs(const OutputVector& new_args) const {
if (new_args.size() != 1) {
throw ngraph_error("Incorrect number of new arguments");
}

View File

@ -20,7 +20,7 @@ op::Eltwise::Eltwise(const Output<Node>& data1, const Output<Node>& data2, const
constructor_validate_and_infer_types();
}
std::shared_ptr<Node> op::Eltwise::copy_with_new_args(const NodeVector& new_args) const {
std::shared_ptr<Node> op::Eltwise::clone_with_new_inputs(const OutputVector& new_args) const {
if (new_args.size() != 2) {
throw ngraph_error("Incorrect number of new arguments");
}

View File

@ -21,7 +21,7 @@ op::FullyConnected::FullyConnected(const Output<Node>& A, const Output<Node>& B,
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::FullyConnected::copy_with_new_args(const NodeVector& new_args) const {
shared_ptr<Node> op::FullyConnected::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<FullyConnected>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_shape);
}

View File

@ -20,7 +20,7 @@ op::GatherTreeIE::GatherTreeIE(const Output<Node>& step_ids,
constructor_validate_and_infer_types();
}
shared_ptr<Node> op::GatherTreeIE::copy_with_new_args(const NodeVector& new_args) const {
shared_ptr<Node> op::GatherTreeIE::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<GatherTreeIE>(
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));

View File

@ -40,7 +40,7 @@ void op::GRUCellIE::validate_and_infer_types() {
set_output_type(0, arg_type, output_shape);
}
shared_ptr<Node> op::GRUCellIE::copy_with_new_args(const NodeVector& new_args) const {
shared_ptr<Node> op::GRUCellIE::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<op::GRUCellIE>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3),
m_hidden_size, m_activations, m_activations_alpha, m_activations_beta, m_clip,

View File

@ -31,7 +31,7 @@ void op::HardSigmoid_IE::validate_and_infer_types() {
set_output_type(0, arg_type, arg_shape);
}
shared_ptr<Node> op::HardSigmoid_IE::copy_with_new_args(const NodeVector& new_args) const {
shared_ptr<Node> op::HardSigmoid_IE::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<op::HardSigmoid_IE>(new_args.at(0), m_alpha, m_beta);
}

View File

@ -61,7 +61,7 @@ void op::Interp::validate_and_infer_types() {
}
}
shared_ptr<Node> op::Interp::copy_with_new_args(const NodeVector& new_args) const {
shared_ptr<Node> op::Interp::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<Interp>(new_args.at(0), m_attrs);
}
@ -101,7 +101,7 @@ void op::ResampleV2::validate_and_infer_types() {
}
}
shared_ptr<Node> op::ResampleV2::copy_with_new_args(const NodeVector& new_args) const {
shared_ptr<Node> op::ResampleV2::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<ResampleV2>(new_args.at(0), new_args.at(1), m_attrs);
}

View File

@ -28,7 +28,7 @@ void op::LRN_IE::validate_and_infer_types() {
set_output_type(0, arg_type, arg_shape);
}
shared_ptr<Node> op::LRN_IE::copy_with_new_args(const NodeVector& new_args) const {
shared_ptr<Node> op::LRN_IE::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<op::LRN_IE>(new_args.at(0), m_alpha, m_beta, m_bias, m_size, m_region);
}

View File

@ -37,7 +37,7 @@ void op::LSTMCellIE::validate_and_infer_types() {
set_output_type(1, arg_type, output_shape);
}
shared_ptr<Node> op::LSTMCellIE::copy_with_new_args(const NodeVector& new_args) const {
shared_ptr<Node> op::LSTMCellIE::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<op::LSTMCellIE>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4),
m_hidden_size, m_activations, m_activations_alpha, m_activations_beta, m_clip);

View File

@ -32,7 +32,7 @@ void op::NormalizeIE::validate_and_infer_types() {
"Argument must have rank >= 2 and <= 4 (argument shape: ", input_shape, ").");
}
shared_ptr<Node> op::NormalizeIE::copy_with_new_args(const NodeVector& new_args) const {
shared_ptr<Node> op::NormalizeIE::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<op::NormalizeIE>(new_args.at(0), new_args.at(1), m_eps, m_across_spatial, m_channel_shared);
}

View File

@ -31,7 +31,7 @@ void op::OneHotIE::validate_and_infer_types() {
}
}
shared_ptr<Node> op::OneHotIE::copy_with_new_args(const NodeVector& new_args) const {
shared_ptr<Node> op::OneHotIE::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<op::OneHotIE>(new_args.at(0), m_axis, m_depth, m_on_value, m_off_value, m_type);
}

View File

@ -41,6 +41,6 @@ void op::PadIE::validate_and_infer_types() {
set_output_type(0, get_input_element_type(0), m_output_shape);
}
shared_ptr<Node> op::PadIE::copy_with_new_args(const NodeVector& new_args) const {
shared_ptr<Node> op::PadIE::clone_with_new_inputs(const OutputVector& new_args) const {
return nullptr;
}

View File

@ -19,7 +19,7 @@ op::PowerIE::PowerIE(const Output<ngraph::Node>& data_batch, const float power,
constructor_validate_and_infer_types();
}
std::shared_ptr<Node> op::PowerIE::copy_with_new_args(const NodeVector& new_args) const {
std::shared_ptr<Node> op::PowerIE::clone_with_new_inputs(const OutputVector& new_args) const {
if (new_args.size() != 1) {
throw ngraph_error("Incorrect number of new arguments");
}

View File

@ -53,7 +53,7 @@ void op::ProposalIE::validate_and_infer_types() {
}
}
shared_ptr<Node> op::ProposalIE::copy_with_new_args(const NodeVector& new_args) const {
shared_ptr<Node> op::ProposalIE::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<ProposalIE>(new_args.at(0), new_args.at(1), new_args.at(2), m_attrs);
}

View File

@ -20,7 +20,7 @@ op::ReLUIE::ReLUIE(const Output<Node>& data, const float& negative_slope)
constructor_validate_and_infer_types();
}
std::shared_ptr<Node> op::ReLUIE::copy_with_new_args(const NodeVector& new_args) const {
std::shared_ptr<Node> op::ReLUIE::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<ReLUIE>(new_args.at(0), m_negative_slope);
}

View File

@ -37,7 +37,7 @@ void op::RNNCellIE::validate_and_infer_types() {
set_output_type(0, arg_type, output_shape);
}
shared_ptr<Node> op::RNNCellIE::copy_with_new_args(const NodeVector& new_args) const {
shared_ptr<Node> op::RNNCellIE::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<op::RNNCellIE>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3),
m_hidden_size, m_activations, m_activations_alpha, m_activations_beta, m_clip);
@ -50,4 +50,4 @@ bool op::RNNCellIE::visit_attributes(AttributeVisitor &visitor) {
visitor.on_attribute("activations_beta", m_activations_beta);
visitor.on_attribute("clip", m_clip);
return true;
}
}

View File

@ -19,7 +19,7 @@ op::ScaleShiftIE::ScaleShiftIE(const Output<Node>& data_batch, const Output<Node
constructor_validate_and_infer_types();
}
std::shared_ptr<Node> op::ScaleShiftIE::copy_with_new_args(const NodeVector& new_args) const {
std::shared_ptr<Node> op::ScaleShiftIE::clone_with_new_inputs(const OutputVector& new_args) const {
if (new_args.size() != 3) {
throw ngraph_error("Incorrect number of new arguments");
}

View File

@ -22,7 +22,7 @@ op::SeluIE::SeluIE(const Output<Node> & input,
constructor_validate_and_infer_types();
}
std::shared_ptr<Node> op::SeluIE::copy_with_new_args(const NodeVector& new_args) const {
std::shared_ptr<Node> op::SeluIE::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<SeluIE>(new_args.at(0), alpha, gamma);
}

View File

@ -20,7 +20,7 @@ op::TileIE::TileIE(const Output<ngraph::Node>& data1, const int64_t axis, const
constructor_validate_and_infer_types();
}
std::shared_ptr<Node> op::TileIE::copy_with_new_args(const NodeVector& new_args) const {
std::shared_ptr<Node> op::TileIE::clone_with_new_inputs(const OutputVector& new_args) const {
if (new_args.size() != 1) {
throw ngraph_error("Incorrect number of new arguments");
}

View File

@ -28,7 +28,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(ngraph::AttributeVisitor& visitor) override;

View File

@ -22,7 +22,7 @@ public:
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(ngraph::AttributeVisitor& visitor) override;

View File

@ -18,7 +18,7 @@ DynamicShapeResolver::DynamicShapeResolver(
constructor_validate_and_infer_types();
}
std::shared_ptr<Node> DynamicShapeResolver::copy_with_new_args(const NodeVector& new_args) const {
std::shared_ptr<Node> DynamicShapeResolver::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return std::make_shared<DynamicShapeResolver>(new_args.at(0), new_args.at(1), m_mode);
}

View File

@ -40,8 +40,8 @@ void StaticShapeNonZero::validate_and_infer_types() {
set_output_type(1, m_output_type, {Dimension(2)});
}
std::shared_ptr<Node> StaticShapeNonZero::copy_with_new_args(
const NodeVector& new_args) const {
std::shared_ptr<Node> StaticShapeNonZero::clone_with_new_inputs(
const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return std::make_shared<StaticShapeNonZero>(new_args.at(0), m_output_type);
}

View File

@ -33,7 +33,7 @@ public:
set_output_type(0, get_input_element_type(0), ngraph::PartialShape(output_shape));
}
std::shared_ptr<ngraph::Node> copy_with_new_args(const ngraph::NodeVector& new_args) const override {
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override {
if (new_args.size() != 1) {
throw ngraph::ngraph_error("Incorrect number of new arguments");
}

View File

@ -20,7 +20,7 @@ public:
void validate_and_infer_types() override {
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}
std::shared_ptr<ngraph::Node> copy_with_new_args(const ngraph::NodeVector& new_args) const override {
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override {
return std::make_shared<FakeAbs>(new_args.at(0));
}
bool visit_attributes(ngraph::AttributeVisitor& visitor) override {

View File

@ -175,7 +175,7 @@ public:
set_output_type(0, get_input_element_type(0), ngraph::PartialShape(output_shape));
}
std::shared_ptr<ngraph::Node> copy_with_new_args(const ngraph::NodeVector& new_args) const override {
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override {
if (new_args.size() != 1) {
throw ngraph::ngraph_error("Incorrect number of new arguments");
}

View File

@ -131,7 +131,7 @@ public:
set_output_type(0, get_input_element_type(0), ngraph::PartialShape(output_shape));
}
std::shared_ptr<ngraph::Node> copy_with_new_args(const ngraph::NodeVector& new_args) const override {
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override {
if (new_args.size() != 1) {
throw ngraph::ngraph_error("Incorrect number of new arguments");
}

View File

@ -43,7 +43,6 @@
* `Parameters` is now `ParameterVector`
* `NodeVector`, `ParameterVector`, `AxisVector`, `AxisSet`, `Shape`, `Stride`, `Coordinate`, and `CoordinateDiff` are now classes, not type aliases.
* `PrimaryTensorView` is now `TensorView` (and will merge into `Tensor`)
* `copy_with_new_args` is protected; use `copy_with_new_inputs` which takes an `OutputVector` as an argument and preserves control dependencies.
## Changes to ops

View File

@ -25,7 +25,8 @@ namespace ngraph
{
constexpr NodeTypeInfo NullNode::type_info;
std::shared_ptr<Node> NullNode::copy_with_new_args(const NodeVector& /* new_args */) const
std::shared_ptr<Node>
NullNode::clone_with_new_inputs(const OutputVector& /* new_args */) const
{
return std::make_shared<NullNode>();
}

View File

@ -49,7 +49,7 @@ namespace ngraph
NullNode() = default;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
clone_with_new_inputs(const OutputVector& new_args) const override;
};
} // namespace onnx_import
} // namespace ngraph

View File

@ -124,31 +124,6 @@ std::shared_ptr<Node>
return clone;
}
std::shared_ptr<Node> Node::copy_with_new_args(const NodeVector& args) const
{
NODE_VALIDATION_CHECK(
this, false, "Internal error: copy_with_new_args not replaced by clone_with_new_inputs");
}
std::shared_ptr<Node> Node::clone_with_new_inputs(const OutputVector& inputs) const
{
NodeVector args;
for (const Output<Node>& input : inputs)
{
args.push_back(get_output_element(input));
}
std::shared_ptr<Node> clone = copy_with_new_args(args);
// Remove the inserted GOEs
for (size_t i = 0; i < inputs.size(); ++i)
{
if (clone->input_value(i) != inputs.at(i))
{
clone->set_argument(i, inputs.at(i));
}
}
return clone;
}
void Node::safe_delete(NodeVector& nodes, bool recurse)
{
for (auto& input : m_inputs)

View File

@ -397,15 +397,8 @@ namespace ngraph
std::shared_ptr<Node> get_input_node_shared_ptr(size_t index) const;
Output<Node> get_input_source_output(size_t i) const;
protected:
// Will be replaced with clone_with_new_inputs
virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const
NGRAPH_DEPRECATED("use copy_with_new_inputs instead");
public:
// TODO: When all copy_with_new_args have been replaced with copy_with_new_inputs, make
// this pure and remove copy_with_new_args
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const;
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const = 0;
std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& new_args) const;
@ -625,7 +618,7 @@ namespace ngraph
{
NODE_VALIDATION_CHECK(node,
new_args.size() == node->input_values().size(),
"copy_with_new_args() expected ",
"clone_with_new_inputs() expected ",
node->input_values().size(),
" argument",
(node->input_values().size() == 1 ? "" : "s"),

View File

@ -160,18 +160,6 @@ NodeVector op::FakeQuantize::decompose_op() const
return {dequantized_data + output_low};
}
shared_ptr<Node> op::FakeQuantize::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<FakeQuantize>(new_args.at(0), // X
new_args.at(1), // input_low
new_args.at(2), // input_high
new_args.at(3), // output_low
new_args.at(4), // output_high
m_levels,
m_auto_broadcast);
}
shared_ptr<Node> op::FakeQuantize::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);

View File

@ -70,12 +70,6 @@ namespace ngraph
virtual NodeVector decompose_op() const override;
virtual void validate_and_infer_types() override;
// This is a hack to work around dldt directly calling copy_with_new_args
// When that code is replace with clone_with_new_inputs then remove this
// method.
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;

View File

@ -78,12 +78,6 @@ shared_ptr<Node> op::v1::Transpose::clone_with_new_inputs(const OutputVector& ne
return make_shared<v1::Transpose>(new_args[0], new_args[1]);
}
shared_ptr<Node> op::v1::Transpose::copy_with_new_args(const NodeVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<v1::Transpose>(new_args[0], new_args[1]);
}
namespace
{
template <element::Type_t ET>

View File

@ -50,9 +50,6 @@ namespace ngraph
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& new_args) const override;
bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};

View File

@ -78,7 +78,7 @@ namespace ngraph
}
virtual std::shared_ptr<Node>
copy_with_new_args(const NodeVector& /* new_args */) const override
clone_with_new_inputs(const OutputVector& /* new_args */) const override
{
throw ngraph_error("Uncopyable");
}

View File

@ -114,21 +114,3 @@ TEST(op, variant)
Ship& node_ship = as_type_ptr<VariantWrapper<Ship>>(node_var_ship)->get();
EXPECT_EQ(&node_ship, &ship);
}
// TODO: Need to mock Node, Op etc to be able to unit test functions like replace_node().
// Mocking them directly isn't possible because google test requires methods to be
// non-virtual. For non-virtual methods we will need to templatize these classes and call using
// different template argument between testing and production.
/*
TEST(op, provenance_replace_node)
{
class MockOp: public op::Op
{
MOCK_CONST_METHOD1(copy_with_new_args, std::shared_ptr<Node>(const NodeVector& new_args));
MOCK_CONST_METHOD1(get_users, NodeVector (bool check_is_used)); // This can't be mocked as
// it's non-virtual
};
::testing::NiceMock<MockOp> mock_op;
}
*/