Reverted GetConvData helpers
This commit is contained in:
parent
d78f7e7fc6
commit
8b81867d18
@ -18,6 +18,54 @@ namespace intel_gna {
|
||||
namespace pass {
|
||||
namespace helper {
|
||||
|
||||
void GetConvData(std::shared_ptr<ngraph::opset7::Convolution> conv, ConvData& conv_data) {
|
||||
OPENVINO_ASSERT(conv);
|
||||
conv_data.output_height = conv->get_output_shape(0)[2];
|
||||
conv_data.output_width = conv->get_output_shape(0)[3];
|
||||
conv_data.input_channel_count = conv->input_value(0).get_shape()[1];
|
||||
conv_data.input_height = conv->input_value(0).get_shape()[2];
|
||||
conv_data.input_width = conv->input_value(0).get_shape()[3];
|
||||
conv_data.filter_count = conv->input_value(1).get_shape()[0];
|
||||
conv_data.filter_channel_count = conv->input_value(1).get_shape()[1];
|
||||
conv_data.filter_height = conv->input_value(1).get_shape()[2];
|
||||
conv_data.filter_width = conv->input_value(1).get_shape()[3];
|
||||
conv_data.filter_dilation_height = conv->get_dilations()[0];
|
||||
conv_data.filter_dilation_width = conv->get_dilations()[1];
|
||||
conv_data.filter_stride_height = conv->get_strides()[0];
|
||||
conv_data.filter_stride_width = conv->get_strides()[1];
|
||||
conv_data.output_channel_count = conv_data.filter_count;
|
||||
conv_data.pads_begin_height = conv->get_pads_begin()[0];
|
||||
conv_data.pads_begin_width = conv->get_pads_begin()[1];
|
||||
conv_data.pads_end_height = conv->get_pads_end()[0];
|
||||
conv_data.pads_end_width = conv->get_pads_end()[1];
|
||||
conv_data.padding_type = conv->get_auto_pad();
|
||||
conv_data.element_type = conv->get_element_type();
|
||||
}
|
||||
|
||||
void GetConvData(std::shared_ptr<ov::intel_gna::op::GNAConvolution> conv, ConvData& conv_data) {
|
||||
OPENVINO_ASSERT(conv);
|
||||
conv_data.output_height = conv->get_output_shape(0)[2];
|
||||
conv_data.output_width = conv->get_output_shape(0)[3];
|
||||
conv_data.input_channel_count = conv->input_value(0).get_shape()[3];
|
||||
conv_data.input_height = conv->input_value(0).get_shape()[1];
|
||||
conv_data.input_width = conv->input_value(0).get_shape()[2];
|
||||
conv_data.filter_count = conv->input_value(1).get_shape()[0];
|
||||
conv_data.filter_channel_count = conv->input_value(1).get_shape()[3];
|
||||
conv_data.filter_height = conv->input_value(1).get_shape()[1];
|
||||
conv_data.filter_width = conv->input_value(1).get_shape()[2];
|
||||
conv_data.filter_dilation_height = conv->get_dilations()[0];
|
||||
conv_data.filter_dilation_width = conv->get_dilations()[1];
|
||||
conv_data.filter_stride_height = conv->get_strides()[0];
|
||||
conv_data.filter_stride_width = conv->get_strides()[1];
|
||||
conv_data.output_channel_count = conv_data.filter_count;
|
||||
conv_data.pads_begin_height = conv->get_pads_begin()[0];
|
||||
conv_data.pads_begin_width = conv->get_pads_begin()[1];
|
||||
conv_data.pads_end_height = conv->get_pads_end()[0];
|
||||
conv_data.pads_end_width = conv->get_pads_end()[1];
|
||||
conv_data.padding_type = conv->get_auto_pad();
|
||||
conv_data.element_type = conv->get_element_type();
|
||||
}
|
||||
|
||||
std::function<bool(ngraph::Output<ngraph::Node>)> consumers_and_rank(const size_t expected_count,
|
||||
const ngraph::Dimension& expected_rank) {
|
||||
return [=](ngraph::Output<ngraph::Node> output) -> bool {
|
||||
|
@ -42,30 +42,15 @@ struct ConvData {
|
||||
* @param conv_data convolution data structure to put data into
|
||||
* @return void
|
||||
*/
|
||||
template <class T>
|
||||
void GetConvData(const T& conv, ConvData& conv_data) {
|
||||
OPENVINO_ASSERT(conv);
|
||||
conv_data.output_height = conv->get_output_shape(0)[2];
|
||||
conv_data.output_width = conv->get_output_shape(0)[3];
|
||||
conv_data.input_channel_count = conv->input_value(0).get_shape()[1];
|
||||
conv_data.input_height = conv->input_value(0).get_shape()[2];
|
||||
conv_data.input_width = conv->input_value(0).get_shape()[3];
|
||||
conv_data.filter_count = conv->input_value(1).get_shape()[0];
|
||||
conv_data.filter_channel_count = conv->input_value(1).get_shape()[1];
|
||||
conv_data.filter_height = conv->input_value(1).get_shape()[2];
|
||||
conv_data.filter_width = conv->input_value(1).get_shape()[3];
|
||||
conv_data.filter_dilation_height = conv->get_dilations()[0];
|
||||
conv_data.filter_dilation_width = conv->get_dilations()[1];
|
||||
conv_data.filter_stride_height = conv->get_strides()[0];
|
||||
conv_data.filter_stride_width = conv->get_strides()[1];
|
||||
conv_data.output_channel_count = conv_data.filter_count;
|
||||
conv_data.pads_begin_height = conv->get_pads_begin()[0];
|
||||
conv_data.pads_begin_width = conv->get_pads_begin()[1];
|
||||
conv_data.pads_end_height = conv->get_pads_end()[0];
|
||||
conv_data.pads_end_width = conv->get_pads_end()[1];
|
||||
conv_data.padding_type = conv->get_auto_pad();
|
||||
conv_data.element_type = conv->get_element_type();
|
||||
}
|
||||
void GetConvData(std::shared_ptr<ngraph::opset7::Convolution> conv, ConvData& conv_data);
|
||||
|
||||
/**
|
||||
* @brief gets all convolution related data into a struct for further processing
|
||||
* @param conv GNA custom convolution node to get data of
|
||||
* @param conv_data convolution data structure to put data into
|
||||
* @return void
|
||||
*/
|
||||
void GetConvData(std::shared_ptr<ov::intel_gna::op::GNAConvolution> conv, ConvData& conv_data);
|
||||
|
||||
/**
|
||||
* @brief ngraph matcher predicate fusing existing predicates for consumers count and rank of a layer
|
||||
|
Loading…
Reference in New Issue
Block a user