Removed v0 operations from AlgebraicSimplufication pass (#1481)
* Removed v0 operations from AlgebraicSimplufication pass * Fixed tests
This commit is contained in:
parent
ffcb7fab2d
commit
af3a0900b0
@ -48,314 +48,6 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
template <typename T>
|
||||
static shared_ptr<pattern::Matcher>
|
||||
create_binary_matcher(shared_ptr<pattern::op::Label> label,
|
||||
shared_ptr<pattern::op::Label> const_label)
|
||||
{
|
||||
auto bcst =
|
||||
make_shared<pattern::op::Skip>(const_label, pattern::has_class<op::v0::Broadcast>());
|
||||
auto bcst_label = make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
|
||||
auto matcher = make_shared<pattern::Matcher>(make_shared<T>(label, bcst_label));
|
||||
return matcher;
|
||||
}
|
||||
|
||||
//`simplify_concat` identifies slices-concat sequences
|
||||
// that cancel each other. Namely it replaces subgraphs
|
||||
// similar to the one below with `arg`
|
||||
//
|
||||
// +----------+
|
||||
// +----+slice(n/2..n)---+
|
||||
// +-------+ | +----------+ | +-----------+
|
||||
// | arg +--+ +--+ concat |
|
||||
// +-------+ | +----------+ | +-----------+
|
||||
// +----+slice(0..n/2)---+
|
||||
// +----------+
|
||||
static bool simplify_concat(shared_ptr<Node> n)
|
||||
{
|
||||
NGRAPH_DEBUG << "In simplify_concat for " << n->get_name();
|
||||
if (n->get_output_partial_shape(0).is_dynamic())
|
||||
{
|
||||
NGRAPH_DEBUG << n << " has dynamic shape";
|
||||
return false;
|
||||
}
|
||||
|
||||
Output<Node> branch_tip;
|
||||
|
||||
auto ltip = make_shared<pattern::op::Label>(element::i32, Shape{2, 1});
|
||||
|
||||
auto pslice =
|
||||
make_shared<op::v0::Slice>(ltip, Coordinate{0, 0}, Coordinate{2, 1}, Strides{1, 1});
|
||||
|
||||
auto lslice = make_shared<pattern::op::Label>(pslice, nullptr, NodeVector{pslice});
|
||||
|
||||
auto skip_reshape =
|
||||
make_shared<pattern::op::Skip>(lslice, pattern::has_class<op::v0::Reshape>());
|
||||
|
||||
auto matcher = make_shared<pattern::Matcher>(skip_reshape);
|
||||
|
||||
Coordinate prev_lower_bounds;
|
||||
Shape prev_slice_shape;
|
||||
|
||||
for (auto carg : n->input_values())
|
||||
{
|
||||
if (!matcher->match(carg))
|
||||
{
|
||||
NGRAPH_DEBUG << carg << " doesn't match";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto& pattern_value_map = matcher->get_pattern_value_map();
|
||||
auto slice = as_type_ptr<op::v0::Slice>(pattern_value_map[lslice].get_node_shared_ptr());
|
||||
if (branch_tip != Output<Node>())
|
||||
{
|
||||
if (branch_tip != pattern_value_map[ltip])
|
||||
{
|
||||
NGRAPH_DEBUG << branch_tip << " doesn't match " << pattern_value_map[ltip];
|
||||
return false;
|
||||
}
|
||||
|
||||
// slice chunks should be slice in the same order as slice nodes in concat's argument
|
||||
// list
|
||||
auto cur_lower_bounds = slice->get_lower_bounds();
|
||||
if (cur_lower_bounds < prev_lower_bounds)
|
||||
{
|
||||
NGRAPH_DEBUG << slice << " is in the wrong order";
|
||||
return false;
|
||||
}
|
||||
prev_lower_bounds.assign(cur_lower_bounds.begin(), cur_lower_bounds.end());
|
||||
|
||||
// slice shapes need to match
|
||||
if (slice->get_shape() != prev_slice_shape)
|
||||
{
|
||||
NGRAPH_DEBUG << slice << " doesn't match the shape of the previous slice";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
branch_tip = pattern_value_map[ltip];
|
||||
prev_lower_bounds.assign(slice->get_lower_bounds().begin(),
|
||||
slice->get_lower_bounds().end());
|
||||
prev_slice_shape.assign(slice->get_shape().begin(), slice->get_shape().end());
|
||||
NGRAPH_DEBUG << "setting branch_tip to " << branch_tip;
|
||||
}
|
||||
|
||||
if (slice->get_users(true).size() > 1)
|
||||
{
|
||||
NGRAPH_DEBUG << slice << " has more than one user";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (shape_size(slice->get_strides()) != 1)
|
||||
{
|
||||
NGRAPH_DEBUG << slice << " is strided";
|
||||
return false;
|
||||
}
|
||||
|
||||
// check that no other node uses slices and reshapes
|
||||
if (auto rcarg = as_type_ptr<op::v0::Reshape>(carg.get_node_shared_ptr()))
|
||||
{
|
||||
auto default_shape = get_default_order(rcarg->input_value(0).get_shape());
|
||||
if (default_shape != rcarg->get_input_order())
|
||||
{
|
||||
NGRAPH_DEBUG << carg << " reshape also does transposes";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (rcarg->get_users(true).size() > 1)
|
||||
{
|
||||
NGRAPH_DEBUG << rcarg << " has more than one user";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto concat = static_pointer_cast<op::v0::Concat>(n);
|
||||
auto concat_axis = concat->get_concatenation_axis();
|
||||
|
||||
auto slice_shape = branch_tip.get_node_shared_ptr()->get_users(true).at(0)->get_shape();
|
||||
size_t slice_axis = numeric_limits<size_t>::max();
|
||||
|
||||
auto btip_shape = branch_tip.get_shape();
|
||||
|
||||
// slices should cover all elements
|
||||
if (shape_size(btip_shape) != shape_size(n->get_shape()))
|
||||
{
|
||||
NGRAPH_DEBUG << "The number of elements in Concat (" << shape_size(n->get_shape())
|
||||
<< ") and the total of elements in slices (" << shape_size(btip_shape)
|
||||
<< ") don't match";
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < btip_shape.size(); i++)
|
||||
{
|
||||
if (btip_shape[i] != slice_shape[i])
|
||||
{
|
||||
if (slice_axis != numeric_limits<size_t>::max())
|
||||
{
|
||||
// multi-axis slice + concat do not cancel
|
||||
return false;
|
||||
}
|
||||
slice_axis = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (slice_axis == numeric_limits<size_t>::max())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
auto replacement = branch_tip;
|
||||
if (btip_shape != n->get_shape())
|
||||
{
|
||||
auto default_order = get_default_order(btip_shape);
|
||||
if (concat_axis == slice_axis)
|
||||
{
|
||||
// logical reshape only
|
||||
replacement =
|
||||
make_shared<op::v0::Reshape>(branch_tip, default_order, concat->get_shape());
|
||||
}
|
||||
else
|
||||
{
|
||||
// axis reordering required
|
||||
auto transposed_shape = n->get_shape();
|
||||
|
||||
if (btip_shape.size() >= transposed_shape.size())
|
||||
{
|
||||
AxisVector order = get_default_order(btip_shape);
|
||||
auto ax = order[slice_axis];
|
||||
order[slice_axis] = order[concat_axis];
|
||||
order[concat_axis] = ax;
|
||||
replacement = make_shared<op::v0::Reshape>(branch_tip, order, transposed_shape);
|
||||
}
|
||||
else if (btip_shape.size() < transposed_shape.size())
|
||||
{
|
||||
// intermediate logical reshape
|
||||
AxisVector order = get_default_order(transposed_shape);
|
||||
auto ax = order[slice_axis];
|
||||
order[slice_axis] = order[concat_axis];
|
||||
order[concat_axis] = ax;
|
||||
auto output_shape = apply_permutation(transposed_shape, order);
|
||||
auto logical_reshape =
|
||||
make_shared<op::v0::Reshape>(branch_tip, default_order, output_shape);
|
||||
// transpose to final concatenated shape
|
||||
replacement =
|
||||
make_shared<op::v0::Reshape>(logical_reshape, order, transposed_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
n->output(0).replace(replacement);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool is_uniform_constant(const op::Constant* constant, int value)
|
||||
{
|
||||
bool rc = false;
|
||||
if (constant && constant->get_all_data_elements_bitwise_identical())
|
||||
{
|
||||
switch (constant->get_element_type())
|
||||
{
|
||||
case ngraph::element::Type_t::undefined:
|
||||
{
|
||||
throw runtime_error("is_value type not supported");
|
||||
}
|
||||
case ngraph::element::Type_t::dynamic: { throw runtime_error("is_value type not supported");
|
||||
}
|
||||
case ngraph::element::Type_t::boolean: break;
|
||||
case ngraph::element::Type_t::bf16:
|
||||
rc = *static_cast<const bfloat16*>(constant->get_data_ptr()) ==
|
||||
bfloat16(static_cast<float>(value));
|
||||
break;
|
||||
case ngraph::element::Type_t::f16:
|
||||
rc = *static_cast<const float16*>(constant->get_data_ptr()) ==
|
||||
float16(static_cast<float>(value));
|
||||
break;
|
||||
case ngraph::element::Type_t::f32:
|
||||
rc = *static_cast<const float*>(constant->get_data_ptr()) == static_cast<float>(value);
|
||||
break;
|
||||
case ngraph::element::Type_t::f64:
|
||||
rc =
|
||||
*static_cast<const double*>(constant->get_data_ptr()) == static_cast<double>(value);
|
||||
break;
|
||||
case ngraph::element::Type_t::i8:
|
||||
rc =
|
||||
*static_cast<const int8_t*>(constant->get_data_ptr()) == static_cast<int8_t>(value);
|
||||
break;
|
||||
case ngraph::element::Type_t::i16:
|
||||
rc = *static_cast<const int16_t*>(constant->get_data_ptr()) ==
|
||||
static_cast<int16_t>(value);
|
||||
break;
|
||||
case ngraph::element::Type_t::i32:
|
||||
rc = *static_cast<const int32_t*>(constant->get_data_ptr()) ==
|
||||
static_cast<int32_t>(value);
|
||||
break;
|
||||
case ngraph::element::Type_t::i64:
|
||||
rc = *static_cast<const int64_t*>(constant->get_data_ptr()) ==
|
||||
static_cast<int64_t>(value);
|
||||
break;
|
||||
case ngraph::element::Type_t::u1: throw runtime_error("is_value type not supported");
|
||||
case ngraph::element::Type_t::u8:
|
||||
rc = *static_cast<const uint8_t*>(constant->get_data_ptr()) ==
|
||||
static_cast<uint8_t>(value);
|
||||
break;
|
||||
case ngraph::element::Type_t::u16:
|
||||
rc = *static_cast<const uint16_t*>(constant->get_data_ptr()) ==
|
||||
static_cast<uint16_t>(value);
|
||||
break;
|
||||
case ngraph::element::Type_t::u32:
|
||||
rc = *static_cast<const uint32_t*>(constant->get_data_ptr()) ==
|
||||
static_cast<uint32_t>(value);
|
||||
break;
|
||||
case ngraph::element::Type_t::u64:
|
||||
rc = *static_cast<const uint64_t*>(constant->get_data_ptr()) ==
|
||||
static_cast<uint64_t>(value);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
static shared_ptr<op::Constant> get_constant(shared_ptr<Node> op)
|
||||
{
|
||||
set<Node::type_info_t> nomath = {op::v0::Broadcast::type_info,
|
||||
op::v0::Reshape::type_info,
|
||||
op::v1::Broadcast::type_info,
|
||||
opset3::Broadcast::type_info,
|
||||
opset3::Reshape::type_info};
|
||||
;
|
||||
while (nomath.find(op->get_type_info()) != nomath.end())
|
||||
{
|
||||
op = op->get_input_node_shared_ptr(0);
|
||||
}
|
||||
return as_type_ptr<op::Constant>(op);
|
||||
}
|
||||
|
||||
static bool is_input_uniform_constant(shared_ptr<Node> op,
|
||||
int constant_value,
|
||||
shared_ptr<Node>& constant,
|
||||
Output<Node>& value)
|
||||
{
|
||||
bool rc = false;
|
||||
auto c = get_constant(op->get_input_node_shared_ptr(0));
|
||||
if (is_uniform_constant(c.get(), constant_value))
|
||||
{
|
||||
constant = op->get_input_node_shared_ptr(0);
|
||||
value = op->input_value(1);
|
||||
rc = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
c = get_constant(op->get_input_node_shared_ptr(1));
|
||||
if (is_uniform_constant(c.get(), constant_value))
|
||||
{
|
||||
constant = op->get_input_node_shared_ptr(1);
|
||||
value = op->input_value(0);
|
||||
rc = true;
|
||||
}
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
//`simplify_gather`, optimizes gather if Gather is gathering the
|
||||
// whole input tensor
|
||||
static bool simplify_gather(std::shared_ptr<Node> node)
|
||||
@ -422,78 +114,6 @@ static bool simplify_gather(std::shared_ptr<Node> node)
|
||||
return false;
|
||||
}
|
||||
|
||||
//`simplify_multiply` optimizes the following 4 *base* cases
|
||||
//(8 cases in total including variants due to commutativity)
|
||||
//
|
||||
// a * 0 -> 0
|
||||
// a * broadcast(0) -> broadcast(0)
|
||||
// a * 1 -> a
|
||||
// a * broadcast(1) -> a
|
||||
static bool simplify_multiply(shared_ptr<Node> multiply)
|
||||
{
|
||||
bool rc = false;
|
||||
if (multiply)
|
||||
{
|
||||
shared_ptr<Node> constant;
|
||||
Output<Node> value;
|
||||
if (is_input_uniform_constant(multiply, 0, constant, value))
|
||||
{
|
||||
replace_output_update_name(multiply->output(0), constant->output(0));
|
||||
rc = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (is_input_uniform_constant(multiply, 1, constant, value))
|
||||
{
|
||||
replace_output_update_name(multiply->output(0), value);
|
||||
rc = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return rc;
|
||||
}
|
||||
|
||||
//`simplify_add` optimizes the following 2 *base* cases
|
||||
//(4 cases in total including variants due to commutativity)
|
||||
//
|
||||
// a + 0 -> a
|
||||
// a + broadcast(0) -> a
|
||||
static bool simplify_add(shared_ptr<Node> add)
|
||||
{
|
||||
bool rc = false;
|
||||
if (add)
|
||||
{
|
||||
shared_ptr<Node> constant;
|
||||
Output<Node> value;
|
||||
if (is_input_uniform_constant(add, 0, constant, value))
|
||||
{
|
||||
replace_output_update_name(add->output(0), value);
|
||||
rc = true;
|
||||
}
|
||||
}
|
||||
|
||||
return rc;
|
||||
}
|
||||
|
||||
//`simplify_log` optimizes `log(exp(x)/y)` into `x - log(y)`
|
||||
static bool simplify_log(shared_ptr<Node> n)
|
||||
{
|
||||
if (auto div = as_type_ptr<op::v0::Divide>(n->input_value(0).get_node_shared_ptr()))
|
||||
{
|
||||
if (auto exp = as_type_ptr<op::v0::Exp>(div->input_value(0).get_node_shared_ptr()))
|
||||
{
|
||||
auto denom = div->get_input_source_output(1);
|
||||
auto diff = make_shared<op::v0::Subtract>(exp->get_input_source_output(0),
|
||||
make_shared<op::v0::Log>(denom));
|
||||
replace_node(n, diff);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// optimizes `gather->shapeof` into `shapeof->gather` for 0D indices
|
||||
// other cases into Concat of shapeof/gather(data) + shapeof(indices)
|
||||
static bool simplify_gather_shapeof(shared_ptr<Node> node)
|
||||
@ -562,137 +182,6 @@ static bool simplify_gather_shapeof(shared_ptr<Node> node)
|
||||
return true;
|
||||
}
|
||||
|
||||
static size_t reduction_shape_size(const AxisSet& axes, const Shape& shape)
|
||||
{
|
||||
size_t prod = 1;
|
||||
for (auto axis : axes)
|
||||
{
|
||||
prod *= shape.at(axis);
|
||||
}
|
||||
|
||||
return prod;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static shared_ptr<Node>
|
||||
multiply_by(element::Type type, size_t multiplier, shared_ptr<op::Constant> cnst)
|
||||
{
|
||||
T sum_cnst = static_cast<T>(cnst->get_data_ptr<T>()[0] * multiplier);
|
||||
return op::Constant::create<T>(type, Shape{}, {sum_cnst});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static shared_ptr<Node> pow_by(element::Type type, size_t multiplier, shared_ptr<op::Constant> cnst)
|
||||
{
|
||||
T prod = static_cast<T>(1);
|
||||
T val = cnst->get_data_ptr<T>()[0];
|
||||
for (size_t i = 0; i < multiplier; i++)
|
||||
{
|
||||
prod *= val;
|
||||
}
|
||||
return op::Constant::create<T>(type, Shape{}, {prod});
|
||||
}
|
||||
|
||||
static shared_ptr<Node> get_sum_constant(shared_ptr<op::Constant> cnst, size_t multiplier)
|
||||
{
|
||||
if (cnst->get_element_type() == element::i32)
|
||||
{
|
||||
return multiply_by<int>(cnst->get_element_type(), multiplier, cnst);
|
||||
}
|
||||
else if (cnst->get_element_type() == element::i8)
|
||||
{
|
||||
return multiply_by<signed char>(cnst->get_element_type(), multiplier, cnst);
|
||||
}
|
||||
else if (cnst->get_element_type() == element::f32)
|
||||
{
|
||||
return multiply_by<float>(cnst->get_element_type(), multiplier, cnst);
|
||||
}
|
||||
else if (cnst->get_element_type() == element::f64)
|
||||
{
|
||||
return multiply_by<double>(cnst->get_element_type(), multiplier, cnst);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static shared_ptr<Node> get_prod_constant(shared_ptr<op::Constant> cnst, size_t multiplier)
|
||||
{
|
||||
if (cnst->get_element_type() == element::i32)
|
||||
{
|
||||
return pow_by<int>(cnst->get_element_type(), multiplier, cnst);
|
||||
}
|
||||
else if (cnst->get_element_type() == element::i8)
|
||||
{
|
||||
return pow_by<signed char>(cnst->get_element_type(), multiplier, cnst);
|
||||
}
|
||||
else if (cnst->get_element_type() == element::f32)
|
||||
{
|
||||
return pow_by<float>(cnst->get_element_type(), multiplier, cnst);
|
||||
}
|
||||
else if (cnst->get_element_type() == element::f64)
|
||||
{
|
||||
return pow_by<double>(cnst->get_element_type(), multiplier, cnst);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//`simplify_reduction` optimizes the following case:
|
||||
// sum(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
|
||||
// where constant2's values are equal to scalar_constant * shape_size(reduction_axes)
|
||||
// product(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
|
||||
// where constant2's values are equal to scalar_constant ^ shape_size(reduction_axes)
|
||||
template <typename T, shared_ptr<Node> (*F)(shared_ptr<op::Constant> cnst, size_t multiplier)>
|
||||
static bool simplify_reduction(shared_ptr<Node> n)
|
||||
{
|
||||
NGRAPH_DEBUG << "In simplify_reduction for " << n->get_name();
|
||||
if (n->get_output_partial_shape(0).is_dynamic())
|
||||
{
|
||||
NGRAPH_DEBUG << n << " has dynamic shape";
|
||||
return false;
|
||||
}
|
||||
auto reduction = static_pointer_cast<T>(n);
|
||||
|
||||
auto broadcast = as_type_ptr<op::v0::Broadcast>(n->input_value(0).get_node_shared_ptr());
|
||||
if (!broadcast)
|
||||
{
|
||||
NGRAPH_DEBUG << n->get_name() << " isn't Broadcast";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto cnst = as_type_ptr<op::Constant>(broadcast->input_value(0).get_node_shared_ptr());
|
||||
if (!cnst || cnst->get_shape().size() > 0 /*not a scalar*/)
|
||||
{
|
||||
NGRAPH_DEBUG << broadcast->input_value(0).get_node_shared_ptr()->get_name()
|
||||
<< " isn't a scalar constant";
|
||||
return false;
|
||||
}
|
||||
|
||||
auto multiplier = reduction_shape_size(reduction->get_reduction_axes(), broadcast->get_shape());
|
||||
auto reduction_cnst = F(cnst, multiplier);
|
||||
|
||||
// Unsupported type
|
||||
if (!reduction_cnst)
|
||||
{
|
||||
NGRAPH_DEBUG << "unsupported type";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (reduction->get_shape().size() > 0)
|
||||
{
|
||||
AxisSet axes{};
|
||||
for (size_t i = 0; i < reduction->get_shape().size(); i++)
|
||||
{
|
||||
axes.insert(i);
|
||||
}
|
||||
reduction_cnst =
|
||||
make_shared<op::v0::Broadcast>(reduction_cnst, reduction->get_shape(), axes);
|
||||
}
|
||||
|
||||
replace_node(n, reduction_cnst);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool replace_transpose_with_reshape(shared_ptr<Node> transpose)
|
||||
{
|
||||
auto data = transpose->input_value(0);
|
||||
@ -797,17 +286,9 @@ static bool replace_transpose_with_reshape(shared_ptr<Node> transpose)
|
||||
static unordered_map<NodeTypeInfo, function<bool(shared_ptr<Node>)>> initialize_ops_to_simplifiers()
|
||||
{
|
||||
return unordered_map<NodeTypeInfo, function<bool(shared_ptr<Node>)>>(
|
||||
{{op::v0::Add::type_info, simplify_add},
|
||||
{op::v0::Multiply::type_info, simplify_multiply},
|
||||
{opset3::Gather::type_info, simplify_gather},
|
||||
{op::v0::Concat::type_info, simplify_concat},
|
||||
{{opset3::Gather::type_info, simplify_gather},
|
||||
{opset2::ShapeOf::type_info, simplify_gather_shapeof},
|
||||
{opset3::ShapeOf::type_info, simplify_gather_shapeof},
|
||||
{op::v0::Sum::type_info,
|
||||
function<bool(shared_ptr<Node>)>{simplify_reduction<op::v0::Sum, get_sum_constant>}},
|
||||
{op::v0::Product::type_info,
|
||||
function<bool(shared_ptr<Node>)>{simplify_reduction<op::v0::Product, get_prod_constant>}},
|
||||
{op::v0::Log::type_info, simplify_log},
|
||||
{opset3::Transpose::type_info, replace_transpose_with_reshape}});
|
||||
}
|
||||
|
||||
|
@ -56,448 +56,6 @@
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
TEST(algebraic_simplification, add_types_shapes)
|
||||
{
|
||||
Shape shapes[] = {Shape{}, Shape{2, 2}, Shape{3, 3, 3}};
|
||||
element::Type types[] = {element::i32, element::f32, element::f64};
|
||||
for (auto type : types)
|
||||
{
|
||||
for (auto shape : shapes)
|
||||
{
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto a = make_shared<op::Parameter>(type, shape);
|
||||
auto b = make_shared<op::Parameter>(type, shape);
|
||||
auto c = make_shared<op::Parameter>(type, shape);
|
||||
auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
|
||||
auto add_a_0 = make_shared<op::Abs>(a + iconst0);
|
||||
auto add_a_0_0 = add_a_0 + iconst0;
|
||||
auto add_b_0 = make_shared<op::Abs>(b + iconst0);
|
||||
auto add_b_0_0 = add_b_0 + iconst0;
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
|
||||
ParameterVector{a, b, c});
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::Add>(f), 0);
|
||||
auto expected = ngraph::NodeVector{a, b, a, c, b};
|
||||
auto results = f->get_results();
|
||||
for (size_t i = 0; i < results.size(); i++)
|
||||
{
|
||||
ASSERT_EQ(
|
||||
expected.at(i),
|
||||
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
|
||||
? results.at(i)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
: results.at(i)->input_value(0).get_node_shared_ptr()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, DISABLED_add_v1_types_shapes)
|
||||
{
|
||||
Shape shapes[] = {Shape{}, Shape{2, 2}, Shape{3, 3, 3}};
|
||||
element::Type types[] = {element::i32, element::f32, element::f64};
|
||||
for (auto type : types)
|
||||
{
|
||||
for (auto shape : shapes)
|
||||
{
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::Validate>();
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto a = make_shared<op::Parameter>(type, shape);
|
||||
auto b = make_shared<op::Parameter>(type, shape);
|
||||
auto c = make_shared<op::Parameter>(type, shape);
|
||||
auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
|
||||
auto add_a_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(a, iconst0));
|
||||
auto add_a_0_0 = make_shared<op::v1::Add>(add_a_0, iconst0);
|
||||
auto add_b_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(b, iconst0));
|
||||
auto add_b_0_0 = make_shared<op::v1::Add>(add_b_0, iconst0);
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
|
||||
ParameterVector{a, b, c});
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::v1::Add>(f), 0);
|
||||
auto expected = ngraph::NodeVector{a, b, a, c, b};
|
||||
auto results = f->get_results();
|
||||
for (size_t i = 0; i < results.size(); i++)
|
||||
{
|
||||
ASSERT_EQ(
|
||||
expected.at(i),
|
||||
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
|
||||
? results.at(i)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
: results.at(i)->input_value(0).get_node_shared_ptr()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, add_broadcast)
|
||||
{
|
||||
Shape shape{2, 2};
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::Validate>();
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto a = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto b = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto c = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto iconst0 = ngraph::make_zero(element::i32, Shape{});
|
||||
auto const_broadcast = make_shared<op::Broadcast>(iconst0, shape, AxisSet{0, 1});
|
||||
auto add_a_0 = make_shared<op::Abs>(a + const_broadcast);
|
||||
auto add_a_0_0 = add_a_0 + const_broadcast;
|
||||
auto add_b_0 = make_shared<op::Abs>(b + const_broadcast);
|
||||
auto add_b_0_0 = add_b_0 + const_broadcast;
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
|
||||
ParameterVector{a, b, c});
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::Add>(f), 0);
|
||||
auto expected = ngraph::NodeVector{a, b, a, c, b};
|
||||
auto results = f->get_results();
|
||||
for (size_t i = 0; i < results.size(); i++)
|
||||
{
|
||||
ASSERT_EQ(expected.at(i),
|
||||
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
|
||||
? results.at(i)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
: results.at(i)->input_value(0).get_node_shared_ptr()));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, DISABLED_add_v1_broadcast_v1)
|
||||
{
|
||||
Shape shape{2, 2};
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::Validate>();
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto a = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto b = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto c = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto iconst0 = ngraph::make_zero(element::i32, Shape{});
|
||||
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 2});
|
||||
auto const_broadcast = make_shared<op::v1::Broadcast>(iconst0, target_shape);
|
||||
auto add_a_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(a, const_broadcast));
|
||||
auto add_a_0_0 = make_shared<op::v1::Add>(add_a_0, const_broadcast);
|
||||
auto add_b_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(b, const_broadcast));
|
||||
auto add_b_0_0 = make_shared<op::v1::Add>(add_b_0, const_broadcast);
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
|
||||
ParameterVector{a, b, c});
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::v1::Add>(f), 0);
|
||||
auto expected = ngraph::NodeVector{a, b, a, c, b};
|
||||
auto results = f->get_results();
|
||||
for (size_t i = 0; i < results.size(); i++)
|
||||
{
|
||||
ASSERT_EQ(expected.at(i),
|
||||
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
|
||||
? results.at(i)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
: results.at(i)->input_value(0).get_node_shared_ptr()));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, multiply_broadcast_0)
|
||||
{
|
||||
Shape shape{2, 2};
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::Validate>();
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto a = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto b = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto c = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto iconst0 = ngraph::make_zero(element::i32, Shape{});
|
||||
auto const_broadcast = make_shared<op::Broadcast>(iconst0, shape, AxisSet{0, 1});
|
||||
auto mul_a_0 = make_shared<op::Abs>(a * const_broadcast);
|
||||
auto mul_a_0_0 = make_shared<op::Abs>(mul_a_0 * const_broadcast);
|
||||
auto mul_b_0 = make_shared<op::Abs>(b * const_broadcast);
|
||||
auto mul_b_0_0 = make_shared<op::Abs>(mul_b_0 * const_broadcast);
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, mul_a_0_0, c, mul_b_0_0},
|
||||
ParameterVector{a, b, c});
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::Multiply>(f), 0);
|
||||
auto expected = ngraph::NodeVector{a, b, const_broadcast, c, const_broadcast};
|
||||
auto results = f->get_results();
|
||||
for (size_t i = 0; i < results.size(); i++)
|
||||
{
|
||||
ASSERT_EQ(expected.at(i),
|
||||
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
|
||||
? results.at(i)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
: results.at(i)->input_value(0).get_node_shared_ptr()));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, DISABLED_multiply_v1_broadcast_v1_0)
|
||||
{
|
||||
Shape shape{2, 2};
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto a = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto b = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto c = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto iconst0 = ngraph::make_zero(element::i32, Shape{});
|
||||
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 2});
|
||||
auto const_broadcast = make_shared<op::v1::Broadcast>(iconst0, target_shape);
|
||||
auto mul_a_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(a, const_broadcast));
|
||||
auto mul_a_0_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(mul_a_0, const_broadcast));
|
||||
auto mul_b_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(b, const_broadcast));
|
||||
auto mul_b_0_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(mul_b_0, const_broadcast));
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, mul_a_0_0, c, mul_b_0_0},
|
||||
ParameterVector{a, b, c});
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::v1::Multiply>(f), 0);
|
||||
auto expected = ngraph::NodeVector{a, b, const_broadcast, c, const_broadcast};
|
||||
auto results = f->get_results();
|
||||
for (size_t i = 0; i < results.size(); i++)
|
||||
{
|
||||
ASSERT_EQ(expected.at(i),
|
||||
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
|
||||
? results.at(i)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
: results.at(i)->input_value(0).get_node_shared_ptr()));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, multiply_broadcast_1)
|
||||
{
|
||||
Shape shape{2, 2};
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto a = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto b = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto c = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto const_broadcast = ngraph::builder::make_constant<int32_t>(element::i32, shape, 1);
|
||||
auto mul_a_0 = make_shared<op::Abs>(a * const_broadcast);
|
||||
auto mul_a_0_0 = mul_a_0 * const_broadcast;
|
||||
auto mul_b_0 = make_shared<op::Abs>(b * const_broadcast);
|
||||
auto mul_b_0_0 = mul_b_0 * const_broadcast;
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, mul_a_0_0, c, mul_b_0_0},
|
||||
ParameterVector{a, b, c});
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::Multiply>(f), 0);
|
||||
auto expected = ngraph::NodeVector{a, b, a, c, b};
|
||||
auto results = f->get_results();
|
||||
for (size_t i = 0; i < results.size(); i++)
|
||||
{
|
||||
ASSERT_EQ(expected.at(i),
|
||||
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
|
||||
? results.at(i)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
: results.at(i)->input_value(0).get_node_shared_ptr()));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, DISABLED_multiply_v1_broadcast_v1_1)
|
||||
{
|
||||
Shape shape{2, 2};
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto a = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto b = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto c = make_shared<op::Parameter>(element::i32, shape);
|
||||
auto const_broadcast = ngraph::builder::make_constant<int32_t>(element::i32, shape, 1);
|
||||
auto mul_a_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(a, const_broadcast));
|
||||
auto mul_a_0_0 = make_shared<op::v1::Multiply>(mul_a_0, const_broadcast);
|
||||
auto mul_b_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(b, const_broadcast));
|
||||
auto mul_b_0_0 = make_shared<op::v1::Multiply>(mul_b_0, const_broadcast);
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, mul_a_0_0, c, mul_b_0_0},
|
||||
ParameterVector{a, b, c});
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_EQ(count_ops_of_type<op::v1::Multiply>(f), 0);
|
||||
auto expected = ngraph::NodeVector{a, b, a, c, b};
|
||||
auto results = f->get_results();
|
||||
for (size_t i = 0; i < results.size(); i++)
|
||||
{
|
||||
ASSERT_EQ(expected.at(i),
|
||||
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
|
||||
? results.at(i)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
: results.at(i)->input_value(0).get_node_shared_ptr()));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, zero_plus_zero_commutativity)
|
||||
{
|
||||
Shape shape{};
|
||||
auto type = element::f32;
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto a = make_shared<op::Parameter>(type, shape);
|
||||
auto b = make_shared<op::Parameter>(type, shape);
|
||||
auto c = make_shared<op::Parameter>(type, shape);
|
||||
auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
|
||||
auto add_a_0 = make_shared<op::Abs>(iconst0 + iconst0);
|
||||
auto add_a_0_0 = make_shared<op::Abs>(iconst0 + iconst0);
|
||||
auto add_b_0 = make_shared<op::Abs>(iconst0 + b);
|
||||
auto add_b_0_0 = make_shared<op::Abs>(iconst0 + b);
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
|
||||
ParameterVector{a, b, c});
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_TRUE(ngraph::is_zero(f->get_results()
|
||||
.at(2)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()));
|
||||
ASSERT_EQ(f->get_results()
|
||||
.at(4)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr(),
|
||||
b);
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, DISABLED_zero_plus_zero_commutativity_v1)
|
||||
{
|
||||
Shape shape{};
|
||||
auto type = element::f32;
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto a = make_shared<op::Parameter>(type, shape);
|
||||
auto b = make_shared<op::Parameter>(type, shape);
|
||||
auto c = make_shared<op::Parameter>(type, shape);
|
||||
auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
|
||||
auto add_a_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(iconst0, iconst0));
|
||||
auto add_a_0_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(iconst0, iconst0));
|
||||
auto add_b_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(iconst0, b));
|
||||
auto add_b_0_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(iconst0, b));
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
|
||||
ParameterVector{a, b, c});
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_TRUE(ngraph::is_zero(f->get_results()
|
||||
.at(2)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()));
|
||||
ASSERT_EQ(f->get_results()
|
||||
.at(4)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr(),
|
||||
b);
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, zero_multiply_zero_one)
|
||||
{
|
||||
Shape shape{};
|
||||
auto type = element::f32;
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto a = make_shared<op::Parameter>(type, shape);
|
||||
auto b = make_shared<op::Parameter>(type, shape);
|
||||
auto c = make_shared<op::Parameter>(type, shape);
|
||||
auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
|
||||
auto iconst1 = ngraph::make_constant_from_string("1", type, shape);
|
||||
auto add_a_0 = make_shared<op::Abs>(iconst0 * iconst0);
|
||||
auto add_b_0 = make_shared<op::Abs>(iconst1 * iconst0);
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0, c, add_b_0},
|
||||
ParameterVector{a, b, c});
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_TRUE(ngraph::is_zero(f->get_results()
|
||||
.at(2)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()));
|
||||
ASSERT_TRUE(ngraph::is_zero(f->get_results()
|
||||
.at(4)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()));
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, DISABLED_zero_multiply_zero_one_v1)
|
||||
{
|
||||
Shape shape{};
|
||||
auto type = element::f32;
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto a = make_shared<op::Parameter>(type, shape);
|
||||
auto b = make_shared<op::Parameter>(type, shape);
|
||||
auto c = make_shared<op::Parameter>(type, shape);
|
||||
auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
|
||||
auto iconst1 = ngraph::make_constant_from_string("1", type, shape);
|
||||
auto add_a_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(iconst0, iconst0));
|
||||
auto add_b_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(iconst1, iconst0));
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0, c, add_b_0},
|
||||
ParameterVector{a, b, c});
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
ASSERT_TRUE(ngraph::is_zero(f->get_results()
|
||||
.at(2)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()));
|
||||
ASSERT_TRUE(ngraph::is_zero(f->get_results()
|
||||
.at(4)
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()
|
||||
->input_value(0)
|
||||
.get_node_shared_ptr()));
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, add_negative_tests)
|
||||
{
|
||||
Shape shape{};
|
||||
@ -527,64 +85,6 @@ TEST(algebraic_simplification, add_negative_tests)
|
||||
}
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, DISABLED_add_negative_tests_v1)
|
||||
{
|
||||
Shape shape{};
|
||||
auto type = element::f32;
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto a = make_shared<op::Parameter>(type, shape);
|
||||
auto b = make_shared<op::Parameter>(type, shape);
|
||||
auto c = make_shared<op::Parameter>(type, shape);
|
||||
auto abs_a = make_shared<op::Abs>(a);
|
||||
auto iconst2 = ngraph::make_constant_from_string("2", type, shape);
|
||||
auto add_a_0 = make_shared<op::v1::Add>(a, iconst2);
|
||||
auto add_a_0_0 = make_shared<op::v1::Add>(add_a_0, iconst2);
|
||||
auto add_b_0 = make_shared<op::v1::Add>(b, abs_a);
|
||||
auto add_b_0_0 = make_shared<op::v1::Add>(add_b_0, abs_a);
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
|
||||
ParameterVector{a, b, c});
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
auto expected = ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0};
|
||||
auto results = f->get_results();
|
||||
for (size_t i = 0; i < results.size(); i++)
|
||||
{
|
||||
ASSERT_EQ(expected.at(i), results.at(i)->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, DISABLED_multiply_negative_tests_v1)
|
||||
{
|
||||
Shape shape{};
|
||||
auto type = element::f32;
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto a = make_shared<op::Parameter>(type, shape);
|
||||
auto b = make_shared<op::Parameter>(type, shape);
|
||||
auto c = make_shared<op::Parameter>(type, shape);
|
||||
auto abs_a = make_shared<op::Abs>(a);
|
||||
auto iconst2 = ngraph::make_constant_from_string("2", type, shape);
|
||||
auto add_a_0 = make_shared<op::v1::Multiply>(a, iconst2);
|
||||
auto add_a_0_0 = make_shared<op::v1::Multiply>(add_a_0, iconst2);
|
||||
auto add_b_0 = make_shared<op::v1::Multiply>(b, abs_a);
|
||||
auto add_b_0_0 = make_shared<op::v1::Multiply>(add_b_0, abs_a);
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
|
||||
ParameterVector{a, b, c});
|
||||
pass_manager.run_passes(f);
|
||||
|
||||
auto expected = ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0};
|
||||
auto results = f->get_results();
|
||||
for (size_t i = 0; i < results.size(); i++)
|
||||
{
|
||||
ASSERT_EQ(expected.at(i), results.at(i)->input_value(0).get_node_shared_ptr());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, multiply_negative_tests)
|
||||
{
|
||||
Shape shape{};
|
||||
@ -614,45 +114,6 @@ TEST(algebraic_simplification, multiply_negative_tests)
|
||||
}
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, multiply_prod_vector_one)
|
||||
{
|
||||
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {2.0});
|
||||
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
|
||||
auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{1});
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, ParameterVector{});
|
||||
pass_manager.run_passes(f);
|
||||
auto new_broadcast =
|
||||
as_type_ptr<op::Broadcast>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_TRUE(new_broadcast);
|
||||
auto new_const = as_type_ptr<op::Constant>(new_broadcast->input_value(0).get_node_shared_ptr());
|
||||
auto values = new_const->get_vector<double>();
|
||||
ASSERT_EQ(values.size(), 1);
|
||||
ASSERT_EQ(values.at(0), 32);
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, multiply_prod_scalar_one)
|
||||
{
|
||||
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {2.0});
|
||||
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
|
||||
auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{0, 1});
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, ParameterVector{});
|
||||
pass_manager.run_passes(f);
|
||||
auto new_const =
|
||||
as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_TRUE(new_const);
|
||||
auto values = new_const->get_vector<double>();
|
||||
ASSERT_EQ(values.size(), 1);
|
||||
ASSERT_EQ(values.at(0), 32768);
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, multiply_prod_negative)
|
||||
{
|
||||
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{2}, {1.0, 1.0});
|
||||
@ -668,45 +129,6 @@ TEST(algebraic_simplification, multiply_prod_negative)
|
||||
ASSERT_EQ(f_prod, prod_fconst1);
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, multiply_sum_scalar_one)
|
||||
{
|
||||
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {1.0});
|
||||
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
|
||||
auto sum_fconst1 = std::make_shared<op::Sum>(broadcast, AxisSet{0, 1});
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{sum_fconst1}, ParameterVector{});
|
||||
pass_manager.run_passes(f);
|
||||
auto new_const =
|
||||
as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_TRUE(new_const);
|
||||
auto values = new_const->get_vector<double>();
|
||||
ASSERT_EQ(values.size(), 1);
|
||||
ASSERT_EQ(values.at(0), 15);
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, multiply_sum_vector_one)
|
||||
{
|
||||
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {1.0});
|
||||
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
|
||||
auto sum_fconst1 = std::make_shared<op::Sum>(broadcast, AxisSet{1});
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{sum_fconst1}, ParameterVector{});
|
||||
pass_manager.run_passes(f);
|
||||
auto new_broadcast =
|
||||
as_type_ptr<op::Broadcast>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_TRUE(new_broadcast);
|
||||
auto new_const = as_type_ptr<op::Constant>(new_broadcast->input_value(0).get_node_shared_ptr());
|
||||
auto values = new_const->get_vector<double>();
|
||||
ASSERT_EQ(values.size(), 1);
|
||||
ASSERT_EQ(values.at(0), 5);
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, multiply_sum_negative)
|
||||
{
|
||||
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{2}, {1.0, 1.0});
|
||||
@ -722,64 +144,6 @@ TEST(algebraic_simplification, multiply_sum_negative)
|
||||
ASSERT_EQ(f_sum, sum_fconst1);
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, concat_reshape_slice)
|
||||
{
|
||||
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
|
||||
auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
|
||||
auto slice2 = make_shared<op::Slice>(a, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
|
||||
auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
|
||||
|
||||
auto reshape1 = make_shared<op::Reshape>(slice1, AxisVector{0, 1}, Shape{32, 1, 100});
|
||||
auto reshape2 = make_shared<op::Reshape>(slice2, AxisVector{0, 1}, Shape{32, 1, 100});
|
||||
auto reshape3 = make_shared<op::Reshape>(slice3, AxisVector{0, 1}, Shape{32, 1, 100});
|
||||
|
||||
size_t concat_axis = 1;
|
||||
auto concat = make_shared<op::Concat>(NodeVector{reshape1, reshape2, reshape3}, concat_axis);
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_TRUE(is_type<op::Reshape>(f->get_results().at(0)->input_value(0).get_node_shared_ptr()));
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, concat_slice)
|
||||
{
|
||||
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
|
||||
auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
|
||||
auto slice2 = make_shared<op::Slice>(a, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
|
||||
auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
|
||||
|
||||
size_t concat_axis = 0;
|
||||
auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), a);
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, concat_parameter_slice)
|
||||
{
|
||||
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
|
||||
auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
|
||||
auto slice2 = make_shared<op::Slice>(a, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
|
||||
auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
|
||||
|
||||
size_t concat_axis = 0;
|
||||
auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
|
||||
pass_manager.run_passes(f);
|
||||
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), a);
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, concat_parameter_slices_reversed)
|
||||
{
|
||||
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
|
||||
@ -858,32 +222,6 @@ TEST(algebraic_simplification, concat_different_inputs)
|
||||
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), concat);
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, log_neg_neg)
|
||||
{
|
||||
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
|
||||
auto b = make_shared<op::Parameter>(element::f32, Shape{96, 100});
|
||||
auto exp_a = make_shared<op::Exp>(a);
|
||||
auto div = exp_a / b;
|
||||
auto log_div = make_shared<op::Log>(div);
|
||||
|
||||
auto neg_inner = make_shared<op::Negative>(log_div);
|
||||
auto neg2 = make_shared<op::Negative>(neg_inner);
|
||||
auto neg3 = make_shared<op::Negative>(neg2);
|
||||
auto neg4 = make_shared<op::Negative>(neg3);
|
||||
|
||||
pass::Manager pass_manager;
|
||||
pass_manager.register_pass<pass::AlgebraicSimplification>();
|
||||
|
||||
auto f = std::make_shared<Function>(ngraph::NodeVector{neg4}, ParameterVector{a, b});
|
||||
pass_manager.run_passes(f);
|
||||
auto sub = as_type_ptr<op::Subtract>(neg_inner->input_value(0).get_node_shared_ptr());
|
||||
ASSERT_TRUE(sub != nullptr);
|
||||
ASSERT_EQ(sub->input_value(0).get_node_shared_ptr(), a);
|
||||
auto new_log = as_type_ptr<op::Log>(sub->input_value(1).get_node_shared_ptr());
|
||||
ASSERT_TRUE(new_log != nullptr);
|
||||
ASSERT_EQ(new_log->input_value(0).get_node_shared_ptr(), b);
|
||||
}
|
||||
|
||||
TEST(algebraic_simplification, log_no_exp)
|
||||
{
|
||||
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
|
||||
|
Loading…
Reference in New Issue
Block a user