Avoid excess tensor copy for Reshape/Squeeze/Unsqueeze folding (#2834)
* Updated Reshape ConstantFolding to avoid excess tensor copies * Updated Squeeze/Unsqueeze CF to avoid excess tensor copies * Fixed typo
This commit is contained in:
parent
314ec2df72
commit
d36bd8c87b
@ -285,6 +285,12 @@ namespace ngraph
|
||||
/// Repeated values are allowed.
|
||||
AxisSet get_axis_set_val() const;
|
||||
|
||||
/// \brief Update Constant shape. New shape size must equal to the data elements
|
||||
/// count
|
||||
///
|
||||
/// \param shape The shape of the tensor constant.
|
||||
void set_data_shape(const Shape& shape);
|
||||
|
||||
/// \brief Wrapper around constructing a shared_ptr of a Constant
|
||||
///
|
||||
/// \param type The element type of the tensor constant.
|
||||
|
@ -153,6 +153,8 @@ namespace ngraph
|
||||
void set_special_zero(bool special_zero) { m_special_zero = special_zero; }
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const override;
|
||||
bool constant_fold(OutputVector& output_values,
|
||||
const OutputVector& inputs_values) override;
|
||||
|
||||
protected:
|
||||
bool m_special_zero;
|
||||
|
@ -44,6 +44,8 @@ namespace ngraph
|
||||
virtual void pre_validate_and_infer_types() override;
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const override;
|
||||
bool constant_fold(OutputVector& output_values,
|
||||
const OutputVector& inputs_values) override;
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
@ -45,6 +45,8 @@ namespace ngraph
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
bool evaluate(const HostTensorVector& outputs,
|
||||
const HostTensorVector& inputs) const override;
|
||||
bool constant_fold(OutputVector& output_values,
|
||||
const OutputVector& inputs_values) override;
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
@ -540,6 +540,12 @@ AxisSet op::Constant::get_axis_set_val() const
|
||||
return output_axis_set;
|
||||
}
|
||||
|
||||
void op::Constant::set_data_shape(const Shape& shape)
|
||||
{
|
||||
NGRAPH_CHECK(shape_size(shape) == shape_size(m_shape));
|
||||
m_shape = shape;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::Constant::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
|
@ -453,3 +453,33 @@ bool op::v1::Reshape::evaluate(const HostTensorVector& outputs,
|
||||
const AxisVector order = get_default_order(inputs[0]->get_shape());
|
||||
return evaluate_reshape(inputs[0], outputs[0], order);
|
||||
}
|
||||
|
||||
bool op::v1::Reshape::constant_fold(OutputVector& output_values, const OutputVector& inputs_values)
|
||||
{
|
||||
if (get_output_partial_shape(0).is_dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& shape = get_output_shape(0);
|
||||
|
||||
if (auto data_const =
|
||||
std::dynamic_pointer_cast<op::Constant>(inputs_values[0].get_node_shared_ptr()))
|
||||
{
|
||||
// In case if data constant has single consumer we can change it shape without making a copy
|
||||
// Otherwise we create Constant copy with shape from reshape node
|
||||
if (data_const->output(0).get_target_inputs().size() == 1)
|
||||
{
|
||||
data_const->set_data_shape(shape);
|
||||
data_const->validate_and_infer_types();
|
||||
output_values[0] = data_const;
|
||||
}
|
||||
else
|
||||
{
|
||||
output_values[0] = std::make_shared<op::Constant>(
|
||||
data_const->get_element_type(), shape, data_const->get_data_ptr());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
@ -212,3 +212,33 @@ bool op::v0::Squeeze::evaluate(const HostTensorVector& outputs,
|
||||
OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v0::Squeeze::evaluate");
|
||||
return squeeze::evaluate_squeeze(inputs[0], inputs[1], outputs[0]);
|
||||
}
|
||||
|
||||
bool op::v0::Squeeze::constant_fold(OutputVector& output_values, const OutputVector& inputs_values)
|
||||
{
|
||||
if (get_output_partial_shape(0).is_dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& shape = get_output_shape(0);
|
||||
|
||||
if (auto data_const =
|
||||
std::dynamic_pointer_cast<op::Constant>(inputs_values[0].get_node_shared_ptr()))
|
||||
{
|
||||
// In case if data constant has single consumer we can change it shape without making a copy
|
||||
// Otherwise we create Constant copy with shape from squeeze node
|
||||
if (data_const->output(0).get_target_inputs().size() == 1)
|
||||
{
|
||||
data_const->set_data_shape(shape);
|
||||
data_const->validate_and_infer_types();
|
||||
output_values[0] = data_const;
|
||||
}
|
||||
else
|
||||
{
|
||||
output_values[0] = std::make_shared<op::Constant>(
|
||||
data_const->get_element_type(), shape, data_const->get_data_ptr());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
@ -173,3 +173,34 @@ bool op::v0::Unsqueeze::evaluate(const HostTensorVector& outputs,
|
||||
OV_ITT_SCOPED_TASK(itt::domains::nGraphOp, "op::v0::Unsqueeze::evaluate");
|
||||
return unsqueeze::evaluate_unsqueeze(inputs[0], inputs[1], outputs[0]);
|
||||
}
|
||||
|
||||
bool op::v0::Unsqueeze::constant_fold(OutputVector& output_values,
|
||||
const OutputVector& inputs_values)
|
||||
{
|
||||
if (get_output_partial_shape(0).is_dynamic())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& shape = get_output_shape(0);
|
||||
|
||||
if (auto data_const =
|
||||
std::dynamic_pointer_cast<op::Constant>(inputs_values[0].get_node_shared_ptr()))
|
||||
{
|
||||
// In case if data constant has single consumer we can change it shape without making a copy
|
||||
// Otherwise we create Constant copy with shape from unsqueeze node
|
||||
if (data_const->output(0).get_target_inputs().size() == 1)
|
||||
{
|
||||
data_const->set_data_shape(shape);
|
||||
data_const->validate_and_infer_types();
|
||||
output_values[0] = data_const;
|
||||
}
|
||||
else
|
||||
{
|
||||
output_values[0] = std::make_shared<op::Constant>(
|
||||
data_const->get_element_type(), shape, data_const->get_data_ptr());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
Loading…
Reference in New Issue
Block a user