Replaced copy_with_new_args() to clone_with_new_inputs() (#1395)
This commit is contained in:
parent
821a3dae32
commit
141b24cf44
@ -12,7 +12,7 @@ To add your custom nGraph operation, create a new class that extends `ngraph::Op
|
||||
|
||||
3. Override the shape inference method `validate_and_infer_types`. This method is called multiple times during graph manipulations to determine the shapes and element types of the outputs of the operations. You can access the input shapes through the `get_input_partial_shape()` method and input element types through the `get_input_element_type()` method of `ngraph::Node`. Set the inferred shape and element type of the output using `set_output_type`.
|
||||
|
||||
4. Override the `copy_with_new_args` method, which allows graph manipulation routines to create copies of this operation and connect it to different nodes during optimization.
|
||||
4. Override the `clone_with_new_inputs` method, which allows graph manipulation routines to create copies of this operation and connect it to different nodes during optimization.
|
||||
|
||||
5. Override the `visit_attributes` method, which allows serialization and deserialization of attributes. An `AttributeVisitor` is passed to the method, and the implementation is expected to walk over all the attributes in the op using the type-aware `on_attribute` helper. Helpers are already implemented for standard C++ types like `int64_t`, `float`, `bool`, `vector` and for existing nGraph defined types.
|
||||
|
||||
@ -39,9 +39,9 @@ nGraph operation contains two constructors: a default constructor, which allows
|
||||
|
||||
@snippet op.cpp op:validate
|
||||
|
||||
### `copy_with_new_args()`
|
||||
### `clone_with_new_inputs()`
|
||||
|
||||
`ngraph::Node::copy_with_new_args` method creates a copy of nGraph operation with new inputs.
|
||||
`ngraph::Node::clone_with_new_inputs` method creates a copy of nGraph operation with new inputs.
|
||||
|
||||
@snippet op.cpp op:copy
|
||||
|
||||
|
@ -21,7 +21,7 @@ void Operation::validate_and_infer_types() {
|
||||
//! [op:validate]
|
||||
|
||||
//! [op:copy]
|
||||
std::shared_ptr<ngraph::Node> Operation::copy_with_new_args(const ngraph::NodeVector &new_args) const {
|
||||
std::shared_ptr<ngraph::Node> Operation::clone_with_new_inputs(const ngraph::OutputVector &new_args) const {
|
||||
if (new_args.size() != 1) {
|
||||
throw ngraph::ngraph_error("Incorrect number of new arguments");
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ public:
|
||||
Operation() = default;
|
||||
Operation(const ngraph::Output<ngraph::Node>& arg, int64_t add);
|
||||
void validate_and_infer_types() override;
|
||||
std::shared_ptr<ngraph::Node> copy_with_new_args(const ngraph::NodeVector& new_args) const override;
|
||||
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override;
|
||||
bool visit_attributes(ngraph::AttributeVisitor& visitor) override;
|
||||
int64_t getAddAttr() { return add; }
|
||||
|
||||
|
@ -107,7 +107,7 @@ public:
|
||||
set_output_type(0, get_input_element_type(0), ngraph::PartialShape(output_shape));
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> copy_with_new_args(const ngraph::NodeVector& new_args) const override {
|
||||
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override {
|
||||
if (new_args.size() != 1) {
|
||||
throw ngraph::ngraph_error("Incorrect number of new arguments");
|
||||
}
|
||||
|
@ -57,11 +57,6 @@ std::vector<InferenceEngine::IShapeInferExtensionPtr> ngraph::op::GenericIE::get
|
||||
return extensions;
|
||||
}
|
||||
|
||||
ngraph::op::GenericIE::GenericIE(const ngraph::NodeVector& inputs,
|
||||
const std::map<std::string, InferenceEngine::Parameter>& params,
|
||||
const std::string type, const std::vector<PortIE>& outputs)
|
||||
: GenericIE(as_output_vector(inputs), params, type, outputs) {}
|
||||
|
||||
ngraph::op::GenericIE::GenericIE(const ngraph::OutputVector& inputs,
|
||||
const std::map<std::string, InferenceEngine::Parameter>& params_,
|
||||
const std::string type_, const std::vector<PortIE>& outputs_)
|
||||
@ -69,7 +64,7 @@ ngraph::op::GenericIE::GenericIE(const ngraph::OutputVector& inputs,
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> ngraph::op::GenericIE::copy_with_new_args(const ngraph::NodeVector& new_args) const {
|
||||
std::shared_ptr<ngraph::Node> ngraph::op::GenericIE::clone_with_new_inputs(const ngraph::OutputVector& new_args) const {
|
||||
auto genNode = std::make_shared<GenericIE>(new_args, params, type, outputs);
|
||||
genNode->extensions = extensions;
|
||||
genNode->reshape = reshape;
|
||||
|
@ -87,11 +87,6 @@ public:
|
||||
* @param type string with original layer type
|
||||
* @param outputs information about output ports from IR
|
||||
*/
|
||||
GenericIE(const NodeVector& inputs,
|
||||
const std::map<std::string, InferenceEngine::Parameter>& params,
|
||||
const std::string type,
|
||||
const std::vector<PortIE>& outputs);
|
||||
|
||||
GenericIE(const OutputVector& inputs,
|
||||
const std::map<std::string, InferenceEngine::Parameter>& params,
|
||||
const std::string type,
|
||||
@ -99,7 +94,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
static void addExtension(std::shared_ptr<const ngraph::Lambda> func, const InferenceEngine::IShapeInferExtensionPtr& ext);
|
||||
static std::vector<InferenceEngine::IShapeInferExtensionPtr> getExtensions(std::shared_ptr<const ngraph::Function> func);
|
||||
|
@ -26,7 +26,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
std::vector<int64_t> axes, dim, offset;
|
||||
};
|
||||
|
@ -26,7 +26,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
ELTWISE_TYPE eltwise_type;
|
||||
};
|
||||
|
@ -33,7 +33,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
size_t get_out_size() { return m_output_size; }
|
||||
|
||||
|
@ -33,7 +33,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -33,7 +33,7 @@ public:
|
||||
|
||||
GRUCellIE() = delete;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::size_t get_hidden_size() { return m_hidden_size; }
|
||||
|
@ -25,7 +25,7 @@ public:
|
||||
float alpha,
|
||||
float beta);
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
float get_alpha() const { return m_alpha; }
|
||||
|
@ -37,7 +37,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
InterpolateIEAttrs get_attrs() { return m_attrs; }
|
||||
|
||||
@ -65,7 +65,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
ResampleIEAttrs get_attrs() { return m_attrs; }
|
||||
private:
|
||||
|
@ -28,7 +28,7 @@ public:
|
||||
size_t size,
|
||||
std::string region);
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
double get_alpha() const { return m_alpha; }
|
||||
|
@ -33,7 +33,7 @@ public:
|
||||
|
||||
LSTMCellIE() = delete;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::size_t get_hidden_size() { return m_hidden_size; }
|
||||
|
@ -33,7 +33,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
protected:
|
||||
float m_eps;
|
||||
|
@ -26,7 +26,7 @@ public:
|
||||
size_t get_version() const override { return 1; }
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
int get_axis() { return m_axis; }
|
||||
int get_depth() { return m_depth; }
|
||||
|
@ -26,7 +26,7 @@ public:
|
||||
size_t get_version() const override { return 1; }
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
PadMode get_pad_mode() { return m_pad_mode; }
|
||||
CoordinateDiff get_pads_begin() { return m_pads_begin; }
|
||||
|
@ -23,7 +23,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
float scale, power, shift;
|
||||
};
|
||||
|
@ -33,7 +33,7 @@ public:
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node>
|
||||
copy_with_new_args(const NodeVector& new_args) const override;
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
const ProposalAttrs& get_attrs() const { return m_attrs; }
|
||||
|
||||
|
@ -22,7 +22,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
float get_slope() { return m_negative_slope; }
|
||||
|
||||
|
@ -32,7 +32,7 @@ public:
|
||||
|
||||
RNNCellIE() = delete;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::size_t get_hidden_size() { return m_hidden_size; }
|
||||
|
@ -24,7 +24,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -24,7 +24,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
float gamma, alpha;
|
||||
};
|
||||
|
@ -24,7 +24,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
int64_t axis, tiles;
|
||||
};
|
||||
|
@ -22,7 +22,7 @@ op::CropIE::CropIE(const Output<Node>& data, std::vector<int64_t> axes, std::vec
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::CropIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
std::shared_ptr<Node> op::CropIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
if (new_args.size() != 1) {
|
||||
throw ngraph_error("Incorrect number of new arguments");
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ op::Eltwise::Eltwise(const Output<Node>& data1, const Output<Node>& data2, const
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::Eltwise::copy_with_new_args(const NodeVector& new_args) const {
|
||||
std::shared_ptr<Node> op::Eltwise::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
if (new_args.size() != 2) {
|
||||
throw ngraph_error("Incorrect number of new arguments");
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ op::FullyConnected::FullyConnected(const Output<Node>& A, const Output<Node>& B,
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::FullyConnected::copy_with_new_args(const NodeVector& new_args) const {
|
||||
shared_ptr<Node> op::FullyConnected::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<FullyConnected>(new_args.at(0), new_args.at(1), new_args.at(2), m_output_shape);
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ op::GatherTreeIE::GatherTreeIE(const Output<Node>& step_ids,
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::GatherTreeIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
shared_ptr<Node> op::GatherTreeIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<GatherTreeIE>(
|
||||
new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3));
|
||||
|
@ -40,7 +40,7 @@ void op::GRUCellIE::validate_and_infer_types() {
|
||||
set_output_type(0, arg_type, output_shape);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::GRUCellIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
shared_ptr<Node> op::GRUCellIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::GRUCellIE>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3),
|
||||
m_hidden_size, m_activations, m_activations_alpha, m_activations_beta, m_clip,
|
||||
|
@ -31,7 +31,7 @@ void op::HardSigmoid_IE::validate_and_infer_types() {
|
||||
set_output_type(0, arg_type, arg_shape);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::HardSigmoid_IE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
shared_ptr<Node> op::HardSigmoid_IE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::HardSigmoid_IE>(new_args.at(0), m_alpha, m_beta);
|
||||
}
|
||||
|
@ -61,7 +61,7 @@ void op::Interp::validate_and_infer_types() {
|
||||
}
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::Interp::copy_with_new_args(const NodeVector& new_args) const {
|
||||
shared_ptr<Node> op::Interp::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<Interp>(new_args.at(0), m_attrs);
|
||||
}
|
||||
@ -101,7 +101,7 @@ void op::ResampleV2::validate_and_infer_types() {
|
||||
}
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::ResampleV2::copy_with_new_args(const NodeVector& new_args) const {
|
||||
shared_ptr<Node> op::ResampleV2::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<ResampleV2>(new_args.at(0), new_args.at(1), m_attrs);
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ void op::LRN_IE::validate_and_infer_types() {
|
||||
set_output_type(0, arg_type, arg_shape);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::LRN_IE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
shared_ptr<Node> op::LRN_IE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::LRN_IE>(new_args.at(0), m_alpha, m_beta, m_bias, m_size, m_region);
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ void op::LSTMCellIE::validate_and_infer_types() {
|
||||
set_output_type(1, arg_type, output_shape);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::LSTMCellIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
shared_ptr<Node> op::LSTMCellIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::LSTMCellIE>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4),
|
||||
m_hidden_size, m_activations, m_activations_alpha, m_activations_beta, m_clip);
|
||||
|
@ -32,7 +32,7 @@ void op::NormalizeIE::validate_and_infer_types() {
|
||||
"Argument must have rank >= 2 and <= 4 (argument shape: ", input_shape, ").");
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::NormalizeIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
shared_ptr<Node> op::NormalizeIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::NormalizeIE>(new_args.at(0), new_args.at(1), m_eps, m_across_spatial, m_channel_shared);
|
||||
}
|
||||
|
@ -31,7 +31,7 @@ void op::OneHotIE::validate_and_infer_types() {
|
||||
}
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::OneHotIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
shared_ptr<Node> op::OneHotIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::OneHotIE>(new_args.at(0), m_axis, m_depth, m_on_value, m_off_value, m_type);
|
||||
}
|
||||
|
@ -41,6 +41,6 @@ void op::PadIE::validate_and_infer_types() {
|
||||
set_output_type(0, get_input_element_type(0), m_output_shape);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::PadIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
shared_ptr<Node> op::PadIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ op::PowerIE::PowerIE(const Output<ngraph::Node>& data_batch, const float power,
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::PowerIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
std::shared_ptr<Node> op::PowerIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
if (new_args.size() != 1) {
|
||||
throw ngraph_error("Incorrect number of new arguments");
|
||||
}
|
||||
|
@ -53,7 +53,7 @@ void op::ProposalIE::validate_and_infer_types() {
|
||||
}
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::ProposalIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
shared_ptr<Node> op::ProposalIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<ProposalIE>(new_args.at(0), new_args.at(1), new_args.at(2), m_attrs);
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ op::ReLUIE::ReLUIE(const Output<Node>& data, const float& negative_slope)
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::ReLUIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
std::shared_ptr<Node> op::ReLUIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<ReLUIE>(new_args.at(0), m_negative_slope);
|
||||
}
|
||||
|
@ -37,7 +37,7 @@ void op::RNNCellIE::validate_and_infer_types() {
|
||||
set_output_type(0, arg_type, output_shape);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::RNNCellIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
shared_ptr<Node> op::RNNCellIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::RNNCellIE>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3),
|
||||
m_hidden_size, m_activations, m_activations_alpha, m_activations_beta, m_clip);
|
||||
@ -50,4 +50,4 @@ bool op::RNNCellIE::visit_attributes(AttributeVisitor &visitor) {
|
||||
visitor.on_attribute("activations_beta", m_activations_beta);
|
||||
visitor.on_attribute("clip", m_clip);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ op::ScaleShiftIE::ScaleShiftIE(const Output<Node>& data_batch, const Output<Node
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::ScaleShiftIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
std::shared_ptr<Node> op::ScaleShiftIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
if (new_args.size() != 3) {
|
||||
throw ngraph_error("Incorrect number of new arguments");
|
||||
}
|
||||
|
@ -22,7 +22,7 @@ op::SeluIE::SeluIE(const Output<Node> & input,
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::SeluIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
std::shared_ptr<Node> op::SeluIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<SeluIE>(new_args.at(0), alpha, gamma);
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ op::TileIE::TileIE(const Output<ngraph::Node>& data1, const int64_t axis, const
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> op::TileIE::copy_with_new_args(const NodeVector& new_args) const {
|
||||
std::shared_ptr<Node> op::TileIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
if (new_args.size() != 1) {
|
||||
throw ngraph_error("Incorrect number of new arguments");
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
bool visit_attributes(ngraph::AttributeVisitor& visitor) override;
|
||||
|
||||
|
@ -22,7 +22,7 @@ public:
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
bool visit_attributes(ngraph::AttributeVisitor& visitor) override;
|
||||
|
||||
|
@ -18,7 +18,7 @@ DynamicShapeResolver::DynamicShapeResolver(
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> DynamicShapeResolver::copy_with_new_args(const NodeVector& new_args) const {
|
||||
std::shared_ptr<Node> DynamicShapeResolver::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<DynamicShapeResolver>(new_args.at(0), new_args.at(1), m_mode);
|
||||
}
|
||||
|
@ -40,8 +40,8 @@ void StaticShapeNonZero::validate_and_infer_types() {
|
||||
set_output_type(1, m_output_type, {Dimension(2)});
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> StaticShapeNonZero::copy_with_new_args(
|
||||
const NodeVector& new_args) const {
|
||||
std::shared_ptr<Node> StaticShapeNonZero::clone_with_new_inputs(
|
||||
const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<StaticShapeNonZero>(new_args.at(0), m_output_type);
|
||||
}
|
||||
|
@ -33,7 +33,7 @@ public:
|
||||
set_output_type(0, get_input_element_type(0), ngraph::PartialShape(output_shape));
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> copy_with_new_args(const ngraph::NodeVector& new_args) const override {
|
||||
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override {
|
||||
if (new_args.size() != 1) {
|
||||
throw ngraph::ngraph_error("Incorrect number of new arguments");
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ public:
|
||||
void validate_and_infer_types() override {
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
}
|
||||
std::shared_ptr<ngraph::Node> copy_with_new_args(const ngraph::NodeVector& new_args) const override {
|
||||
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override {
|
||||
return std::make_shared<FakeAbs>(new_args.at(0));
|
||||
}
|
||||
bool visit_attributes(ngraph::AttributeVisitor& visitor) override {
|
||||
|
@ -175,7 +175,7 @@ public:
|
||||
set_output_type(0, get_input_element_type(0), ngraph::PartialShape(output_shape));
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> copy_with_new_args(const ngraph::NodeVector& new_args) const override {
|
||||
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override {
|
||||
if (new_args.size() != 1) {
|
||||
throw ngraph::ngraph_error("Incorrect number of new arguments");
|
||||
}
|
||||
|
@ -131,7 +131,7 @@ public:
|
||||
set_output_type(0, get_input_element_type(0), ngraph::PartialShape(output_shape));
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> copy_with_new_args(const ngraph::NodeVector& new_args) const override {
|
||||
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector& new_args) const override {
|
||||
if (new_args.size() != 1) {
|
||||
throw ngraph::ngraph_error("Incorrect number of new arguments");
|
||||
}
|
||||
|
@ -43,7 +43,6 @@
|
||||
* `Parameters` is now `ParameterVector`
|
||||
* `NodeVector`, `ParameterVector`, `AxisVector`, `AxisSet`, `Shape`, `Stride`, `Coordinate`, and `CoordinateDiff` are now classes, not type aliases.
|
||||
* `PrimaryTensorView` is now `TensorView` (and will merge into `Tensor`)
|
||||
* `copy_with_new_args` is protected; use `copy_with_new_inputs` which takes an `OutputVector` as an argument and preserves control dependencies.
|
||||
|
||||
## Changes to ops
|
||||
|
||||
|
@ -25,7 +25,8 @@ namespace ngraph
|
||||
{
|
||||
constexpr NodeTypeInfo NullNode::type_info;
|
||||
|
||||
std::shared_ptr<Node> NullNode::copy_with_new_args(const NodeVector& /* new_args */) const
|
||||
std::shared_ptr<Node>
|
||||
NullNode::clone_with_new_inputs(const OutputVector& /* new_args */) const
|
||||
{
|
||||
return std::make_shared<NullNode>();
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ namespace ngraph
|
||||
NullNode() = default;
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
copy_with_new_args(const NodeVector& new_args) const override;
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
};
|
||||
} // namespace onnx_import
|
||||
} // namespace ngraph
|
||||
|
@ -124,31 +124,6 @@ std::shared_ptr<Node>
|
||||
return clone;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> Node::copy_with_new_args(const NodeVector& args) const
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, false, "Internal error: copy_with_new_args not replaced by clone_with_new_inputs");
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> Node::clone_with_new_inputs(const OutputVector& inputs) const
|
||||
{
|
||||
NodeVector args;
|
||||
for (const Output<Node>& input : inputs)
|
||||
{
|
||||
args.push_back(get_output_element(input));
|
||||
}
|
||||
std::shared_ptr<Node> clone = copy_with_new_args(args);
|
||||
// Remove the inserted GOEs
|
||||
for (size_t i = 0; i < inputs.size(); ++i)
|
||||
{
|
||||
if (clone->input_value(i) != inputs.at(i))
|
||||
{
|
||||
clone->set_argument(i, inputs.at(i));
|
||||
}
|
||||
}
|
||||
return clone;
|
||||
}
|
||||
|
||||
void Node::safe_delete(NodeVector& nodes, bool recurse)
|
||||
{
|
||||
for (auto& input : m_inputs)
|
||||
|
@ -397,15 +397,8 @@ namespace ngraph
|
||||
std::shared_ptr<Node> get_input_node_shared_ptr(size_t index) const;
|
||||
Output<Node> get_input_source_output(size_t i) const;
|
||||
|
||||
protected:
|
||||
// Will be replaced with clone_with_new_inputs
|
||||
virtual std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const
|
||||
NGRAPH_DEPRECATED("use copy_with_new_inputs instead");
|
||||
|
||||
public:
|
||||
// TODO: When all copy_with_new_args have been replaced with copy_with_new_inputs, make
|
||||
// this pure and remove copy_with_new_args
|
||||
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const;
|
||||
virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const = 0;
|
||||
|
||||
std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& new_args) const;
|
||||
|
||||
@ -625,7 +618,7 @@ namespace ngraph
|
||||
{
|
||||
NODE_VALIDATION_CHECK(node,
|
||||
new_args.size() == node->input_values().size(),
|
||||
"copy_with_new_args() expected ",
|
||||
"clone_with_new_inputs() expected ",
|
||||
node->input_values().size(),
|
||||
" argument",
|
||||
(node->input_values().size() == 1 ? "" : "s"),
|
||||
|
@ -160,18 +160,6 @@ NodeVector op::FakeQuantize::decompose_op() const
|
||||
return {dequantized_data + output_low};
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::FakeQuantize::copy_with_new_args(const NodeVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<FakeQuantize>(new_args.at(0), // X
|
||||
new_args.at(1), // input_low
|
||||
new_args.at(2), // input_high
|
||||
new_args.at(3), // output_low
|
||||
new_args.at(4), // output_high
|
||||
m_levels,
|
||||
m_auto_broadcast);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::FakeQuantize::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
|
@ -70,12 +70,6 @@ namespace ngraph
|
||||
virtual NodeVector decompose_op() const override;
|
||||
virtual void validate_and_infer_types() override;
|
||||
|
||||
// This is a hack to work around dldt directly calling copy_with_new_args
|
||||
// When that code is replace with clone_with_new_inputs then remove this
|
||||
// method.
|
||||
virtual std::shared_ptr<Node>
|
||||
copy_with_new_args(const NodeVector& new_args) const override;
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
|
@ -78,12 +78,6 @@ shared_ptr<Node> op::v1::Transpose::clone_with_new_inputs(const OutputVector& ne
|
||||
return make_shared<v1::Transpose>(new_args[0], new_args[1]);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v1::Transpose::copy_with_new_args(const NodeVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<v1::Transpose>(new_args[0], new_args[1]);
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
template <element::Type_t ET>
|
||||
|
@ -50,9 +50,6 @@ namespace ngraph
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
copy_with_new_args(const NodeVector& new_args) const override;
|
||||
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) override;
|
||||
};
|
||||
|
@ -78,7 +78,7 @@ namespace ngraph
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
copy_with_new_args(const NodeVector& /* new_args */) const override
|
||||
clone_with_new_inputs(const OutputVector& /* new_args */) const override
|
||||
{
|
||||
throw ngraph_error("Uncopyable");
|
||||
}
|
||||
|
@ -114,21 +114,3 @@ TEST(op, variant)
|
||||
Ship& node_ship = as_type_ptr<VariantWrapper<Ship>>(node_var_ship)->get();
|
||||
EXPECT_EQ(&node_ship, &ship);
|
||||
}
|
||||
|
||||
// TODO: Need to mock Node, Op etc to be able to unit test functions like replace_node().
|
||||
// Mocking them directly isn't possible because google test requires methods to be
|
||||
// non-virtual. For non-virtual methods we will need to templatize these classes and call using
|
||||
// different template argument between testing and production.
|
||||
/*
|
||||
TEST(op, provenance_replace_node)
|
||||
{
|
||||
class MockOp: public op::Op
|
||||
{
|
||||
MOCK_CONST_METHOD1(copy_with_new_args, std::shared_ptr<Node>(const NodeVector& new_args));
|
||||
MOCK_CONST_METHOD1(get_users, NodeVector (bool check_is_used)); // This can't be mocked as
|
||||
// it's non-virtual
|
||||
};
|
||||
|
||||
::testing::NiceMock<MockOp> mock_op;
|
||||
}
|
||||
*/
|
||||
|
Loading…
Reference in New Issue
Block a user