add default parameter value to runtime::reference::fake_quantize
This commit is contained in:
parent
2e9c6fd752
commit
9d2c00d967
@ -195,10 +195,10 @@ namespace ngraph
|
||||
}
|
||||
|
||||
} // namespace fake_quantize_details
|
||||
namespace v1
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
void fake_quantize(const T* const arg,
|
||||
void
|
||||
fake_quantize(const T* const arg,
|
||||
const T* const in_low,
|
||||
const T* const in_high,
|
||||
const T* const out_low,
|
||||
@ -210,7 +210,7 @@ namespace ngraph
|
||||
const Shape& out_low_shape,
|
||||
const Shape& out_high_shape,
|
||||
size_t levels,
|
||||
const op::AutoBroadcastSpec& broadcast)
|
||||
const op::AutoBroadcastSpec& broadcast = op::AutoBroadcastType::NUMPY)
|
||||
{
|
||||
using namespace fake_quantize_details;
|
||||
|
||||
@ -256,46 +256,12 @@ namespace ngraph
|
||||
const T out_low_val = out_low_bound.get_value(current_dim, index);
|
||||
const T out_high_val = out_high_bound.get_value(current_dim, index);
|
||||
|
||||
out[index] = quantize(arg[index],
|
||||
in_low_val,
|
||||
in_high_val,
|
||||
out_low_val,
|
||||
out_high_val,
|
||||
levels);
|
||||
out[index] = quantize(
|
||||
arg[index], in_low_val, in_high_val, out_low_val, out_high_val, levels);
|
||||
increment_current_dim(current_dim, arg_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace v1
|
||||
|
||||
template <typename T>
|
||||
void fake_quantize(const T* const arg,
|
||||
const T* const in_low,
|
||||
const T* const in_high,
|
||||
const T* const out_low,
|
||||
const T* const out_high,
|
||||
T* const out,
|
||||
const Shape& arg_shape,
|
||||
const Shape& in_low_shape,
|
||||
const Shape& in_high_shape,
|
||||
const Shape& out_low_shape,
|
||||
const Shape& out_high_shape,
|
||||
size_t levels)
|
||||
{
|
||||
v1::fake_quantize(arg,
|
||||
in_low,
|
||||
in_high,
|
||||
out_low,
|
||||
out_high,
|
||||
out,
|
||||
arg_shape,
|
||||
in_low_shape,
|
||||
in_high_shape,
|
||||
out_low_shape,
|
||||
out_high_shape,
|
||||
levels,
|
||||
op::AutoBroadcastType::NUMPY);
|
||||
}
|
||||
} // namespace reference
|
||||
} // namespace runtime
|
||||
} // namespace ngraph
|
||||
|
@ -2383,7 +2383,7 @@ namespace
|
||||
const HostTensorVector& inputs)
|
||||
{
|
||||
using T = typename element_type_traits<ET>::value_type;
|
||||
runtime::reference::v1::fake_quantize<T>(inputs[0]->get_data_ptr<const T>(),
|
||||
runtime::reference::fake_quantize<T>(inputs[0]->get_data_ptr<const T>(),
|
||||
inputs[1]->get_data_ptr<const T>(),
|
||||
inputs[2]->get_data_ptr<const T>(),
|
||||
inputs[3]->get_data_ptr<const T>(),
|
||||
|
Loading…
Reference in New Issue
Block a user