remove v0
namespace from reference implementation of fake quantize. (#6977)
* remove `v0` namespace from reference implementation of fake quantize. * fix ngraph check message
This commit is contained in:
parent
289df8db27
commit
51d511c8ac
@ -194,76 +194,70 @@ namespace ngraph
|
||||
}
|
||||
|
||||
} // namespace fake_quantize_details
|
||||
inline namespace v0
|
||||
|
||||
template <typename T>
|
||||
void fake_quantize(const T* const arg,
|
||||
const T* const in_low,
|
||||
const T* const in_high,
|
||||
const T* const out_low,
|
||||
const T* const out_high,
|
||||
T* const out,
|
||||
const Shape& arg_shape,
|
||||
const Shape& in_low_shape,
|
||||
const Shape& in_high_shape,
|
||||
const Shape& out_low_shape,
|
||||
const Shape& out_high_shape,
|
||||
size_t levels,
|
||||
const op::AutoBroadcastSpec& broadcast)
|
||||
{
|
||||
template <typename T>
|
||||
void fake_quantize(const T* const arg,
|
||||
const T* const in_low,
|
||||
const T* const in_high,
|
||||
const T* const out_low,
|
||||
const T* const out_high,
|
||||
T* const out,
|
||||
const Shape& arg_shape,
|
||||
const Shape& in_low_shape,
|
||||
const Shape& in_high_shape,
|
||||
const Shape& out_low_shape,
|
||||
const Shape& out_high_shape,
|
||||
size_t levels,
|
||||
const op::AutoBroadcastSpec& broadcast)
|
||||
using namespace fake_quantize_details;
|
||||
|
||||
if (shape_size(in_low_shape) == 1 && shape_size(in_high_shape) == 1 &&
|
||||
shape_size(out_low_shape) == 1 && shape_size(out_high_shape) == 1)
|
||||
{
|
||||
using namespace fake_quantize_details;
|
||||
|
||||
if (shape_size(in_low_shape) == 1 && shape_size(in_high_shape) == 1 &&
|
||||
shape_size(out_low_shape) == 1 && shape_size(out_high_shape) == 1)
|
||||
const size_t arg_size = shape_size(arg_shape);
|
||||
const auto q = [=](const T& a) {
|
||||
return quantize(a, *in_low, *in_high, *out_low, *out_high, levels);
|
||||
};
|
||||
for (size_t i = 0; i < arg_size; ++i)
|
||||
{
|
||||
const size_t arg_size = shape_size(arg_shape);
|
||||
const auto q = [=](const T& a) {
|
||||
return quantize(a, *in_low, *in_high, *out_low, *out_high, levels);
|
||||
};
|
||||
for (size_t i = 0; i < arg_size; ++i)
|
||||
{
|
||||
out[i] = q(arg[i]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
NGRAPH_CHECK(in_low_shape.size() <= arg_shape.size() &&
|
||||
in_high_shape.size() <= arg_shape.size() &&
|
||||
out_low_shape.size() <= arg_shape.size() &&
|
||||
out_high_shape.size() <= arg_shape.size(),
|
||||
"Tensors with inout\\output ranges should have rank less or "
|
||||
"equal to data tensor rank equal to ",
|
||||
arg_shape.size());
|
||||
|
||||
const QuantizationBound<T> in_low_bound(
|
||||
in_low, in_low_shape, arg_shape, broadcast);
|
||||
const QuantizationBound<T> in_high_bound(
|
||||
in_high, in_high_shape, arg_shape, broadcast);
|
||||
const QuantizationBound<T> out_low_bound(
|
||||
out_low, out_low_shape, arg_shape, broadcast);
|
||||
const QuantizationBound<T> out_high_bound(
|
||||
out_high, out_high_shape, arg_shape, broadcast);
|
||||
|
||||
std::vector<size_t> current_dim(arg_shape.size(), 0);
|
||||
const auto arg_shape_size = shape_size(arg_shape);
|
||||
for (size_t index = 0; index < arg_shape_size; ++index)
|
||||
{
|
||||
const T in_low_val = in_low_bound.get_value(current_dim, index);
|
||||
const T in_high_val = in_high_bound.get_value(current_dim, index);
|
||||
const T out_low_val = out_low_bound.get_value(current_dim, index);
|
||||
const T out_high_val = out_high_bound.get_value(current_dim, index);
|
||||
|
||||
out[index] = quantize(arg[index],
|
||||
in_low_val,
|
||||
in_high_val,
|
||||
out_low_val,
|
||||
out_high_val,
|
||||
levels);
|
||||
increment_current_dim(current_dim, arg_shape);
|
||||
}
|
||||
out[i] = q(arg[i]);
|
||||
}
|
||||
}
|
||||
} // namespace v0
|
||||
} // namespace reference
|
||||
} // namespace runtime
|
||||
else
|
||||
{
|
||||
NGRAPH_CHECK(in_low_shape.size() <= arg_shape.size() &&
|
||||
in_high_shape.size() <= arg_shape.size() &&
|
||||
out_low_shape.size() <= arg_shape.size() &&
|
||||
out_high_shape.size() <= arg_shape.size(),
|
||||
"Tensors with input\\output ranges should have rank less or "
|
||||
"equal to data tensor rank equal to ",
|
||||
arg_shape.size());
|
||||
|
||||
const QuantizationBound<T> in_low_bound(
|
||||
in_low, in_low_shape, arg_shape, broadcast);
|
||||
const QuantizationBound<T> in_high_bound(
|
||||
in_high, in_high_shape, arg_shape, broadcast);
|
||||
const QuantizationBound<T> out_low_bound(
|
||||
out_low, out_low_shape, arg_shape, broadcast);
|
||||
const QuantizationBound<T> out_high_bound(
|
||||
out_high, out_high_shape, arg_shape, broadcast);
|
||||
|
||||
std::vector<size_t> current_dim(arg_shape.size(), 0);
|
||||
const auto arg_shape_size = shape_size(arg_shape);
|
||||
for (size_t index = 0; index < arg_shape_size; ++index)
|
||||
{
|
||||
const T in_low_val = in_low_bound.get_value(current_dim, index);
|
||||
const T in_high_val = in_high_bound.get_value(current_dim, index);
|
||||
const T out_low_val = out_low_bound.get_value(current_dim, index);
|
||||
const T out_high_val = out_high_bound.get_value(current_dim, index);
|
||||
|
||||
out[index] = quantize(
|
||||
arg[index], in_low_val, in_high_val, out_low_val, out_high_val, levels);
|
||||
increment_current_dim(current_dim, arg_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace reference
|
||||
} // namespace runtime
|
||||
} // namespace ngraph
|
||||
|
@ -2439,19 +2439,19 @@ namespace
|
||||
const HostTensorVector& inputs)
|
||||
{
|
||||
using T = typename element_type_traits<ET>::value_type;
|
||||
runtime::reference::v0::fake_quantize<T>(inputs[0]->get_data_ptr<const T>(),
|
||||
inputs[1]->get_data_ptr<const T>(),
|
||||
inputs[2]->get_data_ptr<const T>(),
|
||||
inputs[3]->get_data_ptr<const T>(),
|
||||
inputs[4]->get_data_ptr<const T>(),
|
||||
outputs[0]->get_data_ptr<T>(),
|
||||
op->get_input_shape(0),
|
||||
op->get_input_shape(1),
|
||||
op->get_input_shape(2),
|
||||
op->get_input_shape(3),
|
||||
op->get_input_shape(4),
|
||||
op->get_levels(),
|
||||
op->get_auto_broadcast());
|
||||
runtime::reference::fake_quantize<T>(inputs[0]->get_data_ptr<const T>(),
|
||||
inputs[1]->get_data_ptr<const T>(),
|
||||
inputs[2]->get_data_ptr<const T>(),
|
||||
inputs[3]->get_data_ptr<const T>(),
|
||||
inputs[4]->get_data_ptr<const T>(),
|
||||
outputs[0]->get_data_ptr<T>(),
|
||||
op->get_input_shape(0),
|
||||
op->get_input_shape(1),
|
||||
op->get_input_shape(2),
|
||||
op->get_input_shape(3),
|
||||
op->get_input_shape(4),
|
||||
op->get_levels(),
|
||||
op->get_auto_broadcast());
|
||||
return true;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user