[GNA] Support in GNA plugin for power layer with non-1 exponents (#997)

* added support for power layer with non-1 exponents to GNA plugin

* reverted a change caused by merge issue

* fixes for review comments (typo fix - lrelu instead of leru, unnamed structure instead of of named one in union with arguments of activation function, name fix - input instead of inputs),

scale-shift implementation based on affine layer instead of PWL,

* fixed code style

* fixes for coding style in scale_factor_calc.hpp

* added domain for power function

* fixed review comment - power function specific methods

* added check if dynamic casting was successful

* removed I16 as it is not supported by ngraph

* fixed initialization per review comment
This commit is contained in:
Bartosz Sochacki 2020-07-10 12:39:29 +02:00 committed by GitHub
parent d9706da8d0
commit 8da662b2b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 658 additions and 92 deletions

View File

@ -460,6 +460,10 @@ void GNAPluginNS::backend::AMIntelDNN::WriteGraphWizModel(const char *filename)
#define IS_DIAG(k)\
(components[k].operation == kDnnDiagonalOp)
#define IS_POW(k)\
(components[k].operation == kDnnPiecewiselinearOp &&\
components[k].op.pwl.func_id == kActPow)
#define OUTPUTS(idx)\
components[idx].ptr_outputs, components[idx].num_rows_out*components[idx].num_columns_out * components[idx].num_bytes_per_output
@ -531,7 +535,12 @@ void GNAPluginNS::backend::AMIntelDNN::WriteGraphWizModel(const char *filename)
graph << " <TR><TD> badr</TD><TD>" << components[k].op.affine.ptr_biases<< "</TD></TR>\n";
}
if (IS_RELU(k)) {
graph << " <TR><TD> negative_slope</TD><TD>" << components[k].op.pwl.func_id.negative_slope<< "</TD></TR>\n";
graph << " <TR><TD> negative_slope</TD><TD>" << components[k].op.pwl.func_id.args.lrelu.negative_slope<< "</TD></TR>\n";
}
if (IS_POW(k)) {
graph << " <TR><TD> exponent</TD><TD>" << components[k].op.pwl.func_id.args.pow.exponent << "</TD></TR>\n";
graph << " <TR><TD> scale</TD><TD>" << components[k].op.pwl.func_id.args.pow.scale << "</TD></TR>\n";
graph << " <TR><TD> offset</TD><TD>" << components[k].op.pwl.func_id.args.pow.offset << "</TD></TR>\n";
}
if (IS_CONV(k)) {
auto &conv = components[k].op.conv1D;

View File

@ -26,20 +26,33 @@ enum DnnActivationType : uint8_t {
kActNegLog,
kActNegHalfLog,
kActSoftSign,
kActPow,
kActNumType
};
struct DnnActivation {
// for prelu
DnnActivationType type;
float negative_slope;
union {
struct {
float negative_slope;
} lrelu;
struct {
float exponent;
float scale;
float offset;
} pow;
struct {
float reserved[3];
};
} args;
operator DnnActivationType () const noexcept {
return type;
}
static DnnActivation fromType(DnnActivationType type) {
DnnActivation activation;
activation.type = type;
activation.negative_slope = 0.0f;
activation.args = {};
return activation;
}
};
@ -61,7 +74,8 @@ static const char *intel_dnn_activation_name[kActNumType] = {
"kActNegLog",
"kActNegHalfLog",
"kActCustom",
"kActSoftSign"
"kActSoftSign",
"kActPow"
};
typedef enum DnnSoftmaxType {

View File

@ -224,8 +224,8 @@ void make_gna_pwl(const DnnActivation fun,
int16_t y_lower = INT16_MIN;
if (x_lower < y_lower * in_scale / out_scale) x_lower = FLOAT_TO_INT32(y_lower * in_scale / out_scale);
if (y_lower < x_lower * out_scale / in_scale) y_lower = FLOAT_TO_INT16(x_lower * out_scale / in_scale);
gna_pwl[0].yBase = y_lower * fun.negative_slope;
s = gna_slope(fun.negative_slope, in_scale, out_scale);
gna_pwl[0].yBase = y_lower * fun.args.lrelu.negative_slope;
s = gna_slope(fun.args.lrelu.negative_slope, in_scale, out_scale);
gna_pwl[0].xBase = (x_lower & XBASEMASK) | s.slope_scale_index; // zero out the 2 lsb
gna_pwl[0].slope = FLOAT_TO_INT16(s.slope * s.slope_scale);
@ -383,6 +383,162 @@ void make_gna_pwl(const DnnActivation fun,
<< "\n";
break;
}
case kActPow: {
float pow_exponent = fun.args.pow.exponent;
if (pow_exponent == 0.0f || pow_exponent == 1.0f) {
float pow_scale = fun.args.pow.scale;
float pow_offset = fun.args.pow.offset;
int32_t x_lower = INT32_MIN;
int32_t x_upper = INT32_MAX;
int16_t y_lower = INT16_MIN;
int16_t y_upper = INT16_MAX;
auto n_segments = 2;
if (pow_exponent == 0.0f) {
y_lower = y_upper = FLOAT_TO_INT16(1 * out_scale);
} else if (pow_exponent == 1.0f) {
if (x_lower < y_lower * in_scale / out_scale)
x_lower = FLOAT_TO_INT32(y_lower * in_scale / out_scale);
if (x_upper > y_upper * in_scale / out_scale)
x_upper = FLOAT_TO_INT32(y_upper * in_scale / out_scale);
if (y_lower < x_lower * out_scale / in_scale)
y_lower = FLOAT_TO_INT16(x_lower * out_scale / in_scale);
if (y_upper > x_upper * out_scale / in_scale)
y_upper = FLOAT_TO_INT16(x_upper * out_scale / in_scale);
if (pow_scale < 1) {
int16_t tmp = y_lower;
y_lower = y_upper;
y_upper = tmp;
}
int64_t x_lower_new = FLOAT_TO_INT32((x_lower / in_scale) / abs(pow_scale) * in_scale);
int64_t x_upper_new = FLOAT_TO_INT32((x_upper / in_scale) / abs(pow_scale) * in_scale);
x_lower = static_cast<int32_t>(x_lower_new);
x_upper = static_cast<int32_t>(x_upper_new);
if (x_lower_new < INT32_MIN) {
int16_t offset_lower = abs(x_lower_new - INT32_MIN) / in_scale * out_scale;
x_lower = INT32_MIN;
y_lower = y_lower + offset_lower;
}
if (x_upper_new > INT32_MAX) {
int16_t offset_upper = (x_upper_new - INT32_MAX) / in_scale * out_scale;
x_upper = INT32_MAX;
y_upper = y_upper - offset_upper;
}
int32_t y_lower_new = FLOAT_TO_INT32((y_lower / out_scale + pow_offset) * out_scale);
int32_t y_upper_new = FLOAT_TO_INT32((y_upper / out_scale + pow_offset) * out_scale);
y_lower = static_cast<int16_t>(y_lower_new);
y_upper = static_cast<int16_t>(y_upper_new);
if (y_lower_new < INT16_MIN) {
int32_t offset_lower = abs(y_lower_new - INT16_MIN) / out_scale * in_scale;
y_lower = INT16_MIN;
x_lower = x_lower + offset_lower;
}
if (y_lower_new > INT16_MAX) {
int32_t offset_lower = (y_lower_new - INT16_MAX) / out_scale * in_scale;
y_lower = INT16_MAX;
x_upper = x_upper + offset_lower;
}
if (y_upper_new > INT16_MAX) {
int32_t offset_upper = (y_upper_new - INT16_MAX) / out_scale * in_scale;
y_upper = INT16_MAX;
x_upper = x_upper - offset_upper;
}
if (y_upper_new < INT16_MIN) {
int32_t offset_upper = abs(y_upper_new - INT16_MAX) / out_scale * in_scale;
y_upper = INT16_MIN;
x_lower = x_lower - offset_upper;
}
}
gna_pwl.resize(n_segments);
gna_pwl[0].xBase = INT32_MIN & XBASEMASK; // zero out the 2 lsb
gna_pwl[0].yBase = y_lower;
gna_pwl[0].slope = 0;
gnalog() << gna_pwl[0].xBase / in_scale
<< " " << gna_pwl[0].yBase / out_scale
<< " " << 0
<< "\n";
gna_pwl[1].xBase = x_lower & XBASEMASK; // zero out the 2 lsb
gna_pwl[1].yBase = y_lower;
double slope = (static_cast<double>(y_upper - y_lower) / out_scale) / (static_cast<double>(x_upper - x_lower) / in_scale);
s = gna_slope(slope, in_scale, out_scale);
gna_pwl[1].slope = FLOAT_TO_INT16(s.slope * s.slope_scale);
gna_pwl[1].xBase = gna_pwl[1].xBase | s.slope_scale_index;
gnalog() << (int32_t)(gna_pwl[1].xBase & XBASEMASK) / in_scale
<< " " << gna_pwl[1].yBase / out_scale
<< " " << 1.0
<< "\n";
if (INT32_MAX > x_upper) { // need a right segment
gna_pwl.push_back({
static_cast<int32_t>(x_upper & XBASEMASK), // zero out the 2 lsb
y_upper,
0 });
gnalog() << (x_upper & XBASEMASK) / in_scale
<< " " << gna_pwl[2].yBase / out_scale
<< " " << 0
<< "\n";
}
} else {
auto n_segments = static_cast<int32_t> (pwl_size) + 1;
gna_pwl.resize(n_segments);
// insert extra segment for x values < l_bound
gna_pwl[0].xBase = static_cast<int32_t> (INT32_MIN & XBASEMASK); // zero out the 2 lsb
gnalog() << "=========================== Exp Segments ===========================\n";
gna_pwl[0].yBase = gna_pwl[1].yBase = 0;
gna_pwl[1].xBase = (static_cast<int32_t> (in_scale * (-pwl[0].b / pwl[0].m))) & XBASEMASK;
gna_pwl[0].slope = 0;
gnalog() << (gna_pwl[0].xBase) / in_scale
<< " " << (gna_pwl[0].yBase) / out_scale
<< " " << 0.0
<< "\n";
s = gna_slope(pwl[0].m, in_scale, out_scale);
gna_pwl[1].slope = FLOAT_TO_INT16(s.slope * s.slope_scale);
gna_pwl[1].xBase = gna_pwl[1].xBase | s.slope_scale_index;
gnalog() << ((int32_t)(gna_pwl[1].xBase & XBASEMASK) / in_scale)
<< " " << (gna_pwl[1].yBase) / out_scale
<< " " << pwl[0].m
<< "\n";
for (uint32_t i = 1; i < pwl_size - 1; ++i) {
s = gna_slope(pwl[i].m, in_scale, out_scale);
gna_pwl[i + 1].xBase = (static_cast<int32_t> (in_scale * pwl[i].alpha)) & XBASEMASK;
gna_pwl[i + 1].yBase = FLOAT_TO_INT16(pwl[i].beta * out_scale);
gna_pwl[i + 1].slope = FLOAT_TO_INT16(s.slope * s.slope_scale);
gna_pwl[i + 1].xBase = gna_pwl[i + 1].xBase | s.slope_scale_index;
gnalog() << (pwl[i].alpha)
<< " " << pwl[i].beta
<< " " << pwl[i].m
<< "\n";
}
// insert extra segment for xvalues > u_bound
gna_pwl[n_segments - 1].xBase =
((uint32_t)(in_scale * (INT16_MAX / out_scale - pwl[pwl_size - 2].b) / pwl[pwl_size - 2].m)) & XBASEMASK;
gna_pwl[n_segments - 1].yBase = INT16_MAX;
gna_pwl[n_segments - 1].slope = 0;
gnalog() << (gna_pwl[n_segments - 1].xBase / in_scale)
<< " " << 1.0
<< " " << 0.0
<< "\n";
break;
}
break;
}
default:
gnalog() << "Unexpected function activation!\n";
std::cerr << "Unexpected function activation!\n";

View File

@ -56,13 +56,16 @@ class ScaleFactorPerLayer<InferenceEngine::CNNLayer *> {
const float identity_scale_factor = 2049.0f;
const float k = 5;
const float k_identity = 6;
const double pow_domain = 16;
protected :
static bool fp32eq(float p1, float p2) {
return (std::abs(p1 - p2) <= 0.00001f * std::min(std::abs(p1), std::abs(p2)));
}
float getActivationScale(GNAPluginNS::LayerInfo const& layer, QuantizedLayerParams const* quantizedParams) {
float getActivationScale(InferenceEngine::CNNLayer const* cnnLayer,
GNAPluginNS::LayerInfo const& layer,
QuantizedLayerParams const* quantizedParams) {
// todo: calculate proper scale factor where we need to expand it a bit to be safe to stay in int16 weights
// set the initial value
float result = activation_scale_factor;
@ -105,6 +108,31 @@ class ScaleFactorPerLayer<InferenceEngine::CNNLayer *> {
> std::numeric_limits<int32_t>::max()-1) {
// if activation is one from relu family, we need to apply heuristic to avoid activation output overflow
result = (activation_scale_factor * 0.5);
} else if (layer.isPower()) {
auto powerLayer = dynamic_cast<InferenceEngine::PowerLayer const*>(cnnLayer);
if (!powerLayer) {
THROW_IE_EXCEPTION << "Incorrect Power Layer pointer \n";
}
auto input_min_value = static_cast<double>(std::numeric_limits<int32_t>::min());
auto input_max_value = static_cast<double>(std::numeric_limits<int32_t>::max());
auto output_max_value = static_cast<double>(std::numeric_limits<int16_t>::max());
auto x_min = fp32eq(fmod(powerLayer->power, 1.0), 0) ? input_min_value / quantizedParams->_src_quant.scale : 0.0;
x_min = std::max(x_min, -pow_domain);
auto x_max = input_max_value / quantizedParams->_src_quant.scale;
x_max = std::min(x_max, pow_domain);
auto val1 = pow(x_min * powerLayer->scale + powerLayer->offset, powerLayer->power);
auto val2 = pow(x_max * powerLayer->scale + powerLayer->offset, powerLayer->power);
auto abs_val = std::max(std::abs(val1), std::abs(val2));
auto scale_val = output_max_value / abs_val;
if (!std::isinf(scale_val)) {
result = scale_val;
}
}
return result;
}
@ -251,7 +279,7 @@ class ScaleFactorPerLayer<InferenceEngine::CNNLayer *> {
if (layerInfo.isActivation()) {
// todo: calculate proper scale factor where we need to expand it a bit to be safe to stay in int16 weights
// set the initial value
quant->_dst_quant.scale = getActivationScale(layerInfo, quant);
quant->_dst_quant.scale = getActivationScale(cnnLayer, layerInfo, quant);
}
return true;
}

View File

@ -338,8 +338,8 @@ void GNAGraphCompiler::PowerPrimitive(InferenceEngine::CNNLayerPtr layer) {
auto& power = dynamic_cast<PowerLayer&>(*layer.get());
auto quantized = InferenceEngine::getInjectedData<QuantizedLayerParams>(layer);
if (power.power != 1.0) {
THROW_IE_EXCEPTION << "[GNA plugin] unsupported power factor, expected 1 but was " << power.power;
if (power.power < 0.0f || power.power > 2.8f) {
THROW_IE_EXCEPTION << "[GNA plugin] unsupported power factor, expected be in <0, 2.8> range but was " << power.power;
}
auto input = layer->insData[0].lock();
@ -351,49 +351,117 @@ void GNAGraphCompiler::PowerPrimitive(InferenceEngine::CNNLayerPtr layer) {
uint32_t num_rows_out = num_rows_in;
uint32_t num_padding = ALIGN(num_rows_in, 8) - num_rows_in;
void* ptr_inputs = nullptr;
void* ptr_outputs = nullptr;
void* ptr_weights = nullptr;
void* ptr_biases = nullptr;
auto& currentComponent = dnnComponents.addComponent(layer->name, "power");
dnn->InitAffineComponent(currentComponent,
num_rows_in + num_padding,
num_columns_in,
num_rows_out + num_padding,
input->getPrecision().size(),
outputs->getPrecision().size(),
// TODO: only fp32 and Int16 tested
quantized == nullptr ? input->getPrecision().size() : 2,
quantized == nullptr ? input->getPrecision().size() : 4,
quantized == nullptr ? 1 : quantized->_weights_quant.scale,
quantized == nullptr ? 1 : quantized->_dst_quant.scale,
ptr_inputs,
ptr_outputs,
ptr_weights,
ptr_biases,
true);
size_t num_data_bytes_out = InferenceEngine::details::product(begin(outputs->getDims()), end(outputs->getDims()))
* outputs->getPrecision().size();
size_t num_data_bytes_in = InferenceEngine::details::product(begin(input->getDims()), end(input->getDims()))
* input->getPrecision().size();
connectOutput(layer, ptr_outputs, num_data_bytes_out);
connectInput(layer, ptr_inputs, num_data_bytes_in, 0, 0);
if (power.power == 1.0f) {
void* ptr_inputs = nullptr;
void* ptr_outputs = nullptr;
void* ptr_weights = nullptr;
void* ptr_biases = nullptr;
if (gnaFlags->sw_fp32) {
gnamem->readonly().push_value(ptr_weights, power.scale, num_rows_out, 64);
gnamem->readonly().push_value(ptr_biases, power.offset, num_rows_out, 64);
auto& currentComponent = dnnComponents.addComponent(layer->name, "power");
dnn->InitAffineComponent(currentComponent,
num_rows_in + num_padding,
num_columns_in,
num_rows_out + num_padding,
input->getPrecision().size(),
outputs->getPrecision().size(),
// TODO: only fp32 and Int16 tested
quantized == nullptr ? input->getPrecision().size() : 2,
quantized == nullptr ? input->getPrecision().size() : 4,
quantized == nullptr ? 1 : quantized->_weights_quant.scale,
quantized == nullptr ? 1 : quantized->_dst_quant.scale,
ptr_inputs,
ptr_outputs,
ptr_weights,
ptr_biases,
true);
connectOutput(layer, ptr_outputs, num_data_bytes_out);
connectInput(layer, ptr_inputs, num_data_bytes_in, 0, 0);
if (gnaFlags->sw_fp32) {
gnamem->readonly().push_value(ptr_weights, power.scale, num_rows_out, 64);
gnamem->readonly().push_value(ptr_biases, power.offset, num_rows_out, 64);
} else {
auto quantizedScale = FLOAT_TO_INT16(std::min(quantized->_weights_quant.scale * power.scale,
static_cast<float>(INT16_MAX)));
auto quantizedOffset = FLOAT_TO_INT32(std::min(quantized->_dst_quant.scale * power.offset,
static_cast<float>(INT32_MAX)));
gnamem->readonly().push_value<int16_t>(ptr_weights, quantizedScale, num_rows_out, 64);
gnamem->readonly().push_value<int32_t>(ptr_biases, quantizedOffset, num_rows_out, 64);
}
} else {
auto quantizedScale = FLOAT_TO_INT16(std::min(quantized->_weights_quant.scale * power.scale,
static_cast<float>(INT16_MAX)));
auto quantizedOffset = FLOAT_TO_INT32(std::min(quantized->_dst_quant.scale * power.offset,
static_cast<float>(INT32_MAX)));
gnamem->readonly().push_value<int16_t>(ptr_weights, quantizedScale, num_rows_out, 64);
gnamem->readonly().push_value<int32_t>(ptr_biases, quantizedOffset, num_rows_out, 64);
//use PWL to calculate power
std::vector<intel_pwl_segment_t> ptr_pwl_segments;
auto orientation = kDnnInterleavedOrientation;
auto activation_type = DnnActivation::fromType(kActPow);
activation_type.args.pow.exponent = power.power;
activation_type.args.pow.scale = power.scale;
activation_type.args.pow.offset = power.offset;
auto& pwlComponent = dnnComponents.addComponent(layer->name, "power");
intel_pwl_segment_t* ptr_pwl_segments_target = nullptr;
float output_pwl_scale_factor = quantized != nullptr ? quantized->_dst_quant.scale : 1.0f;
float input_pwl_scale_factor = quantized != nullptr ? quantized->_src_quant.scale : 1.0f;
if (!gnaFlags->sw_fp32) {
if (gnaFlags->uniformPwlDesign) {
uint32_t num_segments = POW_NUM_SEGMENTS;
if (activation_type.args.pow.exponent == 0.0f || activation_type.args.pow.exponent == 1.0f) {
num_segments = 3;
}
ptr_pwl_segments.resize(num_segments);
PwlDesign16(activation_type,
&*ptr_pwl_segments.begin(),
static_cast<uint32_t>(ptr_pwl_segments.size()),
input_pwl_scale_factor,
output_pwl_scale_factor);
} else {
PwlDesignOpt16(activation_type,
ptr_pwl_segments,
input_pwl_scale_factor,
output_pwl_scale_factor);
}
}
ptr_pwl_segments_target = reinterpret_cast<intel_pwl_segment_t*>(&ptr_pwl_segments_target);
void* ptr_pwl_input = nullptr;
void* ptr_pwl_outputs = nullptr;
dnn->InitPiecewiseLinearComponent(pwlComponent,
activation_type,
orientation,
num_rows_in + num_padding,
num_columns_in,
input->getPrecision().size(),
outputs->getPrecision().size(),
ptr_pwl_segments.size(),
output_pwl_scale_factor,
output_pwl_scale_factor,
ptr_pwl_input,
ptr_pwl_outputs,
ptr_pwl_segments_target);
connectOutput(layer, ptr_pwl_outputs, num_data_bytes_out);
connectInput(layer, ptr_pwl_input, num_data_bytes_in, 0, 0);
if (ptr_pwl_segments_target != nullptr) {
gnamem->readonly().push_local_ptr(ptr_pwl_segments_target,
&ptr_pwl_segments.front(),
ptr_pwl_segments.size() * sizeof(intel_pwl_segment_t),
64);
}
}
}
@ -1300,9 +1368,9 @@ void GNAGraphCompiler::PWLPrimitive(InferenceEngine::CNNLayerPtr layer) {
auto activation_type = DnnActivation::fromType(it->second);
if (it->second == kActRelu) {
auto reluLayer = dynamic_cast<ReLULayer*>(layer.get());
activation_type.negative_slope = reluLayer != nullptr ? reluLayer->negative_slope : 0.0f;
activation_type.args.lrelu.negative_slope = reluLayer != nullptr ? reluLayer->negative_slope : 0.0f;
} else {
activation_type.negative_slope = 0.0f;
activation_type.args.lrelu.negative_slope = 0.0f;
}
string actName = "unknown";

View File

@ -96,7 +96,14 @@ class LayerInfo {
"abs",
"neglog",
"neghalflog",
"softsign"};
"softsign",
"power"};
if (isPower()) {
auto powerLayer = as<const InferenceEngine::PowerLayer*>();
return powerLayer != nullptr && powerLayer->power != 1.0f;
}
return activations.find(layer->type) != activations.end();
}

View File

@ -8,6 +8,7 @@
#include <iostream>
#include <limits>
#include <cstdint>
#include <algorithm>
#ifdef _NO_MKL_
#include <cmath>
@ -44,13 +45,24 @@ double relu(const double x) { if (x < 0) { return(0.0); } else { return(x); } }
double leaky_relu(const double x) { if (x < 0.0) { return(LEAKYRELU_SLOPE*x); } else { return(x); } }
double clipping(const double x, const double lbound, const double ubound) { return((x < lbound)?lbound:((x > ubound)?ubound:x)); }
double pivot_search(std::vector<pwl_t>& result, double(*f)(const double),
double(*first_deriv_f)(const double),
const uint32_t N,
const double alpha_0,
const double alpha_N,
const double threshold,
const bool negative) {
double first_deriv_power(const double x, const std::tuple<double, double, double>& args) {
//scale * exponent * (offset + scale * x)^(exponent - 1)
return (std::get<1>(args) * std::get<0>(args) * pow(std::get<2>(args) + std::get<1>(args) * x, std::get<0>(args) - 1));
}
double power(const double x, const std::tuple<double, double, double>& args) {
return (pow(std::get<2>(args) + std::get<1>(args) * x, std::get<0>(args)));
}
template <typename T1, typename T2>
double pivot_search(std::vector<pwl_t>& result,
T1 f,
T2 first_deriv_f,
const uint32_t N,
const double alpha_0,
const double alpha_N,
const double threshold,
const bool negative) {
std::vector<std::vector<double>> t(N + 1);
std::vector<std::vector<double>> alpha(N + 1);
std::vector<std::vector<double>> epsilon(N + 1);
@ -61,12 +73,11 @@ double pivot_search(std::vector<pwl_t>& result, double(*f)(const double),
double max_epsilon = 0.0;
double max_epsilon_prev;
double min_epsilon;
double min_epsilon2;
double sgn = (negative) ? -1.0 : 1.0;
int j;
if ( f == nullptr ||
first_deriv_f == nullptr ||
threshold < 0) {
if (threshold < 0) {
return epsilon_final;
}
// Figure 4: Box #1
@ -163,7 +174,27 @@ double pivot_search(std::vector<pwl_t>& result, double(*f)(const double),
}
}
double calculate_error_pct(const DnnActivationType fun,
double pivot_search(std::vector<pwl_t>& result, double(*f)(const double),
double(*first_deriv_f)(const double),
const uint32_t N,
const double alpha_0,
const double alpha_N,
const double threshold,
const bool negative) {
double epsilon_final = 0.0;
if (f == nullptr ||
first_deriv_f == nullptr ||
threshold < 0) {
return epsilon_final;
}
auto fun = [&f](double x) -> double { return f(x); };
auto first_deriv = [&first_deriv_f](double x) -> double { return first_deriv_f(x); };
return pivot_search(result, fun, first_deriv, N, alpha_0, alpha_N, threshold, negative);
}
double calculate_error_pct(const DnnActivation& activation_type,
const double l_bound,
const double u_bound,
const double offset,
@ -176,7 +207,7 @@ double calculate_error_pct(const DnnActivationType fun,
return 0.0;
}
switch (fun) {
switch (activation_type) {
case kActSigmoid:
min_val = max_val = sigmoid(l_bound);
break;
@ -198,6 +229,9 @@ double calculate_error_pct(const DnnActivationType fun,
case kActSoftSign:
min_val = max_val = softsign(l_bound);
break;
case kActPow:
min_val = max_val = pow(activation_type.args.pow.offset + activation_type.args.pow.scale * l_bound, activation_type.args.pow.exponent);
break;
default:
break;
}
@ -205,7 +239,7 @@ double calculate_error_pct(const DnnActivationType fun,
for (int i = 0; i < samples; i++) {
double arg = l_bound + i * delta;
double val = 0.0;
switch (fun) {
switch (activation_type) {
case kActSigmoid:
val = sigmoid(arg);
break;
@ -227,6 +261,9 @@ double calculate_error_pct(const DnnActivationType fun,
case kActNegHalfLog:
val = neghalflog(arg);
break;
case kActPow:
val = pow(activation_type.args.pow.offset + activation_type.args.pow.scale * arg, activation_type.args.pow.exponent);
break;
default:
break;
}
@ -237,22 +274,37 @@ double calculate_error_pct(const DnnActivationType fun,
return(100.0 * fabs(offset) / (max_val - min_val));
}
bool split_search(const DnnActivationType fun,
double get_break_bound(const DnnActivation& activation_type) {
double break_bound = 0.0;
switch (activation_type) {
case kActExp:
break_bound = EXP_BREAK;
break;
case kActPow:
break_bound = POW_BREAK;
break;
default:
break;
}
return break_bound;
}
bool split_search(const DnnActivation& activation_type,
const double l_bound,
const double u_bound) {
bool is_split = false;
if (l_bound > u_bound) {
return is_split;
}
double break_bound = get_break_bound(activation_type);
switch (fun) {
switch (activation_type) {
case kActSigmoid:
case kActTanh:
case kActSoftSign:
is_split = ((l_bound < 0.0) && (u_bound > 0.0));
break;
case kActExp:
is_split = ((l_bound < EXP_BREAK) && (u_bound > EXP_BREAK));
case kActPow:
is_split = ((l_bound < break_bound) && (u_bound > break_bound));
break;
default:
is_split = false;
@ -272,7 +324,7 @@ inline std::vector<pwl_t> negative_pwl(const std::vector<pwl_t>& pwl) {
return(new_pwl);
}
std::vector<pwl_t> pwl_search(const DnnActivationType fun,
std::vector<pwl_t> pwl_search(const DnnActivation& activation_type,
const double l_bound,
const double u_bound,
const double threshold,
@ -288,25 +340,24 @@ std::vector<pwl_t> pwl_search(const DnnActivationType fun,
return pwl;
}
if (split_search(fun, l_bound, u_bound)) {
if (split_search(activation_type, l_bound, u_bound)) {
std::vector<pwl_t> pwl2;
double err_pct1 = 0.0, err_pct2 = 0.0;
const double break_bound = (fun == kActExp ? EXP_BREAK : 0.0);
double break_bound = get_break_bound(activation_type);
pwl = pwl_search(fun, l_bound, break_bound, threshold, allowed_err_pct, samples, err_pct1);
pwl = pwl_search(activation_type, l_bound, break_bound, threshold, allowed_err_pct, samples, err_pct1);
pwl = negative_pwl(pwl);
pwl2 = pwl_search(fun, break_bound, u_bound, threshold, allowed_err_pct, samples, err_pct2);
pwl2 = pwl_search(activation_type, break_bound, u_bound, threshold, allowed_err_pct, samples, err_pct2);
if (fun == kActExp) {
if (activation_type == kActExp || activation_type == kActPow) {
pwl2 = negative_pwl(pwl2);
}
// merge
pwl.pop_back(); // remove final alpha and beta from first half
pwl.insert(pwl.end(), pwl2.begin(), pwl2.end()); // concatenate the two halves
err_pct = (err_pct1 + err_pct2) / 2; // this is not quite correct but should give an indication
} else {
if (fun == kActIdentity) {
if (activation_type == kActIdentity) {
pwl.resize(2);
pwl[0].alpha = pwl[0].t = pwl[0].beta = -std::numeric_limits<float>::infinity();
pwl[0].m = 1.0;
@ -314,7 +365,7 @@ std::vector<pwl_t> pwl_search(const DnnActivationType fun,
pwl[1].alpha = std::numeric_limits<float>::infinity();
pwl[1].beta = std::numeric_limits<float>::infinity();
} else if (fun == kActKaldiLstmClipping) {
} else if (activation_type == kActKaldiLstmClipping) {
pwl.resize(4);
pwl[0].alpha = pwl[0].t = pwl[0].beta = -std::numeric_limits<float>::infinity();
pwl[0].m = 0.0;
@ -327,7 +378,7 @@ std::vector<pwl_t> pwl_search(const DnnActivationType fun,
pwl[2].b = KALDI_LSTM_CLIP_UPPER;
pwl[3].alpha = pwl[3].beta = std::numeric_limits<float>::infinity();
} else if (fun == kActSign) {
} else if (activation_type == kActSign) {
pwl.resize(4);
pwl[0].alpha = pwl[0].t = -std::numeric_limits<float>::infinity();
pwl[0].m = 0.0;
@ -343,7 +394,7 @@ std::vector<pwl_t> pwl_search(const DnnActivationType fun,
pwl[2].m = 0.0;
pwl[3].alpha = pwl[3].beta = std::numeric_limits<float>::infinity();
} else if (fun == kActAbs) {
} else if (activation_type == kActAbs) {
pwl.resize(2);
pwl[0].alpha = pwl[0].t = pwl[0].beta = -std::numeric_limits<float>::infinity();
pwl[0].m = -1.0;
@ -351,11 +402,10 @@ std::vector<pwl_t> pwl_search(const DnnActivationType fun,
pwl[1].alpha = pwl[1].t = pwl[1].beta = std::numeric_limits<float>::infinity();
pwl[1].m = 1.0;
pwl[1].b = 0.0;
} else {
bool negative = false;
switch (fun) {
switch (activation_type) {
case kActSigmoid:
if (u_bound == 0) negative = true; // make left half convex
err = pivot_search(pwl, sigmoid, first_deriv_sigmoid, n_segments, l_bound, u_bound, threshold, negative);
@ -383,14 +433,24 @@ std::vector<pwl_t> pwl_search(const DnnActivationType fun,
negative = true; // make function convex
err = pivot_search(pwl, neghalflog, first_deriv_neghalflog, n_segments, l_bound, u_bound, threshold, negative);
break;
case kActPow: {
negative = (fmod(activation_type.args.pow.exponent, 1.0) == 0) ? true : false;
auto args = std::tuple<double, double, double>{ activation_type.args.pow.exponent,
activation_type.args.pow.scale,
activation_type.args.pow.offset };
auto fun = [&args](double x) -> double { return power(x, args); };
auto first_deriv = [&args](double x) -> double { return first_deriv_power(x, args); };
err = pivot_search(pwl, fun, first_deriv, n_segments, l_bound, u_bound, threshold, negative);
break;
}
default:
break;
}
err_pct = calculate_error_pct(fun, l_bound, u_bound, err, samples);
err_pct = calculate_error_pct(activation_type, l_bound, u_bound, err, samples);
while ((n_segments < PWL_MAX_ITERATIONS) && (allowed_err_pct < err_pct)) {
n_segments += 1;
switch (fun) {
switch (activation_type) {
case kActSigmoid:
err = pivot_search(pwl, sigmoid, first_deriv_sigmoid, n_segments, l_bound, u_bound, threshold, negative);
break;
@ -412,10 +472,19 @@ std::vector<pwl_t> pwl_search(const DnnActivationType fun,
case kActNegHalfLog:
err = pivot_search(pwl, neghalflog, first_deriv_neghalflog, n_segments, l_bound, u_bound, threshold, negative);
break;
case kActPow: {
auto args = std::tuple<double, double, double>{ activation_type.args.pow.exponent,
activation_type.args.pow.scale,
activation_type.args.pow.offset };
auto fun = [&args](double x) { return power(x, args); };
auto first_deriv = [&args](double x) { return first_deriv_power(x, args); };
err = pivot_search(pwl, fun, first_deriv, n_segments, l_bound, u_bound, threshold, negative);
break;
}
default:
break;
}
err_pct = calculate_error_pct(fun, l_bound, u_bound, err, samples);
err_pct = calculate_error_pct(activation_type, l_bound, u_bound, err, samples);
}
if (n_segments >= PWL_MAX_ITERATIONS) {
@ -435,15 +504,15 @@ void PwlDesignOpt16(const DnnActivation activation_type,
double err_pct = 0.0;
switch (activation_type) {
case kActSigmoid:
pwl = pwl_search(kActSigmoid, -SIGMOID_DOMAIN, SIGMOID_DOMAIN, PWL_DESIGN_THRESHOLD, PWL_MAX_ERR_PERCENT, PWL_DESIGN_SAMPLES, err_pct);
pwl = pwl_search(activation_type, -SIGMOID_DOMAIN, SIGMOID_DOMAIN, PWL_DESIGN_THRESHOLD, PWL_MAX_ERR_PERCENT, PWL_DESIGN_SAMPLES, err_pct);
make_gna_pwl(activation_type, pwl, -SIGMOID_DOMAIN, SIGMOID_DOMAIN, scale_in, scale_out, ptr_segment);
break;
case kActTanh:
pwl = pwl_search(kActTanh, -TANH_DOMAIN, TANH_DOMAIN, PWL_DESIGN_THRESHOLD, PWL_MAX_ERR_PERCENT, PWL_DESIGN_SAMPLES, err_pct);
pwl = pwl_search(activation_type, -TANH_DOMAIN, TANH_DOMAIN, PWL_DESIGN_THRESHOLD, PWL_MAX_ERR_PERCENT, PWL_DESIGN_SAMPLES, err_pct);
make_gna_pwl(activation_type, pwl, -TANH_DOMAIN, TANH_DOMAIN, scale_in, scale_out, ptr_segment);
break;
case kActSoftSign:
pwl = pwl_search(kActSoftSign, -SOFTSIGN_DOMAIN, SOFTSIGN_DOMAIN, PWL_DESIGN_THRESHOLD, PWL_MAX_ERR_PERCENT, PWL_DESIGN_SAMPLES, err_pct);
pwl = pwl_search(activation_type, -SOFTSIGN_DOMAIN, SOFTSIGN_DOMAIN, PWL_DESIGN_THRESHOLD, PWL_MAX_ERR_PERCENT, PWL_DESIGN_SAMPLES, err_pct);
make_gna_pwl(activation_type, pwl, -SOFTSIGN_DOMAIN, SOFTSIGN_DOMAIN, scale_in, scale_out, ptr_segment);
break;
case kActRelu:
@ -461,28 +530,28 @@ void PwlDesignOpt16(const DnnActivation activation_type,
case kActLog: {
double x_min = (1 + ~XBASEMASK) / scale_in;
double x_max = ((INT32_MAX / scale_in) < LOG_DOMAIN) ? (INT32_MAX / scale_in) : LOG_DOMAIN;
pwl = pwl_search(kActLog, x_min, x_max, PWL_DESIGN_THRESHOLD, 0.066*PWL_MAX_ERR_PERCENT, PWL_DESIGN_SAMPLES, err_pct);
pwl = pwl_search(activation_type, x_min, x_max, PWL_DESIGN_THRESHOLD, 0.066*PWL_MAX_ERR_PERCENT, PWL_DESIGN_SAMPLES, err_pct);
make_gna_pwl(activation_type, pwl, x_min, x_max, scale_in, scale_out, ptr_segment);
break;
}
case kActNegLog: {
double x_min = (1 + ~XBASEMASK) / scale_in;
double x_max = ((INT32_MAX / scale_in) < LOG_DOMAIN) ? (INT32_MAX / scale_in) : LOG_DOMAIN;
pwl = pwl_search(kActNegLog, x_min, x_max, PWL_DESIGN_THRESHOLD, 0.066*PWL_MAX_ERR_PERCENT, PWL_DESIGN_SAMPLES, err_pct);
pwl = pwl_search(activation_type, x_min, x_max, PWL_DESIGN_THRESHOLD, 0.066*PWL_MAX_ERR_PERCENT, PWL_DESIGN_SAMPLES, err_pct);
make_gna_pwl(activation_type, pwl, x_min, x_max, scale_in, scale_out, ptr_segment);
break;
}
case kActNegHalfLog: {
double x_min = (1 + ~XBASEMASK) / scale_in;
double x_max = ((INT32_MAX / scale_in) < LOG_DOMAIN) ? (INT32_MAX / scale_in) : LOG_DOMAIN;
pwl = pwl_search(kActNegHalfLog, x_min, x_max, PWL_DESIGN_THRESHOLD, 0.066*PWL_MAX_ERR_PERCENT, PWL_DESIGN_SAMPLES, err_pct);
pwl = pwl_search(activation_type, x_min, x_max, PWL_DESIGN_THRESHOLD, 0.066*PWL_MAX_ERR_PERCENT, PWL_DESIGN_SAMPLES, err_pct);
make_gna_pwl(activation_type, pwl, x_min, x_max, scale_in, scale_out, ptr_segment);
break;
}
case kActExp: {
double x_min = -log(scale_out);
double x_max = x_min + log(INT16_MAX);
pwl = pwl_search(kActExp, x_min, x_max, PWL_DESIGN_THRESHOLD, 0.5*PWL_MAX_ERR_PERCENT, PWL_DESIGN_SAMPLES, err_pct);
pwl = pwl_search(activation_type, x_min, x_max, PWL_DESIGN_THRESHOLD, 0.5*PWL_MAX_ERR_PERCENT, PWL_DESIGN_SAMPLES, err_pct);
make_gna_pwl(activation_type, pwl, x_min, x_max, scale_in, scale_out, ptr_segment);
break;
}
@ -492,6 +561,27 @@ void PwlDesignOpt16(const DnnActivation activation_type,
case kActAbs:
make_gna_pwl(activation_type, pwl, -1.0, 1.0, scale_in, scale_out, ptr_segment);
break;
case kActPow: {
auto fp32eq = [](float p1, float p2) -> bool {
return (std::abs(p1 - p2) <= 0.00001f * std::min(std::abs(p1), std::abs(p2)));
};
auto input_min_value = static_cast<double>(std::numeric_limits<int32_t>::min());
auto input_max_value = static_cast<double>(std::numeric_limits<int32_t>::max());
auto x_min = fp32eq(fmod(activation_type.args.pow.exponent, 1.0), 0.0f) ? input_min_value / scale_in: 0;
x_min = std::max(x_min, -POW_DOMAIN);
auto x_max = input_max_value / scale_in;
x_max = std::min(x_max, POW_DOMAIN);
if (activation_type.args.pow.exponent != 0.0f && activation_type.args.pow.exponent != 1.0f) {
pwl = pwl_search(activation_type, x_min, x_max, PWL_DESIGN_THRESHOLD, 0.015 * PWL_MAX_ERR_PERCENT, PWL_DESIGN_SAMPLES, err_pct);
}
make_gna_pwl(activation_type, pwl, x_min, x_max, scale_in, scale_out, ptr_segment);
break;
}
default:
break;
}
@ -698,6 +788,65 @@ void PwlDesign16(const DnnActivation activation_type,
ptr_segment[2].slope = 0;
}
break;
case kActPow:
{
gnalog() << "=========================== Pow Segments===========================\n";
uint32_t num_segment_size = 0;
auto fp32eq = [](float p1, float p2) -> bool {
return (std::abs(p1 - p2) <= 0.00001f * std::min(std::abs(p1), std::abs(p2)));
};
auto args = std::tuple<double, double, double>{ activation_type.args.pow.exponent,
activation_type.args.pow.scale,
activation_type.args.pow.offset };
auto input_min_value = static_cast<double>(std::numeric_limits<int32_t>::min());
auto input_max_value = static_cast<double>(std::numeric_limits<int32_t>::max());
double x_min = fp32eq(fmod(activation_type.args.pow.exponent, 1.0), 0.0f)? input_min_value / scale_in: 0.0;
x_min = std::max(x_min, -POW_DOMAIN);
double x_max = input_max_value / scale_in;
x_max = std::min(x_max, POW_DOMAIN);
double pow_domain = x_max - x_min;
ptr_segment[0].xBase = static_cast<int32_t>(INT32_MIN & XBASEMASK); // zero out the 2 lsb
num_segment_size = static_cast<int32_t>(pow_domain * scale_in / (num_segments - 2) + 0.5);
int32_t x_min_scaled = x_min * scale_in + 0.5;
int32_t offset = x_min_scaled;
for (uint32_t i = 1; i < num_segments; i++) {
ptr_segment[i].xBase = static_cast<int32_t>(offset & XBASEMASK); // zero out the 2 lsb
offset += num_segment_size;
}
for (uint32_t i = 0; i < num_segments; i++) {
int32_t xbase = static_cast<int32_t>(ptr_segment[i].xBase & XBASEMASK);
int32_t xbasenext = (i < num_segments - 1) ? static_cast<int32_t>(ptr_segment[i + 1].xBase & XBASEMASK) : INT32_MAX;
double arg = xbase / scale_in;
arg = arg < x_min ? x_min : arg;
double argnext = xbasenext / scale_in;
argnext = argnext < x_min ? x_min : argnext;
double val = power(arg, args);
double valnext = power(argnext, args);
double slope = (valnext - val) / (static_cast<double>(xbasenext - xbase) / scale_in);
auto s = gna_slope(slope, scale_in, scale_out);
ptr_segment[i].slope = FLOAT_TO_INT16(s.slope * s.slope_scale);
ptr_segment[i].xBase = ptr_segment[i].xBase | s.slope_scale_index;
ptr_segment[i].yBase = FLOAT_TO_INT16(val * scale_out);
gnalog() << (static_cast<int32_t>((ptr_segment[i].xBase & XBASEMASK)) / scale_out)
<< " "
<< (static_cast<float>((ptr_segment[i].yBase)) / scale_out)
<< " "
<< (s.slope / scale_out)
<< "\n";
}
}
break;
default:
fprintf(stderr, "Activation function design for %s not yet implemented!\n", intel_dnn_activation_name[activation_type]);
throw -1;
@ -816,7 +965,9 @@ void PwlApply32(intel_dnn_component_t *component,
for (uint32_t i = num_row_start; i <= num_row_end; i++) {
for (uint32_t j = num_col_start; j <= num_col_end; j++) {
ptr_out[i * num_columns + j] =
(ptr_in[i * num_columns + j] < 0.0f) ? ptr_in[i * num_columns + j] * transform->func_id.negative_slope : ptr_in[i * num_columns + j];
(ptr_in[i * num_columns + j] < 0.0f) ?
ptr_in[i * num_columns + j] * transform->func_id.args.lrelu.negative_slope :
ptr_in[i * num_columns + j];
}
}
break;
@ -883,6 +1034,17 @@ void PwlApply32(intel_dnn_component_t *component,
}
}
break;
case kActPow: {
float exponent = transform->func_id.args.pow.exponent;
float scale = transform->func_id.args.pow.scale;
float offset = transform->func_id.args.pow.offset;
for (uint32_t i = num_row_start; i <= num_row_end; i++) {
for (uint32_t j = num_col_start; j <= num_col_end; j++) {
ptr_out[i * num_columns + j] = pow(offset + scale * ptr_in[i * num_columns + j], exponent);
}
}
}
break;
case kActCustom:
// break;
default:fprintf(stderr, "Unknown piecewise linear function type!\n");

View File

@ -32,6 +32,9 @@
#define LOG_DOMAIN (2981.0)
#define EXP_DOMAIN (8.0)
#define EXP_BREAK (0.045)
#define POW_NUM_SEGMENTS 65
#define POW_BREAK 0
#define POW_DOMAIN (16.0)
typedef struct {
double t;
@ -61,7 +64,7 @@ double pivot_search(std::vector<pwl_t>& result, double(*f)(const double),
inline std::vector<pwl_t> negative_pwl(const std::vector<pwl_t>& pwl);
std::vector<pwl_t> pwl_search(const DnnActivationType fun,
std::vector<pwl_t> pwl_search(const DnnActivation& activation_type,
const double l_bound,
const double u_bound,
const double threshold,
@ -73,7 +76,7 @@ bool split_search(const DnnActivationType fun,
const double l_bound,
const double u_bound);
double calculate_error_pct(const DnnActivationType fun,
double calculate_error_pct(const DnnActivation& activation_type,
const double l_bound,
const double u_bound,
const double offset,

View File

@ -0,0 +1,44 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include "single_layer_tests/power.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
namespace {
std::vector<std::vector<std::vector<size_t>>> inShapes = {
{{1, 8}},
{{2, 16}},
{{3, 32}},
{{4, 64}},
{{5, 128}},
{{6, 256}},
{{7, 512}},
{{8, 1024}}
};
std::vector<std::vector<float >> Power = {
{0.0f},
{0.5f},
{1.0f},
{1.1f},
{1.5f},
{2.0f},
};
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
};
INSTANTIATE_TEST_CASE_P(power, PowerLayerTest,
::testing::Combine(
::testing::ValuesIn(inShapes),
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_GNA),
::testing::ValuesIn(Power)),
PowerLayerTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,30 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <tuple>
#include <string>
#include <vector>
#include <memory>
#include "functional_test_utils/layer_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
#include "common_test_utils/test_constants.hpp"
namespace LayerTestsDefinitions {
using PowerParamsTuple = typename std::tuple<
std::vector<std::vector<size_t>>, //input shapes
InferenceEngine::Precision, //Network precision
std::string, //Device name
std::vector<float>>; //power
class PowerLayerTest:
public testing::WithParamInterface<PowerParamsTuple>,
public LayerTestsUtils::LayerTestsCommon{
public:
static std::string getTestCaseName(const testing::TestParamInfo<PowerParamsTuple> &obj);
protected:
void SetUp() override;
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,45 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <debug.h>
#include "functional_test_utils/precision_utils.hpp"
#include "functional_test_utils/skip_tests_config.hpp"
#include "single_layer_tests/power.hpp"
namespace LayerTestsDefinitions {
std::string PowerLayerTest::getTestCaseName(const testing::TestParamInfo<PowerParamsTuple> &obj) {
std::vector<std::vector<size_t>> inputShapes;
InferenceEngine::Precision netPrecision;
std::string targetName;
std::vector<float> power;
std::tie(inputShapes, netPrecision, targetName, power) = obj.param;
std::ostringstream results;
results << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_";
results << "Power=" << CommonTestUtils::vec2str(power) << "_";
results << "netPRC=" << netPrecision.name() << "_";
results << "targetDevice=" << targetName << "_";
return results.str();
}
void PowerLayerTest::SetUp() {
threshold = 0.04f;
std::vector<std::vector<size_t>> inputShapes;
InferenceEngine::Precision netPrecision;
std::vector<float> power;
std::tie(inputShapes, netPrecision, targetDevice, power) = this->GetParam();
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto paramsIn = ngraph::builder::makeParams(ngPrc, {inputShapes[0]});
auto power_const = std::make_shared<ngraph::op::Constant>(ngPrc, ngraph::Shape{ 1 }, power);
auto pow = std::make_shared<ngraph::opset1::Power>(paramsIn[0], power_const);
function = std::make_shared<ngraph::Function>(pow, paramsIn, "power");
}
TEST_P(PowerLayerTest, CompareWithRefs){
Run();
};
} // namespace LayerTestsDefinitions