add default parameter value to runtime::reference::fake_quantize

This commit is contained in:
Patryk Elszkowski 2021-07-13 08:00:03 +02:00
parent 2e9c6fd752
commit 9d2c00d967
2 changed files with 77 additions and 111 deletions

View File

@ -195,10 +195,10 @@ namespace ngraph
}
} // namespace fake_quantize_details
namespace v1
{
template <typename T>
void fake_quantize(const T* const arg,
void
fake_quantize(const T* const arg,
const T* const in_low,
const T* const in_high,
const T* const out_low,
@ -210,7 +210,7 @@ namespace ngraph
const Shape& out_low_shape,
const Shape& out_high_shape,
size_t levels,
const op::AutoBroadcastSpec& broadcast)
const op::AutoBroadcastSpec& broadcast = op::AutoBroadcastType::NUMPY)
{
using namespace fake_quantize_details;
@ -256,46 +256,12 @@ namespace ngraph
const T out_low_val = out_low_bound.get_value(current_dim, index);
const T out_high_val = out_high_bound.get_value(current_dim, index);
out[index] = quantize(arg[index],
in_low_val,
in_high_val,
out_low_val,
out_high_val,
levels);
out[index] = quantize(
arg[index], in_low_val, in_high_val, out_low_val, out_high_val, levels);
increment_current_dim(current_dim, arg_shape);
}
}
}
} // namespace v1
template <typename T>
void fake_quantize(const T* const arg,
const T* const in_low,
const T* const in_high,
const T* const out_low,
const T* const out_high,
T* const out,
const Shape& arg_shape,
const Shape& in_low_shape,
const Shape& in_high_shape,
const Shape& out_low_shape,
const Shape& out_high_shape,
size_t levels)
{
v1::fake_quantize(arg,
in_low,
in_high,
out_low,
out_high,
out,
arg_shape,
in_low_shape,
in_high_shape,
out_low_shape,
out_high_shape,
levels,
op::AutoBroadcastType::NUMPY);
}
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -2383,7 +2383,7 @@ namespace
const HostTensorVector& inputs)
{
using T = typename element_type_traits<ET>::value_type;
runtime::reference::v1::fake_quantize<T>(inputs[0]->get_data_ptr<const T>(),
runtime::reference::fake_quantize<T>(inputs[0]->get_data_ptr<const T>(),
inputs[1]->get_data_ptr<const T>(),
inputs[2]->get_data_ptr<const T>(),
inputs[3]->get_data_ptr<const T>(),