Removed v0 operations from AlgebraicSimplufication pass (#1481)

* Removed v0 operations from AlgebraicSimplufication pass

* Fixed tests
This commit is contained in:
Ilya Churaev 2020-07-28 05:48:12 +03:00 committed by GitHub
parent ffcb7fab2d
commit af3a0900b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 1 additions and 1182 deletions

View File

@ -48,314 +48,6 @@
using namespace std;
using namespace ngraph;
template <typename T>
static shared_ptr<pattern::Matcher>
create_binary_matcher(shared_ptr<pattern::op::Label> label,
shared_ptr<pattern::op::Label> const_label)
{
auto bcst =
make_shared<pattern::op::Skip>(const_label, pattern::has_class<op::v0::Broadcast>());
auto bcst_label = make_shared<pattern::op::Label>(bcst, nullptr, NodeVector{bcst});
auto matcher = make_shared<pattern::Matcher>(make_shared<T>(label, bcst_label));
return matcher;
}
//`simplify_concat` identifies slices-concat sequences
// that cancel each other. Namely it replaces subgraphs
// similar to the one below with `arg`
//
// +----------+
// +----+slice(n/2..n)---+
// +-------+ | +----------+ | +-----------+
// | arg +--+ +--+ concat |
// +-------+ | +----------+ | +-----------+
// +----+slice(0..n/2)---+
// +----------+
static bool simplify_concat(shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_concat for " << n->get_name();
if (n->get_output_partial_shape(0).is_dynamic())
{
NGRAPH_DEBUG << n << " has dynamic shape";
return false;
}
Output<Node> branch_tip;
auto ltip = make_shared<pattern::op::Label>(element::i32, Shape{2, 1});
auto pslice =
make_shared<op::v0::Slice>(ltip, Coordinate{0, 0}, Coordinate{2, 1}, Strides{1, 1});
auto lslice = make_shared<pattern::op::Label>(pslice, nullptr, NodeVector{pslice});
auto skip_reshape =
make_shared<pattern::op::Skip>(lslice, pattern::has_class<op::v0::Reshape>());
auto matcher = make_shared<pattern::Matcher>(skip_reshape);
Coordinate prev_lower_bounds;
Shape prev_slice_shape;
for (auto carg : n->input_values())
{
if (!matcher->match(carg))
{
NGRAPH_DEBUG << carg << " doesn't match";
return false;
}
auto& pattern_value_map = matcher->get_pattern_value_map();
auto slice = as_type_ptr<op::v0::Slice>(pattern_value_map[lslice].get_node_shared_ptr());
if (branch_tip != Output<Node>())
{
if (branch_tip != pattern_value_map[ltip])
{
NGRAPH_DEBUG << branch_tip << " doesn't match " << pattern_value_map[ltip];
return false;
}
// slice chunks should be slice in the same order as slice nodes in concat's argument
// list
auto cur_lower_bounds = slice->get_lower_bounds();
if (cur_lower_bounds < prev_lower_bounds)
{
NGRAPH_DEBUG << slice << " is in the wrong order";
return false;
}
prev_lower_bounds.assign(cur_lower_bounds.begin(), cur_lower_bounds.end());
// slice shapes need to match
if (slice->get_shape() != prev_slice_shape)
{
NGRAPH_DEBUG << slice << " doesn't match the shape of the previous slice";
return false;
}
}
else
{
branch_tip = pattern_value_map[ltip];
prev_lower_bounds.assign(slice->get_lower_bounds().begin(),
slice->get_lower_bounds().end());
prev_slice_shape.assign(slice->get_shape().begin(), slice->get_shape().end());
NGRAPH_DEBUG << "setting branch_tip to " << branch_tip;
}
if (slice->get_users(true).size() > 1)
{
NGRAPH_DEBUG << slice << " has more than one user";
return false;
}
if (shape_size(slice->get_strides()) != 1)
{
NGRAPH_DEBUG << slice << " is strided";
return false;
}
// check that no other node uses slices and reshapes
if (auto rcarg = as_type_ptr<op::v0::Reshape>(carg.get_node_shared_ptr()))
{
auto default_shape = get_default_order(rcarg->input_value(0).get_shape());
if (default_shape != rcarg->get_input_order())
{
NGRAPH_DEBUG << carg << " reshape also does transposes";
return false;
}
if (rcarg->get_users(true).size() > 1)
{
NGRAPH_DEBUG << rcarg << " has more than one user";
return false;
}
}
}
auto concat = static_pointer_cast<op::v0::Concat>(n);
auto concat_axis = concat->get_concatenation_axis();
auto slice_shape = branch_tip.get_node_shared_ptr()->get_users(true).at(0)->get_shape();
size_t slice_axis = numeric_limits<size_t>::max();
auto btip_shape = branch_tip.get_shape();
// slices should cover all elements
if (shape_size(btip_shape) != shape_size(n->get_shape()))
{
NGRAPH_DEBUG << "The number of elements in Concat (" << shape_size(n->get_shape())
<< ") and the total of elements in slices (" << shape_size(btip_shape)
<< ") don't match";
return false;
}
for (size_t i = 0; i < btip_shape.size(); i++)
{
if (btip_shape[i] != slice_shape[i])
{
if (slice_axis != numeric_limits<size_t>::max())
{
// multi-axis slice + concat do not cancel
return false;
}
slice_axis = i;
}
}
if (slice_axis == numeric_limits<size_t>::max())
{
return false;
}
auto replacement = branch_tip;
if (btip_shape != n->get_shape())
{
auto default_order = get_default_order(btip_shape);
if (concat_axis == slice_axis)
{
// logical reshape only
replacement =
make_shared<op::v0::Reshape>(branch_tip, default_order, concat->get_shape());
}
else
{
// axis reordering required
auto transposed_shape = n->get_shape();
if (btip_shape.size() >= transposed_shape.size())
{
AxisVector order = get_default_order(btip_shape);
auto ax = order[slice_axis];
order[slice_axis] = order[concat_axis];
order[concat_axis] = ax;
replacement = make_shared<op::v0::Reshape>(branch_tip, order, transposed_shape);
}
else if (btip_shape.size() < transposed_shape.size())
{
// intermediate logical reshape
AxisVector order = get_default_order(transposed_shape);
auto ax = order[slice_axis];
order[slice_axis] = order[concat_axis];
order[concat_axis] = ax;
auto output_shape = apply_permutation(transposed_shape, order);
auto logical_reshape =
make_shared<op::v0::Reshape>(branch_tip, default_order, output_shape);
// transpose to final concatenated shape
replacement =
make_shared<op::v0::Reshape>(logical_reshape, order, transposed_shape);
}
}
}
n->output(0).replace(replacement);
return true;
}
static bool is_uniform_constant(const op::Constant* constant, int value)
{
bool rc = false;
if (constant && constant->get_all_data_elements_bitwise_identical())
{
switch (constant->get_element_type())
{
case ngraph::element::Type_t::undefined:
{
throw runtime_error("is_value type not supported");
}
case ngraph::element::Type_t::dynamic: { throw runtime_error("is_value type not supported");
}
case ngraph::element::Type_t::boolean: break;
case ngraph::element::Type_t::bf16:
rc = *static_cast<const bfloat16*>(constant->get_data_ptr()) ==
bfloat16(static_cast<float>(value));
break;
case ngraph::element::Type_t::f16:
rc = *static_cast<const float16*>(constant->get_data_ptr()) ==
float16(static_cast<float>(value));
break;
case ngraph::element::Type_t::f32:
rc = *static_cast<const float*>(constant->get_data_ptr()) == static_cast<float>(value);
break;
case ngraph::element::Type_t::f64:
rc =
*static_cast<const double*>(constant->get_data_ptr()) == static_cast<double>(value);
break;
case ngraph::element::Type_t::i8:
rc =
*static_cast<const int8_t*>(constant->get_data_ptr()) == static_cast<int8_t>(value);
break;
case ngraph::element::Type_t::i16:
rc = *static_cast<const int16_t*>(constant->get_data_ptr()) ==
static_cast<int16_t>(value);
break;
case ngraph::element::Type_t::i32:
rc = *static_cast<const int32_t*>(constant->get_data_ptr()) ==
static_cast<int32_t>(value);
break;
case ngraph::element::Type_t::i64:
rc = *static_cast<const int64_t*>(constant->get_data_ptr()) ==
static_cast<int64_t>(value);
break;
case ngraph::element::Type_t::u1: throw runtime_error("is_value type not supported");
case ngraph::element::Type_t::u8:
rc = *static_cast<const uint8_t*>(constant->get_data_ptr()) ==
static_cast<uint8_t>(value);
break;
case ngraph::element::Type_t::u16:
rc = *static_cast<const uint16_t*>(constant->get_data_ptr()) ==
static_cast<uint16_t>(value);
break;
case ngraph::element::Type_t::u32:
rc = *static_cast<const uint32_t*>(constant->get_data_ptr()) ==
static_cast<uint32_t>(value);
break;
case ngraph::element::Type_t::u64:
rc = *static_cast<const uint64_t*>(constant->get_data_ptr()) ==
static_cast<uint64_t>(value);
break;
}
}
return rc;
}
static shared_ptr<op::Constant> get_constant(shared_ptr<Node> op)
{
set<Node::type_info_t> nomath = {op::v0::Broadcast::type_info,
op::v0::Reshape::type_info,
op::v1::Broadcast::type_info,
opset3::Broadcast::type_info,
opset3::Reshape::type_info};
;
while (nomath.find(op->get_type_info()) != nomath.end())
{
op = op->get_input_node_shared_ptr(0);
}
return as_type_ptr<op::Constant>(op);
}
static bool is_input_uniform_constant(shared_ptr<Node> op,
int constant_value,
shared_ptr<Node>& constant,
Output<Node>& value)
{
bool rc = false;
auto c = get_constant(op->get_input_node_shared_ptr(0));
if (is_uniform_constant(c.get(), constant_value))
{
constant = op->get_input_node_shared_ptr(0);
value = op->input_value(1);
rc = true;
}
else
{
c = get_constant(op->get_input_node_shared_ptr(1));
if (is_uniform_constant(c.get(), constant_value))
{
constant = op->get_input_node_shared_ptr(1);
value = op->input_value(0);
rc = true;
}
}
return rc;
}
//`simplify_gather`, optimizes gather if Gather is gathering the
// whole input tensor
static bool simplify_gather(std::shared_ptr<Node> node)
@ -422,78 +114,6 @@ static bool simplify_gather(std::shared_ptr<Node> node)
return false;
}
//`simplify_multiply` optimizes the following 4 *base* cases
//(8 cases in total including variants due to commutativity)
//
// a * 0 -> 0
// a * broadcast(0) -> broadcast(0)
// a * 1 -> a
// a * broadcast(1) -> a
static bool simplify_multiply(shared_ptr<Node> multiply)
{
bool rc = false;
if (multiply)
{
shared_ptr<Node> constant;
Output<Node> value;
if (is_input_uniform_constant(multiply, 0, constant, value))
{
replace_output_update_name(multiply->output(0), constant->output(0));
rc = true;
}
else
{
if (is_input_uniform_constant(multiply, 1, constant, value))
{
replace_output_update_name(multiply->output(0), value);
rc = true;
}
}
}
return rc;
}
//`simplify_add` optimizes the following 2 *base* cases
//(4 cases in total including variants due to commutativity)
//
// a + 0 -> a
// a + broadcast(0) -> a
static bool simplify_add(shared_ptr<Node> add)
{
bool rc = false;
if (add)
{
shared_ptr<Node> constant;
Output<Node> value;
if (is_input_uniform_constant(add, 0, constant, value))
{
replace_output_update_name(add->output(0), value);
rc = true;
}
}
return rc;
}
//`simplify_log` optimizes `log(exp(x)/y)` into `x - log(y)`
static bool simplify_log(shared_ptr<Node> n)
{
if (auto div = as_type_ptr<op::v0::Divide>(n->input_value(0).get_node_shared_ptr()))
{
if (auto exp = as_type_ptr<op::v0::Exp>(div->input_value(0).get_node_shared_ptr()))
{
auto denom = div->get_input_source_output(1);
auto diff = make_shared<op::v0::Subtract>(exp->get_input_source_output(0),
make_shared<op::v0::Log>(denom));
replace_node(n, diff);
return true;
}
}
return false;
}
// optimizes `gather->shapeof` into `shapeof->gather` for 0D indices
// other cases into Concat of shapeof/gather(data) + shapeof(indices)
static bool simplify_gather_shapeof(shared_ptr<Node> node)
@ -562,137 +182,6 @@ static bool simplify_gather_shapeof(shared_ptr<Node> node)
return true;
}
static size_t reduction_shape_size(const AxisSet& axes, const Shape& shape)
{
size_t prod = 1;
for (auto axis : axes)
{
prod *= shape.at(axis);
}
return prod;
}
template <typename T>
static shared_ptr<Node>
multiply_by(element::Type type, size_t multiplier, shared_ptr<op::Constant> cnst)
{
T sum_cnst = static_cast<T>(cnst->get_data_ptr<T>()[0] * multiplier);
return op::Constant::create<T>(type, Shape{}, {sum_cnst});
}
template <typename T>
static shared_ptr<Node> pow_by(element::Type type, size_t multiplier, shared_ptr<op::Constant> cnst)
{
T prod = static_cast<T>(1);
T val = cnst->get_data_ptr<T>()[0];
for (size_t i = 0; i < multiplier; i++)
{
prod *= val;
}
return op::Constant::create<T>(type, Shape{}, {prod});
}
static shared_ptr<Node> get_sum_constant(shared_ptr<op::Constant> cnst, size_t multiplier)
{
if (cnst->get_element_type() == element::i32)
{
return multiply_by<int>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::i8)
{
return multiply_by<signed char>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::f32)
{
return multiply_by<float>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::f64)
{
return multiply_by<double>(cnst->get_element_type(), multiplier, cnst);
}
return nullptr;
}
static shared_ptr<Node> get_prod_constant(shared_ptr<op::Constant> cnst, size_t multiplier)
{
if (cnst->get_element_type() == element::i32)
{
return pow_by<int>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::i8)
{
return pow_by<signed char>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::f32)
{
return pow_by<float>(cnst->get_element_type(), multiplier, cnst);
}
else if (cnst->get_element_type() == element::f64)
{
return pow_by<double>(cnst->get_element_type(), multiplier, cnst);
}
return nullptr;
}
//`simplify_reduction` optimizes the following case:
// sum(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
// where constant2's values are equal to scalar_constant * shape_size(reduction_axes)
// product(broadcast(scalar_constant), reduction_axes = ...) -> constant2 (or scalar constant)
// where constant2's values are equal to scalar_constant ^ shape_size(reduction_axes)
template <typename T, shared_ptr<Node> (*F)(shared_ptr<op::Constant> cnst, size_t multiplier)>
static bool simplify_reduction(shared_ptr<Node> n)
{
NGRAPH_DEBUG << "In simplify_reduction for " << n->get_name();
if (n->get_output_partial_shape(0).is_dynamic())
{
NGRAPH_DEBUG << n << " has dynamic shape";
return false;
}
auto reduction = static_pointer_cast<T>(n);
auto broadcast = as_type_ptr<op::v0::Broadcast>(n->input_value(0).get_node_shared_ptr());
if (!broadcast)
{
NGRAPH_DEBUG << n->get_name() << " isn't Broadcast";
return false;
}
auto cnst = as_type_ptr<op::Constant>(broadcast->input_value(0).get_node_shared_ptr());
if (!cnst || cnst->get_shape().size() > 0 /*not a scalar*/)
{
NGRAPH_DEBUG << broadcast->input_value(0).get_node_shared_ptr()->get_name()
<< " isn't a scalar constant";
return false;
}
auto multiplier = reduction_shape_size(reduction->get_reduction_axes(), broadcast->get_shape());
auto reduction_cnst = F(cnst, multiplier);
// Unsupported type
if (!reduction_cnst)
{
NGRAPH_DEBUG << "unsupported type";
return false;
}
if (reduction->get_shape().size() > 0)
{
AxisSet axes{};
for (size_t i = 0; i < reduction->get_shape().size(); i++)
{
axes.insert(i);
}
reduction_cnst =
make_shared<op::v0::Broadcast>(reduction_cnst, reduction->get_shape(), axes);
}
replace_node(n, reduction_cnst);
return true;
}
static bool replace_transpose_with_reshape(shared_ptr<Node> transpose)
{
auto data = transpose->input_value(0);
@ -797,17 +286,9 @@ static bool replace_transpose_with_reshape(shared_ptr<Node> transpose)
static unordered_map<NodeTypeInfo, function<bool(shared_ptr<Node>)>> initialize_ops_to_simplifiers()
{
return unordered_map<NodeTypeInfo, function<bool(shared_ptr<Node>)>>(
{{op::v0::Add::type_info, simplify_add},
{op::v0::Multiply::type_info, simplify_multiply},
{opset3::Gather::type_info, simplify_gather},
{op::v0::Concat::type_info, simplify_concat},
{{opset3::Gather::type_info, simplify_gather},
{opset2::ShapeOf::type_info, simplify_gather_shapeof},
{opset3::ShapeOf::type_info, simplify_gather_shapeof},
{op::v0::Sum::type_info,
function<bool(shared_ptr<Node>)>{simplify_reduction<op::v0::Sum, get_sum_constant>}},
{op::v0::Product::type_info,
function<bool(shared_ptr<Node>)>{simplify_reduction<op::v0::Product, get_prod_constant>}},
{op::v0::Log::type_info, simplify_log},
{opset3::Transpose::type_info, replace_transpose_with_reshape}});
}

View File

@ -56,448 +56,6 @@
using namespace ngraph;
using namespace std;
TEST(algebraic_simplification, add_types_shapes)
{
Shape shapes[] = {Shape{}, Shape{2, 2}, Shape{3, 3, 3}};
element::Type types[] = {element::i32, element::f32, element::f64};
for (auto type : types)
{
for (auto shape : shapes)
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(type, shape);
auto b = make_shared<op::Parameter>(type, shape);
auto c = make_shared<op::Parameter>(type, shape);
auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
auto add_a_0 = make_shared<op::Abs>(a + iconst0);
auto add_a_0_0 = add_a_0 + iconst0;
auto add_b_0 = make_shared<op::Abs>(b + iconst0);
auto add_b_0_0 = add_b_0 + iconst0;
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Add>(f), 0);
auto expected = ngraph::NodeVector{a, b, a, c, b};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(
expected.at(i),
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
? results.at(i)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr()
: results.at(i)->input_value(0).get_node_shared_ptr()));
}
}
}
}
TEST(algebraic_simplification, DISABLED_add_v1_types_shapes)
{
Shape shapes[] = {Shape{}, Shape{2, 2}, Shape{3, 3, 3}};
element::Type types[] = {element::i32, element::f32, element::f64};
for (auto type : types)
{
for (auto shape : shapes)
{
pass::Manager pass_manager;
pass_manager.register_pass<pass::Validate>();
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(type, shape);
auto b = make_shared<op::Parameter>(type, shape);
auto c = make_shared<op::Parameter>(type, shape);
auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
auto add_a_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(a, iconst0));
auto add_a_0_0 = make_shared<op::v1::Add>(add_a_0, iconst0);
auto add_b_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(b, iconst0));
auto add_b_0_0 = make_shared<op::v1::Add>(add_b_0, iconst0);
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::Add>(f), 0);
auto expected = ngraph::NodeVector{a, b, a, c, b};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(
expected.at(i),
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
? results.at(i)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr()
: results.at(i)->input_value(0).get_node_shared_ptr()));
}
}
}
}
TEST(algebraic_simplification, add_broadcast)
{
Shape shape{2, 2};
pass::Manager pass_manager;
pass_manager.register_pass<pass::Validate>();
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto c = make_shared<op::Parameter>(element::i32, shape);
auto iconst0 = ngraph::make_zero(element::i32, Shape{});
auto const_broadcast = make_shared<op::Broadcast>(iconst0, shape, AxisSet{0, 1});
auto add_a_0 = make_shared<op::Abs>(a + const_broadcast);
auto add_a_0_0 = add_a_0 + const_broadcast;
auto add_b_0 = make_shared<op::Abs>(b + const_broadcast);
auto add_b_0_0 = add_b_0 + const_broadcast;
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Add>(f), 0);
auto expected = ngraph::NodeVector{a, b, a, c, b};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(expected.at(i),
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
? results.at(i)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr()
: results.at(i)->input_value(0).get_node_shared_ptr()));
}
}
TEST(algebraic_simplification, DISABLED_add_v1_broadcast_v1)
{
Shape shape{2, 2};
pass::Manager pass_manager;
pass_manager.register_pass<pass::Validate>();
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto c = make_shared<op::Parameter>(element::i32, shape);
auto iconst0 = ngraph::make_zero(element::i32, Shape{});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 2});
auto const_broadcast = make_shared<op::v1::Broadcast>(iconst0, target_shape);
auto add_a_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(a, const_broadcast));
auto add_a_0_0 = make_shared<op::v1::Add>(add_a_0, const_broadcast);
auto add_b_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(b, const_broadcast));
auto add_b_0_0 = make_shared<op::v1::Add>(add_b_0, const_broadcast);
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::Add>(f), 0);
auto expected = ngraph::NodeVector{a, b, a, c, b};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(expected.at(i),
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
? results.at(i)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr()
: results.at(i)->input_value(0).get_node_shared_ptr()));
}
}
TEST(algebraic_simplification, multiply_broadcast_0)
{
Shape shape{2, 2};
pass::Manager pass_manager;
pass_manager.register_pass<pass::Validate>();
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto c = make_shared<op::Parameter>(element::i32, shape);
auto iconst0 = ngraph::make_zero(element::i32, Shape{});
auto const_broadcast = make_shared<op::Broadcast>(iconst0, shape, AxisSet{0, 1});
auto mul_a_0 = make_shared<op::Abs>(a * const_broadcast);
auto mul_a_0_0 = make_shared<op::Abs>(mul_a_0 * const_broadcast);
auto mul_b_0 = make_shared<op::Abs>(b * const_broadcast);
auto mul_b_0_0 = make_shared<op::Abs>(mul_b_0 * const_broadcast);
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, mul_a_0_0, c, mul_b_0_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Multiply>(f), 0);
auto expected = ngraph::NodeVector{a, b, const_broadcast, c, const_broadcast};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(expected.at(i),
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
? results.at(i)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr()
: results.at(i)->input_value(0).get_node_shared_ptr()));
}
}
TEST(algebraic_simplification, DISABLED_multiply_v1_broadcast_v1_0)
{
Shape shape{2, 2};
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto c = make_shared<op::Parameter>(element::i32, shape);
auto iconst0 = ngraph::make_zero(element::i32, Shape{});
auto target_shape = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 2});
auto const_broadcast = make_shared<op::v1::Broadcast>(iconst0, target_shape);
auto mul_a_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(a, const_broadcast));
auto mul_a_0_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(mul_a_0, const_broadcast));
auto mul_b_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(b, const_broadcast));
auto mul_b_0_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(mul_b_0, const_broadcast));
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, mul_a_0_0, c, mul_b_0_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::Multiply>(f), 0);
auto expected = ngraph::NodeVector{a, b, const_broadcast, c, const_broadcast};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(expected.at(i),
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
? results.at(i)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr()
: results.at(i)->input_value(0).get_node_shared_ptr()));
}
}
TEST(algebraic_simplification, multiply_broadcast_1)
{
Shape shape{2, 2};
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto c = make_shared<op::Parameter>(element::i32, shape);
auto const_broadcast = ngraph::builder::make_constant<int32_t>(element::i32, shape, 1);
auto mul_a_0 = make_shared<op::Abs>(a * const_broadcast);
auto mul_a_0_0 = mul_a_0 * const_broadcast;
auto mul_b_0 = make_shared<op::Abs>(b * const_broadcast);
auto mul_b_0_0 = mul_b_0 * const_broadcast;
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, mul_a_0_0, c, mul_b_0_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::Multiply>(f), 0);
auto expected = ngraph::NodeVector{a, b, a, c, b};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(expected.at(i),
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
? results.at(i)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr()
: results.at(i)->input_value(0).get_node_shared_ptr()));
}
}
TEST(algebraic_simplification, DISABLED_multiply_v1_broadcast_v1_1)
{
Shape shape{2, 2};
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(element::i32, shape);
auto b = make_shared<op::Parameter>(element::i32, shape);
auto c = make_shared<op::Parameter>(element::i32, shape);
auto const_broadcast = ngraph::builder::make_constant<int32_t>(element::i32, shape, 1);
auto mul_a_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(a, const_broadcast));
auto mul_a_0_0 = make_shared<op::v1::Multiply>(mul_a_0, const_broadcast);
auto mul_b_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(b, const_broadcast));
auto mul_b_0_0 = make_shared<op::v1::Multiply>(mul_b_0, const_broadcast);
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, mul_a_0_0, c, mul_b_0_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_EQ(count_ops_of_type<op::v1::Multiply>(f), 0);
auto expected = ngraph::NodeVector{a, b, a, c, b};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(expected.at(i),
(results.at(i)->input_value(0).get_node_shared_ptr()->input_values().size()
? results.at(i)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr()
: results.at(i)->input_value(0).get_node_shared_ptr()));
}
}
TEST(algebraic_simplification, zero_plus_zero_commutativity)
{
Shape shape{};
auto type = element::f32;
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(type, shape);
auto b = make_shared<op::Parameter>(type, shape);
auto c = make_shared<op::Parameter>(type, shape);
auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
auto add_a_0 = make_shared<op::Abs>(iconst0 + iconst0);
auto add_a_0_0 = make_shared<op::Abs>(iconst0 + iconst0);
auto add_b_0 = make_shared<op::Abs>(iconst0 + b);
auto add_b_0_0 = make_shared<op::Abs>(iconst0 + b);
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_TRUE(ngraph::is_zero(f->get_results()
.at(2)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr()));
ASSERT_EQ(f->get_results()
.at(4)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr(),
b);
}
TEST(algebraic_simplification, DISABLED_zero_plus_zero_commutativity_v1)
{
Shape shape{};
auto type = element::f32;
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(type, shape);
auto b = make_shared<op::Parameter>(type, shape);
auto c = make_shared<op::Parameter>(type, shape);
auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
auto add_a_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(iconst0, iconst0));
auto add_a_0_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(iconst0, iconst0));
auto add_b_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(iconst0, b));
auto add_b_0_0 = make_shared<op::Abs>(make_shared<op::v1::Add>(iconst0, b));
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_TRUE(ngraph::is_zero(f->get_results()
.at(2)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr()));
ASSERT_EQ(f->get_results()
.at(4)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr(),
b);
}
TEST(algebraic_simplification, zero_multiply_zero_one)
{
Shape shape{};
auto type = element::f32;
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(type, shape);
auto b = make_shared<op::Parameter>(type, shape);
auto c = make_shared<op::Parameter>(type, shape);
auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
auto iconst1 = ngraph::make_constant_from_string("1", type, shape);
auto add_a_0 = make_shared<op::Abs>(iconst0 * iconst0);
auto add_b_0 = make_shared<op::Abs>(iconst1 * iconst0);
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0, c, add_b_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_TRUE(ngraph::is_zero(f->get_results()
.at(2)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr()));
ASSERT_TRUE(ngraph::is_zero(f->get_results()
.at(4)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr()));
}
TEST(algebraic_simplification, DISABLED_zero_multiply_zero_one_v1)
{
Shape shape{};
auto type = element::f32;
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(type, shape);
auto b = make_shared<op::Parameter>(type, shape);
auto c = make_shared<op::Parameter>(type, shape);
auto iconst0 = ngraph::make_constant_from_string("0", type, shape);
auto iconst1 = ngraph::make_constant_from_string("1", type, shape);
auto add_a_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(iconst0, iconst0));
auto add_b_0 = make_shared<op::Abs>(make_shared<op::v1::Multiply>(iconst1, iconst0));
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0, c, add_b_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
ASSERT_TRUE(ngraph::is_zero(f->get_results()
.at(2)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr()));
ASSERT_TRUE(ngraph::is_zero(f->get_results()
.at(4)
->input_value(0)
.get_node_shared_ptr()
->input_value(0)
.get_node_shared_ptr()));
}
TEST(algebraic_simplification, add_negative_tests)
{
Shape shape{};
@ -527,64 +85,6 @@ TEST(algebraic_simplification, add_negative_tests)
}
}
TEST(algebraic_simplification, DISABLED_add_negative_tests_v1)
{
Shape shape{};
auto type = element::f32;
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(type, shape);
auto b = make_shared<op::Parameter>(type, shape);
auto c = make_shared<op::Parameter>(type, shape);
auto abs_a = make_shared<op::Abs>(a);
auto iconst2 = ngraph::make_constant_from_string("2", type, shape);
auto add_a_0 = make_shared<op::v1::Add>(a, iconst2);
auto add_a_0_0 = make_shared<op::v1::Add>(add_a_0, iconst2);
auto add_b_0 = make_shared<op::v1::Add>(b, abs_a);
auto add_b_0_0 = make_shared<op::v1::Add>(add_b_0, abs_a);
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
auto expected = ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(expected.at(i), results.at(i)->input_value(0).get_node_shared_ptr());
}
}
TEST(algebraic_simplification, DISABLED_multiply_negative_tests_v1)
{
Shape shape{};
auto type = element::f32;
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto a = make_shared<op::Parameter>(type, shape);
auto b = make_shared<op::Parameter>(type, shape);
auto c = make_shared<op::Parameter>(type, shape);
auto abs_a = make_shared<op::Abs>(a);
auto iconst2 = ngraph::make_constant_from_string("2", type, shape);
auto add_a_0 = make_shared<op::v1::Multiply>(a, iconst2);
auto add_a_0_0 = make_shared<op::v1::Multiply>(add_a_0, iconst2);
auto add_b_0 = make_shared<op::v1::Multiply>(b, abs_a);
auto add_b_0_0 = make_shared<op::v1::Multiply>(add_b_0, abs_a);
auto f = std::make_shared<Function>(ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0},
ParameterVector{a, b, c});
pass_manager.run_passes(f);
auto expected = ngraph::NodeVector{a, b, add_a_0_0, c, add_b_0_0};
auto results = f->get_results();
for (size_t i = 0; i < results.size(); i++)
{
ASSERT_EQ(expected.at(i), results.at(i)->input_value(0).get_node_shared_ptr());
}
}
TEST(algebraic_simplification, multiply_negative_tests)
{
Shape shape{};
@ -614,45 +114,6 @@ TEST(algebraic_simplification, multiply_negative_tests)
}
}
TEST(algebraic_simplification, multiply_prod_vector_one)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {2.0});
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{1});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, ParameterVector{});
pass_manager.run_passes(f);
auto new_broadcast =
as_type_ptr<op::Broadcast>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
ASSERT_TRUE(new_broadcast);
auto new_const = as_type_ptr<op::Constant>(new_broadcast->input_value(0).get_node_shared_ptr());
auto values = new_const->get_vector<double>();
ASSERT_EQ(values.size(), 1);
ASSERT_EQ(values.at(0), 32);
}
TEST(algebraic_simplification, multiply_prod_scalar_one)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {2.0});
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
auto prod_fconst1 = std::make_shared<op::Product>(broadcast, AxisSet{0, 1});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{prod_fconst1}, ParameterVector{});
pass_manager.run_passes(f);
auto new_const =
as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
ASSERT_TRUE(new_const);
auto values = new_const->get_vector<double>();
ASSERT_EQ(values.size(), 1);
ASSERT_EQ(values.at(0), 32768);
}
TEST(algebraic_simplification, multiply_prod_negative)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{2}, {1.0, 1.0});
@ -668,45 +129,6 @@ TEST(algebraic_simplification, multiply_prod_negative)
ASSERT_EQ(f_prod, prod_fconst1);
}
TEST(algebraic_simplification, multiply_sum_scalar_one)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {1.0});
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
auto sum_fconst1 = std::make_shared<op::Sum>(broadcast, AxisSet{0, 1});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{sum_fconst1}, ParameterVector{});
pass_manager.run_passes(f);
auto new_const =
as_type_ptr<op::Constant>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
ASSERT_TRUE(new_const);
auto values = new_const->get_vector<double>();
ASSERT_EQ(values.size(), 1);
ASSERT_EQ(values.at(0), 15);
}
TEST(algebraic_simplification, multiply_sum_vector_one)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{}, {1.0});
auto broadcast = std::make_shared<op::Broadcast>(fconst1, Shape{3, 5}, AxisSet{0, 1});
auto sum_fconst1 = std::make_shared<op::Sum>(broadcast, AxisSet{1});
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{sum_fconst1}, ParameterVector{});
pass_manager.run_passes(f);
auto new_broadcast =
as_type_ptr<op::Broadcast>(f->get_results().at(0)->input_value(0).get_node_shared_ptr());
ASSERT_TRUE(new_broadcast);
auto new_const = as_type_ptr<op::Constant>(new_broadcast->input_value(0).get_node_shared_ptr());
auto values = new_const->get_vector<double>();
ASSERT_EQ(values.size(), 1);
ASSERT_EQ(values.at(0), 5);
}
TEST(algebraic_simplification, multiply_sum_negative)
{
auto fconst1 = ngraph::op::Constant::create(element::f64, Shape{2}, {1.0, 1.0});
@ -722,64 +144,6 @@ TEST(algebraic_simplification, multiply_sum_negative)
ASSERT_EQ(f_sum, sum_fconst1);
}
TEST(algebraic_simplification, concat_reshape_slice)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
auto slice2 = make_shared<op::Slice>(a, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
auto reshape1 = make_shared<op::Reshape>(slice1, AxisVector{0, 1}, Shape{32, 1, 100});
auto reshape2 = make_shared<op::Reshape>(slice2, AxisVector{0, 1}, Shape{32, 1, 100});
auto reshape3 = make_shared<op::Reshape>(slice3, AxisVector{0, 1}, Shape{32, 1, 100});
size_t concat_axis = 1;
auto concat = make_shared<op::Concat>(NodeVector{reshape1, reshape2, reshape3}, concat_axis);
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
pass_manager.run_passes(f);
ASSERT_TRUE(is_type<op::Reshape>(f->get_results().at(0)->input_value(0).get_node_shared_ptr()));
}
TEST(algebraic_simplification, concat_slice)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
auto slice2 = make_shared<op::Slice>(a, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
size_t concat_axis = 0;
auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), a);
}
TEST(algebraic_simplification, concat_parameter_slice)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto slice1 = make_shared<op::Slice>(a, Coordinate{0, 0}, Coordinate{32, 100}, Strides{1, 1});
auto slice2 = make_shared<op::Slice>(a, Coordinate{32, 0}, Coordinate{64, 100}, Strides{1, 1});
auto slice3 = make_shared<op::Slice>(a, Coordinate{64, 0}, Coordinate{96, 100}, Strides{1, 1});
size_t concat_axis = 0;
auto concat = make_shared<op::Concat>(NodeVector{slice1, slice2, slice3}, concat_axis);
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{concat}, ParameterVector{a});
pass_manager.run_passes(f);
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), a);
}
TEST(algebraic_simplification, concat_parameter_slices_reversed)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
@ -858,32 +222,6 @@ TEST(algebraic_simplification, concat_different_inputs)
ASSERT_EQ(f->get_results().at(0)->input_value(0).get_node_shared_ptr(), concat);
}
TEST(algebraic_simplification, log_neg_neg)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto b = make_shared<op::Parameter>(element::f32, Shape{96, 100});
auto exp_a = make_shared<op::Exp>(a);
auto div = exp_a / b;
auto log_div = make_shared<op::Log>(div);
auto neg_inner = make_shared<op::Negative>(log_div);
auto neg2 = make_shared<op::Negative>(neg_inner);
auto neg3 = make_shared<op::Negative>(neg2);
auto neg4 = make_shared<op::Negative>(neg3);
pass::Manager pass_manager;
pass_manager.register_pass<pass::AlgebraicSimplification>();
auto f = std::make_shared<Function>(ngraph::NodeVector{neg4}, ParameterVector{a, b});
pass_manager.run_passes(f);
auto sub = as_type_ptr<op::Subtract>(neg_inner->input_value(0).get_node_shared_ptr());
ASSERT_TRUE(sub != nullptr);
ASSERT_EQ(sub->input_value(0).get_node_shared_ptr(), a);
auto new_log = as_type_ptr<op::Log>(sub->input_value(1).get_node_shared_ptr());
ASSERT_TRUE(new_log != nullptr);
ASSERT_EQ(new_log->input_value(0).get_node_shared_ptr(), b);
}
TEST(algebraic_simplification, log_no_exp)
{
auto a = make_shared<op::Parameter>(element::f32, Shape{96, 100});