Avoid Constant data copy inside Reshape constant folding (#6410)
* Avoid Constant data copy inside Reshape constant folding * Fix Codestyle * Updated Squeeze, Unsqueeze cf * Deprecate set_data_shape method * Fix Pruning
This commit is contained in:
parent
ab8d046642
commit
0a1cad52ab
@ -419,13 +419,12 @@ public:
|
|||||||
auto fq_node = std::dynamic_pointer_cast<op::FakeQuantize>(m_output.get_node_shared_ptr());
|
auto fq_node = std::dynamic_pointer_cast<op::FakeQuantize>(m_output.get_node_shared_ptr());
|
||||||
size_t idx = 0;
|
size_t idx = 0;
|
||||||
if (fq_node->get_auto_broadcast() != ngraph::op::AutoBroadcastType::NONE) {
|
if (fq_node->get_auto_broadcast() != ngraph::op::AutoBroadcastType::NONE) {
|
||||||
for (auto const_node : fq_params_nodes) {
|
for (auto node : fq_params_nodes) {
|
||||||
|
auto const_node = std::dynamic_pointer_cast<op::Constant>(node);
|
||||||
|
if (!const_node) throw ngraph_error("Unexpected operation type.");
|
||||||
auto new_shape = broadcast_shape_to_rank(const_node->get_shape(),
|
auto new_shape = broadcast_shape_to_rank(const_node->get_shape(),
|
||||||
m_input.get_partial_shape().rank().get_length());
|
m_input.get_partial_shape().rank().get_length());
|
||||||
auto const_copy = const_node->clone_with_new_inputs(const_node->input_values());
|
auto new_const = std::make_shared<op::Constant>(*const_node, new_shape);
|
||||||
auto new_const = std::dynamic_pointer_cast<op::Constant>(const_copy);
|
|
||||||
new_const->set_data_shape(new_shape);
|
|
||||||
new_const->validate_and_infer_types();
|
|
||||||
new_const->set_friendly_name(const_node->get_friendly_name());
|
new_const->set_friendly_name(const_node->get_friendly_name());
|
||||||
ngraph::copy_runtime_info(const_node, new_const);
|
ngraph::copy_runtime_info(const_node, new_const);
|
||||||
ngraph::replace_node(const_node, new_const);
|
ngraph::replace_node(const_node, new_const);
|
||||||
|
@ -155,6 +155,7 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
|
|
||||||
Constant(const Constant& other);
|
Constant(const Constant& other);
|
||||||
|
Constant(const Constant& other, const Shape& new_shape);
|
||||||
Constant& operator=(const Constant&) = delete;
|
Constant& operator=(const Constant&) = delete;
|
||||||
|
|
||||||
virtual ~Constant() override;
|
virtual ~Constant() override;
|
||||||
@ -213,6 +214,7 @@ namespace ngraph
|
|||||||
/// count
|
/// count
|
||||||
///
|
///
|
||||||
/// \param shape The shape of the tensor constant.
|
/// \param shape The shape of the tensor constant.
|
||||||
|
NGRAPH_DEPRECATED("Use Constant c-tor with shape argument instead")
|
||||||
void set_data_shape(const Shape& shape);
|
void set_data_shape(const Shape& shape);
|
||||||
|
|
||||||
/// \brief Wrapper around constructing a shared_ptr of a Constant
|
/// \brief Wrapper around constructing a shared_ptr of a Constant
|
||||||
|
@ -162,6 +162,18 @@ op::Constant::Constant(const Constant& other)
|
|||||||
constructor_validate_and_infer_types();
|
constructor_validate_and_infer_types();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
op::Constant::Constant(const Constant& other, const Shape& new_shape)
|
||||||
|
{
|
||||||
|
NGRAPH_CHECK(shape_size(other.m_shape) == shape_size(new_shape),
|
||||||
|
"Shape size " + std::to_string(shape_size(new_shape)) + " is not equal to " +
|
||||||
|
std::to_string(shape_size(other.m_shape)));
|
||||||
|
m_element_type = other.m_element_type;
|
||||||
|
m_shape = new_shape;
|
||||||
|
m_data = other.m_data;
|
||||||
|
m_all_elements_bitwise_identical = other.m_all_elements_bitwise_identical;
|
||||||
|
constructor_validate_and_infer_types();
|
||||||
|
}
|
||||||
|
|
||||||
op::Constant::~Constant() {}
|
op::Constant::~Constant() {}
|
||||||
|
|
||||||
string op::Constant::convert_value_to_string(size_t index) const
|
string op::Constant::convert_value_to_string(size_t index) const
|
||||||
|
@ -241,19 +241,7 @@ bool op::v1::Reshape::constant_fold(OutputVector& output_values, const OutputVec
|
|||||||
if (auto data_const =
|
if (auto data_const =
|
||||||
std::dynamic_pointer_cast<op::Constant>(inputs_values[0].get_node_shared_ptr()))
|
std::dynamic_pointer_cast<op::Constant>(inputs_values[0].get_node_shared_ptr()))
|
||||||
{
|
{
|
||||||
// In case if data constant has single consumer we can change it shape without making a copy
|
output_values[0] = std::make_shared<op::Constant>(*data_const, shape);
|
||||||
// Otherwise we create Constant copy with shape from reshape node
|
|
||||||
if (data_const->output(0).get_target_inputs().size() == 1)
|
|
||||||
{
|
|
||||||
data_const->set_data_shape(shape);
|
|
||||||
data_const->validate_and_infer_types();
|
|
||||||
output_values[0] = data_const;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
output_values[0] = std::make_shared<op::Constant>(
|
|
||||||
data_const->get_element_type(), shape, data_const->get_data_ptr());
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -327,19 +327,7 @@ bool op::v0::Squeeze::constant_fold(OutputVector& output_values, const OutputVec
|
|||||||
if (auto data_const =
|
if (auto data_const =
|
||||||
std::dynamic_pointer_cast<op::Constant>(inputs_values[0].get_node_shared_ptr()))
|
std::dynamic_pointer_cast<op::Constant>(inputs_values[0].get_node_shared_ptr()))
|
||||||
{
|
{
|
||||||
// In case if data constant has single consumer we can change it shape without making a copy
|
output_values[0] = std::make_shared<op::Constant>(*data_const, shape);
|
||||||
// Otherwise we create Constant copy with shape from squeeze node
|
|
||||||
if (data_const->output(0).get_target_inputs().size() == 1)
|
|
||||||
{
|
|
||||||
data_const->set_data_shape(shape);
|
|
||||||
data_const->validate_and_infer_types();
|
|
||||||
output_values[0] = data_const;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
output_values[0] = std::make_shared<op::Constant>(
|
|
||||||
data_const->get_element_type(), shape, data_const->get_data_ptr());
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -190,19 +190,7 @@ bool op::v0::Unsqueeze::constant_fold(OutputVector& output_values,
|
|||||||
if (auto data_const =
|
if (auto data_const =
|
||||||
std::dynamic_pointer_cast<op::Constant>(inputs_values[0].get_node_shared_ptr()))
|
std::dynamic_pointer_cast<op::Constant>(inputs_values[0].get_node_shared_ptr()))
|
||||||
{
|
{
|
||||||
// In case if data constant has single consumer we can change it shape without making a copy
|
output_values[0] = std::make_shared<op::Constant>(*data_const, shape);
|
||||||
// Otherwise we create Constant copy with shape from unsqueeze node
|
|
||||||
if (data_const->output(0).get_target_inputs().size() == 1)
|
|
||||||
{
|
|
||||||
data_const->set_data_shape(shape);
|
|
||||||
data_const->validate_and_infer_types();
|
|
||||||
output_values[0] = data_const;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
output_values[0] = std::make_shared<op::Constant>(
|
|
||||||
data_const->get_element_type(), shape, data_const->get_data_ptr());
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
@ -2274,6 +2274,75 @@ TEST(constant_folding, constant_dyn_reshape_shape_not_originally_constant)
|
|||||||
ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
|
ASSERT_TRUE(test::all_close_f(values_in, values_out, MIN_FLOAT_TOLERANCE_BITS));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(constant_folding, const_reshape_no_data_copy)
|
||||||
|
{
|
||||||
|
auto const_data = op::Constant::create(element::f32, Shape{1, 64}, {1});
|
||||||
|
auto const_reshape = op::Constant::create(element::i64, Shape{2}, {2, 32});
|
||||||
|
auto reshape = std::make_shared<op::v1::Reshape>(const_data, const_reshape, false);
|
||||||
|
auto consumer1 = std::make_shared<op::Relu>(reshape);
|
||||||
|
auto consumer2 = std::make_shared<op::Relu>(reshape);
|
||||||
|
|
||||||
|
auto f = std::make_shared<Function>(NodeVector{consumer1, consumer2}, ParameterVector{});
|
||||||
|
|
||||||
|
pass::Manager pass_manager;
|
||||||
|
pass_manager.register_pass<pass::ConstantFolding>();
|
||||||
|
pass_manager.run_passes(f);
|
||||||
|
|
||||||
|
auto const1 = std::dynamic_pointer_cast<op::Constant>(consumer1->input_value(0).get_node_shared_ptr());
|
||||||
|
auto const2 = std::dynamic_pointer_cast<op::Constant>(consumer2->input_value(0).get_node_shared_ptr());
|
||||||
|
|
||||||
|
ASSERT_TRUE(const1);
|
||||||
|
ASSERT_TRUE(const2);
|
||||||
|
ASSERT_EQ(const1, const2);
|
||||||
|
ASSERT_EQ(const1->get_data_ptr(), const2->get_data_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(constant_folding, const_squeeze_no_data_copy)
|
||||||
|
{
|
||||||
|
auto const_data = op::Constant::create(element::f32, Shape{1, 64}, {1});
|
||||||
|
auto const_reshape = op::Constant::create(element::i64, Shape{1}, {0});
|
||||||
|
auto reshape = std::make_shared<op::v0::Squeeze>(const_data, const_reshape);
|
||||||
|
auto consumer1 = std::make_shared<op::Relu>(reshape);
|
||||||
|
auto consumer2 = std::make_shared<op::Relu>(reshape);
|
||||||
|
|
||||||
|
auto f = std::make_shared<Function>(NodeVector{consumer1, consumer2}, ParameterVector{});
|
||||||
|
|
||||||
|
pass::Manager pass_manager;
|
||||||
|
pass_manager.register_pass<pass::ConstantFolding>();
|
||||||
|
pass_manager.run_passes(f);
|
||||||
|
|
||||||
|
auto const1 = std::dynamic_pointer_cast<op::Constant>(consumer1->input_value(0).get_node_shared_ptr());
|
||||||
|
auto const2 = std::dynamic_pointer_cast<op::Constant>(consumer2->input_value(0).get_node_shared_ptr());
|
||||||
|
|
||||||
|
ASSERT_TRUE(const1);
|
||||||
|
ASSERT_TRUE(const2);
|
||||||
|
ASSERT_EQ(const1, const2);
|
||||||
|
ASSERT_EQ(const1->get_data_ptr(), const2->get_data_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(constant_folding, const_unsqueeze_no_data_copy)
|
||||||
|
{
|
||||||
|
auto const_data = op::Constant::create(element::f32, Shape{1, 64}, {1});
|
||||||
|
auto const_reshape = op::Constant::create(element::i64, Shape{1}, {0});
|
||||||
|
auto reshape = std::make_shared<op::v0::Unsqueeze>(const_data, const_reshape);
|
||||||
|
auto consumer1 = std::make_shared<op::Relu>(reshape);
|
||||||
|
auto consumer2 = std::make_shared<op::Relu>(reshape);
|
||||||
|
|
||||||
|
auto f = std::make_shared<Function>(NodeVector{consumer1, consumer2}, ParameterVector{});
|
||||||
|
|
||||||
|
pass::Manager pass_manager;
|
||||||
|
pass_manager.register_pass<pass::ConstantFolding>();
|
||||||
|
pass_manager.run_passes(f);
|
||||||
|
|
||||||
|
auto const1 = std::dynamic_pointer_cast<op::Constant>(consumer1->input_value(0).get_node_shared_ptr());
|
||||||
|
auto const2 = std::dynamic_pointer_cast<op::Constant>(consumer2->input_value(0).get_node_shared_ptr());
|
||||||
|
|
||||||
|
ASSERT_TRUE(const1);
|
||||||
|
ASSERT_TRUE(const2);
|
||||||
|
ASSERT_EQ(const1, const2);
|
||||||
|
ASSERT_EQ(const1->get_data_ptr(), const2->get_data_ptr());
|
||||||
|
}
|
||||||
|
|
||||||
TEST(constant_folding, constant_transpose)
|
TEST(constant_folding, constant_transpose)
|
||||||
{
|
{
|
||||||
Shape shape_in{2, 4};
|
Shape shape_in{2, 4};
|
||||||
|
Loading…
Reference in New Issue
Block a user