Implement Bucketize in MO and MKLDNN for opset3 (#583)
This operation is used for Wide and Deep Model Signed-off-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
e025f1464b
commit
958e425775
@ -1774,7 +1774,7 @@ public:
|
|||||||
/**
|
/**
|
||||||
* @brief Indicates whether the intervals include the right or the left bucket edge.
|
* @brief Indicates whether the intervals include the right or the left bucket edge.
|
||||||
*/
|
*/
|
||||||
bool with_right_bound = false;
|
bool with_right_bound = true;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Creates a new BucketizeLayer instance.
|
* @brief Creates a new BucketizeLayer instance.
|
||||||
|
@ -1776,7 +1776,7 @@ void BucketizeValidator::parseParams(CNNLayer* layer) {
|
|||||||
THROW_IE_EXCEPTION << layer->name << " Layer is not instance of Bucketize class";
|
THROW_IE_EXCEPTION << layer->name << " Layer is not instance of Bucketize class";
|
||||||
}
|
}
|
||||||
|
|
||||||
casted->with_right_bound = casted->GetParamAsBool("with_right_bound");
|
casted->with_right_bound = casted->GetParamAsBool("with_right_bound", true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void BucketizeValidator::checkParams(const CNNLayer* layer) {
|
void BucketizeValidator::checkParams(const CNNLayer* layer) {
|
||||||
|
@ -29,25 +29,40 @@ public:
|
|||||||
// check one attribute
|
// check one attribute
|
||||||
with_right = layer->GetParamAsBool("with_right_bound");
|
with_right = layer->GetParamAsBool("with_right_bound");
|
||||||
|
|
||||||
// check precisions for input tensors
|
auto input = layer->insData[INPUT_TENSOR_PORT].lock();
|
||||||
Precision input_tensor_precision = layer->insData[INPUT_TENSOR_PORT].lock()->getTensorDesc().getPrecision();
|
if (!input) {
|
||||||
if (input_tensor_precision != Precision::FP32) {
|
THROW_IE_EXCEPTION << "Missing input for " << layer->name << " layer";
|
||||||
THROW_IE_EXCEPTION << layer->name << " Incorrect input precision of the input. Only FP32 is supported!";
|
|
||||||
}
|
}
|
||||||
if (with_bins) {
|
auto boundaries = layer->insData[INPUT_BINS_PORT].lock();
|
||||||
Precision input_bins_precision = layer->insData[INPUT_BINS_PORT].lock()->getTensorDesc().getPrecision();
|
if (!boundaries) {
|
||||||
if (input_bins_precision != Precision::FP32) {
|
THROW_IE_EXCEPTION << "Missing boundaries input for " << layer->name << " layer";
|
||||||
THROW_IE_EXCEPTION << layer->name
|
}
|
||||||
<< " Incorrect input precision of the boundaries tensor. Only FP32 is supported!";
|
|
||||||
}
|
// check precisions for input and output tensors
|
||||||
|
input_precision = input->getTensorDesc().getPrecision();
|
||||||
|
if (input_precision != Precision::FP32 && input_precision != Precision::I32 &&
|
||||||
|
input_precision != Precision::I64) {
|
||||||
|
THROW_IE_EXCEPTION << layer->name
|
||||||
|
<< " Incorrect input precision of the input. Only FP32, I32 and I64 are supported!";
|
||||||
|
}
|
||||||
|
boundaries_precision = boundaries->getTensorDesc().getPrecision();
|
||||||
|
if (boundaries_precision != Precision::FP32 && boundaries_precision != Precision::I32 &&
|
||||||
|
boundaries_precision != Precision::I64) {
|
||||||
|
THROW_IE_EXCEPTION << layer->name
|
||||||
|
<< " Incorrect input precision of the boundaries tensor. Only FP32, I32 and I64 are supported!";
|
||||||
|
}
|
||||||
|
output_precision = layer->outData[OUTPUT_TENSOR_PORT]->getTensorDesc().getPrecision();
|
||||||
|
if (output_precision != Precision::I32 && output_precision != Precision::I64) {
|
||||||
|
THROW_IE_EXCEPTION << layer->name
|
||||||
|
<< " Incorrect precision of the output tensor. Only I32 and I64 are supported!";
|
||||||
}
|
}
|
||||||
|
|
||||||
// check dimensions of input tensors
|
// check dimensions of input tensors
|
||||||
SizeVector input_tensor_dims = layer->insData[INPUT_TENSOR_PORT].lock()->getTensorDesc().getDims();
|
SizeVector input_tensor_dims = input->getTensorDesc().getDims();
|
||||||
if (input_tensor_dims.size() < 1) {
|
if (input_tensor_dims.size() < 1) {
|
||||||
THROW_IE_EXCEPTION << layer->name << " Incorrect dimensions of the input.";
|
THROW_IE_EXCEPTION << layer->name << " Incorrect dimensions of the input.";
|
||||||
}
|
}
|
||||||
SizeVector input_bin_dims = layer->insData[INPUT_BINS_PORT].lock()->getTensorDesc().getDims();
|
SizeVector input_bin_dims = boundaries->getTensorDesc().getDims();
|
||||||
if (input_bin_dims.size() != 1) {
|
if (input_bin_dims.size() != 1) {
|
||||||
THROW_IE_EXCEPTION << layer->name << " Incorrect dimensions of the boundaries tensor.";
|
THROW_IE_EXCEPTION << layer->name << " Incorrect dimensions of the boundaries tensor.";
|
||||||
}
|
}
|
||||||
@ -56,12 +71,8 @@ public:
|
|||||||
}
|
}
|
||||||
num_bin_values = input_bin_dims[0];
|
num_bin_values = input_bin_dims[0];
|
||||||
|
|
||||||
num_values = 1;
|
num_values = std::accumulate(input_tensor_dims.begin(), input_tensor_dims.end(), 1, std::multiplies<size_t>());
|
||||||
for (size_t ind = 0; ind < input_tensor_dims.size(); ind++) {
|
|
||||||
num_values *= input_tensor_dims[ind];
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: check that dense shape value is set
|
|
||||||
addConfig(layer,
|
addConfig(layer,
|
||||||
{ DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN) },
|
{ DataConfigurator(ConfLayout::PLN), DataConfigurator(ConfLayout::PLN) },
|
||||||
{ DataConfigurator(ConfLayout::PLN) });
|
{ DataConfigurator(ConfLayout::PLN) });
|
||||||
@ -72,46 +83,131 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept override {
|
StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept override {
|
||||||
const float *input_tensor_ptr = inputs[INPUT_TENSOR_PORT]->cbuffer().as<const float *>() +
|
auto precision_mask = getPrecisionMask(input_precision, boundaries_precision, output_precision);
|
||||||
inputs[INPUT_TENSOR_PORT]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
|
||||||
const float *input_bins_ptr = nullptr;
|
|
||||||
if (with_bins) {
|
|
||||||
input_bins_ptr = inputs[INPUT_BINS_PORT]->cbuffer().as<const float *>() +
|
|
||||||
inputs[INPUT_BINS_PORT]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
|
||||||
}
|
|
||||||
int *output_tensor_ptr = outputs[OUTPUT_TENSOR_PORT]->cbuffer().as<int *>() +
|
|
||||||
inputs[OUTPUT_TENSOR_PORT]->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
|
||||||
|
|
||||||
if (with_bins == false) {
|
switch (precision_mask) {
|
||||||
for (size_t ind = 0; ind < num_values; ind++) {
|
case getPrecisionMask(Precision::FP32, Precision::FP32, Precision::I32):
|
||||||
output_tensor_ptr[ind] = 0;
|
bucketize<PrecisionTrait<Precision::FP32>::value_type,
|
||||||
}
|
PrecisionTrait<Precision::FP32>::value_type,
|
||||||
return OK;
|
PrecisionTrait<Precision::I32>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
}
|
break;
|
||||||
|
case getPrecisionMask(Precision::FP32, Precision::FP32, Precision::I64):
|
||||||
for (size_t ind = 0; ind < num_values; ind++) {
|
bucketize<PrecisionTrait<Precision::FP32>::value_type,
|
||||||
float value = input_tensor_ptr[ind];
|
PrecisionTrait<Precision::FP32>::value_type,
|
||||||
|
PrecisionTrait<Precision::I64>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
// find a bin to which value belongs
|
break;
|
||||||
output_tensor_ptr[ind] = -1;
|
case getPrecisionMask(Precision::FP32, Precision::I32, Precision::I32):
|
||||||
for (size_t bin_ind = 0; bin_ind < num_bin_values; bin_ind++) {
|
bucketize<PrecisionTrait<Precision::FP32>::value_type,
|
||||||
if (with_right && value <= input_bins_ptr[bin_ind]) {
|
PrecisionTrait<Precision::I32>::value_type,
|
||||||
output_tensor_ptr[ind] = static_cast<int>(bin_ind);
|
PrecisionTrait<Precision::I32>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
break;
|
break;
|
||||||
} else if (!with_right && value < input_bins_ptr[bin_ind]) {
|
case getPrecisionMask(Precision::FP32, Precision::I32, Precision::I64):
|
||||||
output_tensor_ptr[ind] = static_cast<int>(bin_ind);
|
bucketize<PrecisionTrait<Precision::FP32>::value_type,
|
||||||
break;
|
PrecisionTrait<Precision::I32>::value_type,
|
||||||
}
|
PrecisionTrait<Precision::I64>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
}
|
break;
|
||||||
if (output_tensor_ptr[ind] == -1) {
|
case getPrecisionMask(Precision::FP32, Precision::I64, Precision::I32):
|
||||||
output_tensor_ptr[ind] = static_cast<int>(num_bin_values);
|
bucketize<PrecisionTrait<Precision::FP32>::value_type,
|
||||||
}
|
PrecisionTrait<Precision::I64>::value_type,
|
||||||
|
PrecisionTrait<Precision::I32>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
|
break;
|
||||||
|
case getPrecisionMask(Precision::FP32, Precision::I64, Precision::I64):
|
||||||
|
bucketize<PrecisionTrait<Precision::FP32>::value_type,
|
||||||
|
PrecisionTrait<Precision::I64>::value_type,
|
||||||
|
PrecisionTrait<Precision::I64>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
|
break;
|
||||||
|
case getPrecisionMask(Precision::I32, Precision::FP32, Precision::I32):
|
||||||
|
bucketize<PrecisionTrait<Precision::I32>::value_type,
|
||||||
|
PrecisionTrait<Precision::FP32>::value_type,
|
||||||
|
PrecisionTrait<Precision::I32>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
|
break;
|
||||||
|
case getPrecisionMask(Precision::I32, Precision::FP32, Precision::I64):
|
||||||
|
bucketize<PrecisionTrait<Precision::I32>::value_type,
|
||||||
|
PrecisionTrait<Precision::FP32>::value_type,
|
||||||
|
PrecisionTrait<Precision::I64>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
|
break;
|
||||||
|
case getPrecisionMask(Precision::I32, Precision::I32, Precision::I32):
|
||||||
|
bucketize<PrecisionTrait<Precision::I32>::value_type,
|
||||||
|
PrecisionTrait<Precision::I32>::value_type,
|
||||||
|
PrecisionTrait<Precision::I32>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
|
break;
|
||||||
|
case getPrecisionMask(Precision::I32, Precision::I32, Precision::I64):
|
||||||
|
bucketize<PrecisionTrait<Precision::I32>::value_type,
|
||||||
|
PrecisionTrait<Precision::I32>::value_type,
|
||||||
|
PrecisionTrait<Precision::I64>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
|
break;
|
||||||
|
case getPrecisionMask(Precision::I32, Precision::I64, Precision::I32):
|
||||||
|
bucketize<PrecisionTrait<Precision::I32>::value_type,
|
||||||
|
PrecisionTrait<Precision::I64>::value_type,
|
||||||
|
PrecisionTrait<Precision::I32>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
|
break;
|
||||||
|
case getPrecisionMask(Precision::I32, Precision::I64, Precision::I64):
|
||||||
|
bucketize<PrecisionTrait<Precision::I32>::value_type,
|
||||||
|
PrecisionTrait<Precision::I64>::value_type,
|
||||||
|
PrecisionTrait<Precision::I64>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
|
break;
|
||||||
|
case getPrecisionMask(Precision::I64, Precision::FP32, Precision::I32):
|
||||||
|
bucketize<PrecisionTrait<Precision::I64>::value_type,
|
||||||
|
PrecisionTrait<Precision::FP32>::value_type,
|
||||||
|
PrecisionTrait<Precision::I32>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
|
break;
|
||||||
|
case getPrecisionMask(Precision::I64, Precision::FP32, Precision::I64):
|
||||||
|
bucketize<PrecisionTrait<Precision::I64>::value_type,
|
||||||
|
PrecisionTrait<Precision::FP32>::value_type,
|
||||||
|
PrecisionTrait<Precision::I64>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
|
break;
|
||||||
|
case getPrecisionMask(Precision::I64, Precision::I32, Precision::I32):
|
||||||
|
bucketize<PrecisionTrait<Precision::I64>::value_type,
|
||||||
|
PrecisionTrait<Precision::I32>::value_type,
|
||||||
|
PrecisionTrait<Precision::I32>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
|
break;
|
||||||
|
case getPrecisionMask(Precision::I64, Precision::I32, Precision::I64):
|
||||||
|
bucketize<PrecisionTrait<Precision::I64>::value_type,
|
||||||
|
PrecisionTrait<Precision::I32>::value_type,
|
||||||
|
PrecisionTrait<Precision::I64>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
|
break;
|
||||||
|
case getPrecisionMask(Precision::I64, Precision::I64, Precision::I32):
|
||||||
|
bucketize<PrecisionTrait<Precision::I64>::value_type,
|
||||||
|
PrecisionTrait<Precision::I64>::value_type,
|
||||||
|
PrecisionTrait<Precision::I32>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
|
break;
|
||||||
|
case getPrecisionMask(Precision::I64, Precision::I64, Precision::I64):
|
||||||
|
bucketize<PrecisionTrait<Precision::I64>::value_type,
|
||||||
|
PrecisionTrait<Precision::I64>::value_type,
|
||||||
|
PrecisionTrait<Precision::I64>::value_type>(inputs[0], inputs[1], outputs[0]);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return GENERAL_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
return OK;
|
return OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
template <typename T, typename T_BOUNDARIES, typename T_IND>
|
||||||
|
void bucketize(Blob::Ptr input, Blob::Ptr boundaries, Blob::Ptr output) {
|
||||||
|
const auto *input_data = input->cbuffer().as<const T *>();
|
||||||
|
const auto *boundaries_data = boundaries->cbuffer().as<const T_BOUNDARIES *>();
|
||||||
|
auto *output_data = output->buffer().as<T_IND *>();
|
||||||
|
|
||||||
|
if (with_bins == false) {
|
||||||
|
memset(output_data, 0, num_values * sizeof(T_IND));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// boundaries are assumed to be sorted and to have unique elements
|
||||||
|
parallel_for(num_values, [&](size_t ind) {
|
||||||
|
T value = input_data[ind];
|
||||||
|
if (with_right) {
|
||||||
|
auto low = std::lower_bound(boundaries_data, boundaries_data + num_bin_values, value);
|
||||||
|
output_data[ind] = static_cast<T_IND>(low - boundaries_data);
|
||||||
|
} else {
|
||||||
|
auto up = std::upper_bound(boundaries_data, boundaries_data + num_bin_values, value);
|
||||||
|
output_data[ind] = static_cast<T_IND>(up - boundaries_data);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
const size_t INPUT_TENSOR_PORT = 0;
|
const size_t INPUT_TENSOR_PORT = 0;
|
||||||
const size_t INPUT_BINS_PORT = 1;
|
const size_t INPUT_BINS_PORT = 1;
|
||||||
const size_t OUTPUT_TENSOR_PORT = 0;
|
const size_t OUTPUT_TENSOR_PORT = 0;
|
||||||
@ -120,6 +216,10 @@ private:
|
|||||||
size_t num_bin_values = 0;
|
size_t num_bin_values = 0;
|
||||||
bool with_right = false;
|
bool with_right = false;
|
||||||
bool with_bins = false;
|
bool with_bins = false;
|
||||||
|
|
||||||
|
Precision input_precision;
|
||||||
|
Precision boundaries_precision;
|
||||||
|
Precision output_precision;
|
||||||
};
|
};
|
||||||
|
|
||||||
REG_FACTORY_FOR(BucketizeImpl, Bucketize);
|
REG_FACTORY_FOR(BucketizeImpl, Bucketize);
|
||||||
|
@ -936,6 +936,7 @@ mo/utils/ir_engine/ir_engine.py
|
|||||||
mo/utils/ir_reader/__init__.py
|
mo/utils/ir_reader/__init__.py
|
||||||
mo/utils/ir_reader/extender.py
|
mo/utils/ir_reader/extender.py
|
||||||
mo/utils/ir_reader/extenders/binary_convolution_extender.py
|
mo/utils/ir_reader/extenders/binary_convolution_extender.py
|
||||||
|
mo/utils/ir_reader/extenders/bucketize_extender.py
|
||||||
mo/utils/ir_reader/extenders/conv_extender.py
|
mo/utils/ir_reader/extenders/conv_extender.py
|
||||||
mo/utils/ir_reader/extenders/convert_extender.py
|
mo/utils/ir_reader/extenders/convert_extender.py
|
||||||
mo/utils/ir_reader/extenders/deconvolution_extender.py
|
mo/utils/ir_reader/extenders/deconvolution_extender.py
|
||||||
|
@ -27,5 +27,5 @@ class BucketizeFrontExtractor(FrontExtractorOp):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def extract(cls, node):
|
def extract(cls, node):
|
||||||
boundaries = np.array(node.pb.attr['boundaries'].list.f, dtype=np.float)
|
boundaries = np.array(node.pb.attr['boundaries'].list.f, dtype=np.float)
|
||||||
Bucketize.update_node_stat(node, {'boundaries': boundaries, 'with_right_bound': False})
|
Bucketize.update_node_stat(node, {'boundaries': boundaries, 'with_right_bound': False, 'output_type': np.int32})
|
||||||
return cls.enabled
|
return cls.enabled
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mo.graph.graph import Node, Graph
|
from mo.graph.graph import Node, Graph
|
||||||
|
from mo.middle.passes.convert_data_type import np_data_type_to_destination_type
|
||||||
from mo.ops.op import Op
|
from mo.ops.op import Op
|
||||||
|
|
||||||
|
|
||||||
@ -28,28 +29,52 @@ class Bucketize(Op):
|
|||||||
'kind': 'op',
|
'kind': 'op',
|
||||||
'type': __class__.op,
|
'type': __class__.op,
|
||||||
'op': __class__.op,
|
'op': __class__.op,
|
||||||
'version': 'extension',
|
'version': 'opset3',
|
||||||
|
|
||||||
'type_infer': self.type_infer,
|
'type_infer': self.type_infer,
|
||||||
'infer': self.infer,
|
'infer': self.infer,
|
||||||
|
|
||||||
'in_ports_count': 2,
|
'in_ports_count': 2,
|
||||||
'out_ports_count': 1,
|
'out_ports_count': 1,
|
||||||
}
|
}
|
||||||
super().__init__(graph, mandatory_props, attrs)
|
super().__init__(graph, mandatory_props, attrs)
|
||||||
|
|
||||||
def supported_attrs(self):
|
def backend_attrs(self):
|
||||||
return ["with_right_bound"]
|
version = self.get_opset()
|
||||||
|
if version == "extension":
|
||||||
|
return ['with_right_bound']
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
'with_right_bound',
|
||||||
|
('output_type', lambda node: np_data_type_to_destination_type(node.output_type)),
|
||||||
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def type_infer(node):
|
def type_infer(node):
|
||||||
# the output is always integer since the layer outputs a bucket index
|
# the output is always integer since the layer outputs a bucket index
|
||||||
node.out_port(0).set_data_type(np.int32)
|
if node.get_opset() == "extension":
|
||||||
|
node.out_port(0).set_data_type(np.int32)
|
||||||
|
else:
|
||||||
|
assert node.output_type in [np.int64, np.int32], \
|
||||||
|
'Bucketize `output_type` attribute must be int32 or int64, `{}` found'.format(np.dtype(node.output_type).name)
|
||||||
|
node.out_port(0).set_data_type(node.output_type)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def infer(node: Node):
|
def infer(node: Node):
|
||||||
|
node_name = node.soft_get('name', node.id)
|
||||||
assert node.with_right_bound is not None, \
|
assert node.with_right_bound is not None, \
|
||||||
"Attribute \"with_right_bound\" is not defined"
|
"Attribute \"with_right_bound\" is not defined"
|
||||||
assert len(node.in_nodes()) == 2, \
|
assert len(node.in_nodes()) == 2, \
|
||||||
"Incorrect number of inputs for {} node".format(node.id)
|
"Incorrect number of inputs for {} node".format(node.id)
|
||||||
|
if node.get_opset() == "extension":
|
||||||
|
output_type = np.int32
|
||||||
|
else:
|
||||||
|
assert node.has_valid('output_type'), \
|
||||||
|
'`output_type` attribute is not set for Bucketize node `{}`'.format(node_name)
|
||||||
|
assert node.output_type in [np.int64, np.int32], \
|
||||||
|
'Bucketize `output_type` attribute must be int32 or int64, `{}` found'.format(np.dtype(node.output_type).name)
|
||||||
|
output_type = node.output_type
|
||||||
|
|
||||||
output_shape = node.in_port(0).data.get_shape()
|
output_shape = node.in_port(0).data.get_shape()
|
||||||
node.out_port(0).data.set_shape(output_shape)
|
node.out_port(0).data.set_shape(output_shape)
|
||||||
|
|
||||||
@ -58,4 +83,4 @@ class Bucketize(Op):
|
|||||||
|
|
||||||
# compute if all input is constant
|
# compute if all input is constant
|
||||||
if input_value is not None and buckets_value is not None:
|
if input_value is not None and buckets_value is not None:
|
||||||
node.out_port(0).data.set_value(np.digitize(input_value, buckets_value, right=node.with_right_bound))
|
node.out_port(0).data.set_value(np.array(np.digitize(input_value, buckets_value, right=node.with_right_bound), dtype=node.output_type))
|
||||||
|
@ -25,7 +25,7 @@ from mo.utils.unittest.graph import build_graph
|
|||||||
|
|
||||||
nodes_attributes = {'input_tensor': {'shape': None, 'value': None, 'kind': 'data'},
|
nodes_attributes = {'input_tensor': {'shape': None, 'value': None, 'kind': 'data'},
|
||||||
'input_buckets': {'shape': None, 'value': None, 'kind': 'data'},
|
'input_buckets': {'shape': None, 'value': None, 'kind': 'data'},
|
||||||
'bucketize_node': {'op': 'Bucketize', 'kind': 'op', 'with_right_bound': False},
|
'bucketize_node': {'op': 'Bucketize', 'kind': 'op', 'with_right_bound': False, 'output_type': np.int32},
|
||||||
'output': {'shape': None, 'value': None, 'kind': 'data'}}
|
'output': {'shape': None, 'value': None, 'kind': 'data'}}
|
||||||
|
|
||||||
# graph 1
|
# graph 1
|
||||||
|
@ -0,0 +1,28 @@
|
|||||||
|
"""
|
||||||
|
Copyright (C) 2020 Intel Corporation
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
from mo.middle.passes.convert_data_type import destination_type_to_np_data_type
|
||||||
|
|
||||||
|
from mo.utils.graph import Node
|
||||||
|
from mo.utils.ir_reader.extender import Extender
|
||||||
|
|
||||||
|
|
||||||
|
class BucketizeExtender(Extender):
|
||||||
|
op = 'Bucketize'
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extend(op: Node):
|
||||||
|
if op.get_opset() != "extension":
|
||||||
|
op['output_type'] = destination_type_to_np_data_type(op.output_type)
|
Loading…
Reference in New Issue
Block a user