remove v0 namespace from reference implementation of fake quantize. (#6977)

* remove `v0` namespace from reference implementation of fake quantize.

* fix ngraph check message
This commit is contained in:
Patryk Elszkowski 2021-08-11 12:06:25 +02:00 committed by GitHub
parent 289df8db27
commit 51d511c8ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 80 deletions

View File

@ -194,76 +194,70 @@ namespace ngraph
}
} // namespace fake_quantize_details
inline namespace v0
template <typename T>
void fake_quantize(const T* const arg,
const T* const in_low,
const T* const in_high,
const T* const out_low,
const T* const out_high,
T* const out,
const Shape& arg_shape,
const Shape& in_low_shape,
const Shape& in_high_shape,
const Shape& out_low_shape,
const Shape& out_high_shape,
size_t levels,
const op::AutoBroadcastSpec& broadcast)
{
template <typename T>
void fake_quantize(const T* const arg,
const T* const in_low,
const T* const in_high,
const T* const out_low,
const T* const out_high,
T* const out,
const Shape& arg_shape,
const Shape& in_low_shape,
const Shape& in_high_shape,
const Shape& out_low_shape,
const Shape& out_high_shape,
size_t levels,
const op::AutoBroadcastSpec& broadcast)
using namespace fake_quantize_details;
if (shape_size(in_low_shape) == 1 && shape_size(in_high_shape) == 1 &&
shape_size(out_low_shape) == 1 && shape_size(out_high_shape) == 1)
{
using namespace fake_quantize_details;
if (shape_size(in_low_shape) == 1 && shape_size(in_high_shape) == 1 &&
shape_size(out_low_shape) == 1 && shape_size(out_high_shape) == 1)
const size_t arg_size = shape_size(arg_shape);
const auto q = [=](const T& a) {
return quantize(a, *in_low, *in_high, *out_low, *out_high, levels);
};
for (size_t i = 0; i < arg_size; ++i)
{
const size_t arg_size = shape_size(arg_shape);
const auto q = [=](const T& a) {
return quantize(a, *in_low, *in_high, *out_low, *out_high, levels);
};
for (size_t i = 0; i < arg_size; ++i)
{
out[i] = q(arg[i]);
}
}
else
{
NGRAPH_CHECK(in_low_shape.size() <= arg_shape.size() &&
in_high_shape.size() <= arg_shape.size() &&
out_low_shape.size() <= arg_shape.size() &&
out_high_shape.size() <= arg_shape.size(),
"Tensors with inout\\output ranges should have rank less or "
"equal to data tensor rank equal to ",
arg_shape.size());
const QuantizationBound<T> in_low_bound(
in_low, in_low_shape, arg_shape, broadcast);
const QuantizationBound<T> in_high_bound(
in_high, in_high_shape, arg_shape, broadcast);
const QuantizationBound<T> out_low_bound(
out_low, out_low_shape, arg_shape, broadcast);
const QuantizationBound<T> out_high_bound(
out_high, out_high_shape, arg_shape, broadcast);
std::vector<size_t> current_dim(arg_shape.size(), 0);
const auto arg_shape_size = shape_size(arg_shape);
for (size_t index = 0; index < arg_shape_size; ++index)
{
const T in_low_val = in_low_bound.get_value(current_dim, index);
const T in_high_val = in_high_bound.get_value(current_dim, index);
const T out_low_val = out_low_bound.get_value(current_dim, index);
const T out_high_val = out_high_bound.get_value(current_dim, index);
out[index] = quantize(arg[index],
in_low_val,
in_high_val,
out_low_val,
out_high_val,
levels);
increment_current_dim(current_dim, arg_shape);
}
out[i] = q(arg[i]);
}
}
} // namespace v0
} // namespace reference
} // namespace runtime
else
{
NGRAPH_CHECK(in_low_shape.size() <= arg_shape.size() &&
in_high_shape.size() <= arg_shape.size() &&
out_low_shape.size() <= arg_shape.size() &&
out_high_shape.size() <= arg_shape.size(),
"Tensors with input\\output ranges should have rank less or "
"equal to data tensor rank equal to ",
arg_shape.size());
const QuantizationBound<T> in_low_bound(
in_low, in_low_shape, arg_shape, broadcast);
const QuantizationBound<T> in_high_bound(
in_high, in_high_shape, arg_shape, broadcast);
const QuantizationBound<T> out_low_bound(
out_low, out_low_shape, arg_shape, broadcast);
const QuantizationBound<T> out_high_bound(
out_high, out_high_shape, arg_shape, broadcast);
std::vector<size_t> current_dim(arg_shape.size(), 0);
const auto arg_shape_size = shape_size(arg_shape);
for (size_t index = 0; index < arg_shape_size; ++index)
{
const T in_low_val = in_low_bound.get_value(current_dim, index);
const T in_high_val = in_high_bound.get_value(current_dim, index);
const T out_low_val = out_low_bound.get_value(current_dim, index);
const T out_high_val = out_high_bound.get_value(current_dim, index);
out[index] = quantize(
arg[index], in_low_val, in_high_val, out_low_val, out_high_val, levels);
increment_current_dim(current_dim, arg_shape);
}
}
}
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -2439,19 +2439,19 @@ namespace
const HostTensorVector& inputs)
{
using T = typename element_type_traits<ET>::value_type;
runtime::reference::v0::fake_quantize<T>(inputs[0]->get_data_ptr<const T>(),
inputs[1]->get_data_ptr<const T>(),
inputs[2]->get_data_ptr<const T>(),
inputs[3]->get_data_ptr<const T>(),
inputs[4]->get_data_ptr<const T>(),
outputs[0]->get_data_ptr<T>(),
op->get_input_shape(0),
op->get_input_shape(1),
op->get_input_shape(2),
op->get_input_shape(3),
op->get_input_shape(4),
op->get_levels(),
op->get_auto_broadcast());
runtime::reference::fake_quantize<T>(inputs[0]->get_data_ptr<const T>(),
inputs[1]->get_data_ptr<const T>(),
inputs[2]->get_data_ptr<const T>(),
inputs[3]->get_data_ptr<const T>(),
inputs[4]->get_data_ptr<const T>(),
outputs[0]->get_data_ptr<T>(),
op->get_input_shape(0),
op->get_input_shape(1),
op->get_input_shape(2),
op->get_input_shape(3),
op->get_input_shape(4),
op->get_levels(),
op->get_auto_broadcast());
return true;
}