add default parameter value to runtime::reference::fake_quantize
This commit is contained in:
parent
2e9c6fd752
commit
9d2c00d967
@ -195,106 +195,72 @@ namespace ngraph
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace fake_quantize_details
|
} // namespace fake_quantize_details
|
||||||
namespace v1
|
|
||||||
{
|
|
||||||
template <typename T>
|
|
||||||
void fake_quantize(const T* const arg,
|
|
||||||
const T* const in_low,
|
|
||||||
const T* const in_high,
|
|
||||||
const T* const out_low,
|
|
||||||
const T* const out_high,
|
|
||||||
T* const out,
|
|
||||||
const Shape& arg_shape,
|
|
||||||
const Shape& in_low_shape,
|
|
||||||
const Shape& in_high_shape,
|
|
||||||
const Shape& out_low_shape,
|
|
||||||
const Shape& out_high_shape,
|
|
||||||
size_t levels,
|
|
||||||
const op::AutoBroadcastSpec& broadcast)
|
|
||||||
{
|
|
||||||
using namespace fake_quantize_details;
|
|
||||||
|
|
||||||
const FeRound round_mode(FE_TONEAREST);
|
|
||||||
|
|
||||||
if (shape_size(in_low_shape) == 1 && shape_size(in_high_shape) == 1 &&
|
|
||||||
shape_size(out_low_shape) == 1 && shape_size(out_high_shape) == 1)
|
|
||||||
{
|
|
||||||
const size_t arg_size = shape_size(arg_shape);
|
|
||||||
const auto q = [=](const T& a) {
|
|
||||||
return quantize(a, *in_low, *in_high, *out_low, *out_high, levels);
|
|
||||||
};
|
|
||||||
for (size_t i = 0; i < arg_size; ++i)
|
|
||||||
{
|
|
||||||
out[i] = q(arg[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
NGRAPH_CHECK(in_low_shape.size() <= arg_shape.size() &&
|
|
||||||
in_high_shape.size() <= arg_shape.size() &&
|
|
||||||
out_low_shape.size() <= arg_shape.size() &&
|
|
||||||
out_high_shape.size() <= arg_shape.size(),
|
|
||||||
"Tensors with inout\\output ranges should have rank less or "
|
|
||||||
"equal to data tensor rank equal to ",
|
|
||||||
arg_shape.size());
|
|
||||||
|
|
||||||
const QuantizationBound<T> in_low_bound(
|
|
||||||
in_low, in_low_shape, arg_shape, broadcast);
|
|
||||||
const QuantizationBound<T> in_high_bound(
|
|
||||||
in_high, in_high_shape, arg_shape, broadcast);
|
|
||||||
const QuantizationBound<T> out_low_bound(
|
|
||||||
out_low, out_low_shape, arg_shape, broadcast);
|
|
||||||
const QuantizationBound<T> out_high_bound(
|
|
||||||
out_high, out_high_shape, arg_shape, broadcast);
|
|
||||||
|
|
||||||
std::vector<size_t> current_dim(arg_shape.size(), 0);
|
|
||||||
const auto arg_shape_size = shape_size(arg_shape);
|
|
||||||
for (size_t index = 0; index < arg_shape_size; ++index)
|
|
||||||
{
|
|
||||||
const T in_low_val = in_low_bound.get_value(current_dim, index);
|
|
||||||
const T in_high_val = in_high_bound.get_value(current_dim, index);
|
|
||||||
const T out_low_val = out_low_bound.get_value(current_dim, index);
|
|
||||||
const T out_high_val = out_high_bound.get_value(current_dim, index);
|
|
||||||
|
|
||||||
out[index] = quantize(arg[index],
|
|
||||||
in_low_val,
|
|
||||||
in_high_val,
|
|
||||||
out_low_val,
|
|
||||||
out_high_val,
|
|
||||||
levels);
|
|
||||||
increment_current_dim(current_dim, arg_shape);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace v1
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void fake_quantize(const T* const arg,
|
void
|
||||||
const T* const in_low,
|
fake_quantize(const T* const arg,
|
||||||
const T* const in_high,
|
const T* const in_low,
|
||||||
const T* const out_low,
|
const T* const in_high,
|
||||||
const T* const out_high,
|
const T* const out_low,
|
||||||
T* const out,
|
const T* const out_high,
|
||||||
const Shape& arg_shape,
|
T* const out,
|
||||||
const Shape& in_low_shape,
|
const Shape& arg_shape,
|
||||||
const Shape& in_high_shape,
|
const Shape& in_low_shape,
|
||||||
const Shape& out_low_shape,
|
const Shape& in_high_shape,
|
||||||
const Shape& out_high_shape,
|
const Shape& out_low_shape,
|
||||||
size_t levels)
|
const Shape& out_high_shape,
|
||||||
|
size_t levels,
|
||||||
|
const op::AutoBroadcastSpec& broadcast = op::AutoBroadcastType::NUMPY)
|
||||||
{
|
{
|
||||||
v1::fake_quantize(arg,
|
using namespace fake_quantize_details;
|
||||||
in_low,
|
|
||||||
in_high,
|
const FeRound round_mode(FE_TONEAREST);
|
||||||
out_low,
|
|
||||||
out_high,
|
if (shape_size(in_low_shape) == 1 && shape_size(in_high_shape) == 1 &&
|
||||||
out,
|
shape_size(out_low_shape) == 1 && shape_size(out_high_shape) == 1)
|
||||||
arg_shape,
|
{
|
||||||
in_low_shape,
|
const size_t arg_size = shape_size(arg_shape);
|
||||||
in_high_shape,
|
const auto q = [=](const T& a) {
|
||||||
out_low_shape,
|
return quantize(a, *in_low, *in_high, *out_low, *out_high, levels);
|
||||||
out_high_shape,
|
};
|
||||||
levels,
|
for (size_t i = 0; i < arg_size; ++i)
|
||||||
op::AutoBroadcastType::NUMPY);
|
{
|
||||||
|
out[i] = q(arg[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
NGRAPH_CHECK(in_low_shape.size() <= arg_shape.size() &&
|
||||||
|
in_high_shape.size() <= arg_shape.size() &&
|
||||||
|
out_low_shape.size() <= arg_shape.size() &&
|
||||||
|
out_high_shape.size() <= arg_shape.size(),
|
||||||
|
"Tensors with inout\\output ranges should have rank less or "
|
||||||
|
"equal to data tensor rank equal to ",
|
||||||
|
arg_shape.size());
|
||||||
|
|
||||||
|
const QuantizationBound<T> in_low_bound(
|
||||||
|
in_low, in_low_shape, arg_shape, broadcast);
|
||||||
|
const QuantizationBound<T> in_high_bound(
|
||||||
|
in_high, in_high_shape, arg_shape, broadcast);
|
||||||
|
const QuantizationBound<T> out_low_bound(
|
||||||
|
out_low, out_low_shape, arg_shape, broadcast);
|
||||||
|
const QuantizationBound<T> out_high_bound(
|
||||||
|
out_high, out_high_shape, arg_shape, broadcast);
|
||||||
|
|
||||||
|
std::vector<size_t> current_dim(arg_shape.size(), 0);
|
||||||
|
const auto arg_shape_size = shape_size(arg_shape);
|
||||||
|
for (size_t index = 0; index < arg_shape_size; ++index)
|
||||||
|
{
|
||||||
|
const T in_low_val = in_low_bound.get_value(current_dim, index);
|
||||||
|
const T in_high_val = in_high_bound.get_value(current_dim, index);
|
||||||
|
const T out_low_val = out_low_bound.get_value(current_dim, index);
|
||||||
|
const T out_high_val = out_high_bound.get_value(current_dim, index);
|
||||||
|
|
||||||
|
out[index] = quantize(
|
||||||
|
arg[index], in_low_val, in_high_val, out_low_val, out_high_val, levels);
|
||||||
|
increment_current_dim(current_dim, arg_shape);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} // namespace reference
|
} // namespace reference
|
||||||
} // namespace runtime
|
} // namespace runtime
|
||||||
|
@ -1173,7 +1173,7 @@ namespace
|
|||||||
info.selected_outputs_shape,
|
info.selected_outputs_shape,
|
||||||
selected_indices.data(),
|
selected_indices.data(),
|
||||||
info.selected_indices_shape,
|
info.selected_indices_shape,
|
||||||
valid_outputs.data());
|
valid_outputs.data());
|
||||||
|
|
||||||
void* pscores = nullptr;
|
void* pscores = nullptr;
|
||||||
void* pselected_num = nullptr;
|
void* pselected_num = nullptr;
|
||||||
@ -2383,19 +2383,19 @@ namespace
|
|||||||
const HostTensorVector& inputs)
|
const HostTensorVector& inputs)
|
||||||
{
|
{
|
||||||
using T = typename element_type_traits<ET>::value_type;
|
using T = typename element_type_traits<ET>::value_type;
|
||||||
runtime::reference::v1::fake_quantize<T>(inputs[0]->get_data_ptr<const T>(),
|
runtime::reference::fake_quantize<T>(inputs[0]->get_data_ptr<const T>(),
|
||||||
inputs[1]->get_data_ptr<const T>(),
|
inputs[1]->get_data_ptr<const T>(),
|
||||||
inputs[2]->get_data_ptr<const T>(),
|
inputs[2]->get_data_ptr<const T>(),
|
||||||
inputs[3]->get_data_ptr<const T>(),
|
inputs[3]->get_data_ptr<const T>(),
|
||||||
inputs[4]->get_data_ptr<const T>(),
|
inputs[4]->get_data_ptr<const T>(),
|
||||||
outputs[0]->get_data_ptr<T>(),
|
outputs[0]->get_data_ptr<T>(),
|
||||||
op->get_input_shape(0),
|
op->get_input_shape(0),
|
||||||
op->get_input_shape(1),
|
op->get_input_shape(1),
|
||||||
op->get_input_shape(2),
|
op->get_input_shape(2),
|
||||||
op->get_input_shape(3),
|
op->get_input_shape(3),
|
||||||
op->get_input_shape(4),
|
op->get_input_shape(4),
|
||||||
op->get_levels(),
|
op->get_levels(),
|
||||||
op->get_auto_broadcast());
|
op->get_auto_broadcast());
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user