[GPU] Added operator== for cldnn primitives (#15736)

This commit is contained in:
Roman Lyamin 2023-02-17 19:09:12 +04:00 committed by GitHub
parent 59542d5cd3
commit efb51b058c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
88 changed files with 1313 additions and 30 deletions

View File

@ -125,6 +125,18 @@ struct activation : public primitive_base<activation> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const activation>(rhs);
return activation_function == rhs_casted.activation_function &&
additional_params.a == rhs_casted.additional_params.a &&
additional_params.b == rhs_casted.additional_params.b &&
additional_params_input.empty() == rhs_casted.additional_params_input.empty();
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
if (additional_params_input.empty())

View File

@ -58,6 +58,17 @@ struct adaptive_pooling : public primitive_base<adaptive_pooling> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const adaptive_pooling>(rhs);
return mode == rhs_casted.mode &&
indices_output == rhs_casted.indices_output &&
index_element_type == rhs_casted.index_element_type;
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;

View File

@ -82,5 +82,18 @@ struct arg_max_min : public primitive_base<arg_max_min> {
seed = hash_combine(seed, values_first);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const arg_max_min>(rhs);
return mode == rhs_casted.mode &&
top_k == rhs_casted.top_k &&
axis == rhs_casted.axis &&
sort == rhs_casted.sort &&
values_first == rhs_casted.values_first;
}
};
} // namespace cldnn

View File

@ -21,14 +21,23 @@ struct assign : public primitive_base<assign> {
/// @param variable_id Variable id
/// @param output_layout Memory layout
assign(const primitive_id &id,
const std::vector<input_info>& inputs,
const std::string& variable_id,
const layout& output_layout)
: primitive_base(id, inputs, {padding()}, {optional_data_type{output_layout.data_type}}),
variable_id{variable_id},
output_layout{output_layout} {}
const std::vector<input_info>& inputs,
const std::string& variable_id,
const layout& output_layout)
: primitive_base(id, inputs, {padding()}, {optional_data_type{output_layout.data_type}}),
variable_id{variable_id},
output_layout{output_layout} {}
std::string variable_id;
layout output_layout;
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const assign>(rhs);
return variable_id == rhs_casted.variable_id;
}
};
} // namespace cldnn

View File

@ -71,5 +71,16 @@ struct batch_to_space : public primitive_base<batch_to_space> {
seed = hash_combine(seed, crops_end.hash());
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const batch_to_space>(rhs);
return block_shape == rhs_casted.block_shape &&
crops_begin == rhs_casted.crops_begin &&
crops_end == rhs_casted.crops_end;
}
};
} // namespace cldnn

View File

@ -76,6 +76,20 @@ struct binary_convolution : public primitive_base<binary_convolution> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const binary_convolution>(rhs);
return pad == rhs_casted.pad &&
stride == rhs_casted.stride &&
dilation == rhs_casted.dilation &&
groups == rhs_casted.groups &&
pad_value == rhs_casted.pad_value &&
weights.size() == rhs_casted.weights.size();
}
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;
ret.reserve(weights.size());

View File

@ -80,5 +80,17 @@ struct border : public primitive_base<border> {
seed = hash_combine(seed, pad_value);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const border>(rhs);
return pads_begin == rhs_casted.pads_begin &&
pads_end == rhs_casted.pads_end &&
pad_mode == rhs_casted.pad_mode &&
pad_value == rhs_casted.pad_value;
}
};
} // namespace cldnn

View File

@ -135,5 +135,16 @@ struct broadcast : public primitive_base<broadcast> {
seed = hash_range(seed, axes_mapping.begin(), axes_mapping.end());
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const broadcast>(rhs);
return axes_mapping == rhs_casted.axes_mapping &&
broadcast_mode == rhs_casted.broadcast_mode &&
broadcast_sizes == rhs_casted.broadcast_sizes;
}
};
} // namespace cldnn

View File

@ -31,6 +31,15 @@ struct bucketize : primitive_base<bucketize> {
seed = hash_combine(seed, with_right_bound);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const bucketize>(rhs);
return with_right_bound == rhs_casted.with_right_bound;
}
};
} // namespace cldnn

View File

@ -64,5 +64,14 @@ struct concatenation : public primitive_base<concatenation> {
seed = hash_combine(seed, axis);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const concatenation>(rhs);
return axis == rhs_casted.axis;
}
};
} // namespace cldnn

View File

@ -26,9 +26,9 @@ struct condition : public primitive_base<condition> {
/// @param id An identifier of new primitive.
/// @param input An identifier of primitive which is an input for newly created
/// condition primitive.
/// @param topology_true Topolgoy containg primitives, which will be executed when comparsion results
/// @param topology_true Topology containg primitives, which will be executed when comparsion results
/// true.
/// @param topology_false Topolgoy containg primitives, which will be executed when comparsion results
/// @param topology_false Topology containg primitives, which will be executed when comparsion results
/// false..
/// @param compare_Data An identifier of primitive which contains compare values
/// @param func Used function during comparison.

View File

@ -58,5 +58,17 @@ struct convert_color : public primitive_base<convert_color> {
seed = hash_combine(seed, mem_type);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const convert_color>(rhs);
return input_color_format == rhs_casted.input_color_format &&
output_color_format == rhs_casted.output_color_format &&
mem_type == rhs_casted.mem_type &&
output_layout == rhs_casted.output_layout;
}
};
} // namespace cldnn

View File

@ -781,6 +781,31 @@ struct convolution : public primitive_base<convolution> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const convolution>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(pad) &&
cmp_fields(stride) &&
cmp_fields(dilation) &&
cmp_fields(groups) &&
cmp_fields(deformable_groups) &&
cmp_fields(padding_above) &&
cmp_fields(padding_below) &&
cmp_fields(deformable_mode) &&
cmp_fields(bilinear_interpolation_pad) &&
cmp_fields(grouped_weights_shape) &&
cmp_fields(weights.size()) &&
cmp_fields(bias.size()) &&
cmp_fields(weights_zero_points.size()) &&
cmp_fields(activations_zero_points.size()) &&
cmp_fields(compensation.size());
#undef cmp_fields
}
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;
ret.reserve(weights.size() + bias.size() + weights_zero_points.size() +
@ -858,6 +883,25 @@ struct deformable_interp : public primitive_base<deformable_interp> {
seed = cldnn::hash_combine(seed, bilinear_interpolation_pad);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const deformable_interp>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(pad) &&
cmp_fields(stride) &&
cmp_fields(dilation) &&
cmp_fields(kernel_size) &&
cmp_fields(groups) &&
cmp_fields(deformable_groups) &&
cmp_fields(padding_above) &&
cmp_fields(padding_below) &&
cmp_fields(bilinear_interpolation_pad);
#undef cmp_fields
}
};
struct deformable_conv : public primitive_base<deformable_conv> {
@ -893,6 +937,17 @@ struct deformable_conv : public primitive_base<deformable_conv> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const deformable_conv>(rhs);
return groups == rhs_casted.groups &&
weights.size() == rhs_casted.weights.size() &&
bias.size() == rhs_casted.bias.size();
}
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;
ret.reserve(weights.size() + bias.size());

View File

@ -134,5 +134,18 @@ struct crop : public primitive_base<crop> {
seed = hash_combine(seed, op_mode);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const crop>(rhs);
return reference_input == rhs_casted.reference_input &&
offsets == rhs_casted.offsets &&
output_idx == rhs_casted.output_idx &&
num_splits == rhs_casted.num_splits &&
op_mode == rhs_casted.op_mode;
}
};
} // namespace cldnn

View File

@ -39,5 +39,16 @@ struct ctc_greedy_decoder : public primitive_base<ctc_greedy_decoder> {
seed = hash_combine(seed, second_output.empty());
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const ctc_greedy_decoder>(rhs);
return blank_index == rhs_casted.blank_index &&
ctc_merge_repeated == rhs_casted.ctc_merge_repeated &&
second_output.empty() == rhs_casted.second_output.empty();
}
};
} // namespace cldnn

View File

@ -41,6 +41,17 @@ struct ctc_loss : primitive_base<ctc_loss> {
seed = hash_combine(seed, unique);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const ctc_loss>(rhs);
return preprocess_collapse_repeated == rhs_casted.preprocess_collapse_repeated &&
ctc_merge_repeated == rhs_casted.ctc_merge_repeated &&
unique == rhs_casted.unique;
}
};
} // namespace cldnn

View File

@ -40,5 +40,16 @@ struct cum_sum : public primitive_base<cum_sum> {
seed = hash_combine(seed, reverse);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const cum_sum>(rhs);
return axis == rhs_casted.axis &&
exclusive == rhs_casted.exclusive &&
reverse == rhs_casted.reverse;
}
};
} // namespace cldnn

View File

@ -80,5 +80,16 @@ struct custom_gpu_primitive : public primitive_base<custom_gpu_primitive> {
seed = hash_combine(seed, kernels_code.size());
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const custom_gpu_primitive>(rhs);
return kernel_entry_point == rhs_casted.kernel_entry_point &&
build_options == rhs_casted.build_options &&
kernels_code.size() == rhs_casted.kernels_code.size();
}
};
} // namespace cldnn

View File

@ -387,6 +387,27 @@ struct deconvolution : public primitive_base<deconvolution> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const deconvolution>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(pad) &&
cmp_fields(stride) &&
cmp_fields(dilations) &&
cmp_fields(groups) &&
cmp_fields(pads_begin) &&
cmp_fields(pads_end) &&
cmp_fields(out_padding) &&
cmp_fields(grouped_weights_shape) &&
cmp_fields(weights.size()) &&
cmp_fields(bias.size()) &&
cmp_fields(output_shape_id.empty());
#undef cmp_fields
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;

View File

@ -45,5 +45,15 @@ struct depth_to_space : public primitive_base<depth_to_space> {
seed = hash_combine(seed, mode);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const depth_to_space>(rhs);
return block_size == rhs_casted.block_size &&
mode == rhs_casted.mode;
}
};
} // namespace cldnn

View File

@ -143,6 +143,34 @@ struct detection_output : public primitive_base<detection_output> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const detection_output>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(num_classes) &&
cmp_fields(keep_top_k) &&
cmp_fields(share_location) &&
cmp_fields(background_label_id) &&
cmp_fields(nms_threshold) &&
cmp_fields(top_k) &&
cmp_fields(eta) &&
cmp_fields(code_type) &&
cmp_fields(variance_encoded_in_target) &&
cmp_fields(confidence_threshold) &&
cmp_fields(prior_info_size) &&
cmp_fields(prior_coordinates_offset) &&
cmp_fields(prior_is_normalized) &&
cmp_fields(input_width) &&
cmp_fields(input_height) &&
cmp_fields(decrease_label_id) &&
cmp_fields(clip_before_nms) &&
cmp_fields(clip_after_nms);
#undef cmp_fields
}
protected:
};

View File

@ -64,6 +64,18 @@ struct dft : public primitive_base<dft> {
seed = hash_combine(seed, mode);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const dft>(rhs);
return axes == rhs_casted.axes &&
signal_size == rhs_casted.signal_size &&
direction == rhs_casted.direction &&
mode == rhs_casted.mode;
}
};
} // namespace cldnn

View File

@ -177,5 +177,17 @@ struct eltwise : public primitive_base<eltwise> {
}
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const eltwise>(rhs);
return mode == rhs_casted.mode &&
coefficients == rhs_casted.coefficients &&
broadcast_spec == rhs_casted.broadcast_spec &&
stride == rhs_casted.stride;
}
};
} // namespace cldnn

View File

@ -45,5 +45,15 @@ struct embedding_bag : public primitive_base<embedding_bag> {
seed = hash_combine(seed, default_index);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const embedding_bag>(rhs);
return type == rhs_casted.type &&
default_index == rhs_casted.default_index;
}
};
} // namespace cldnn

View File

@ -86,6 +86,26 @@ struct experimental_detectron_detection_output : public primitive_base<experimen
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const experimental_detectron_detection_output>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(score_threshold) &&
cmp_fields(nms_threshold) &&
cmp_fields(num_classes) &&
cmp_fields(post_nms_count) &&
cmp_fields(max_detections_per_image) &&
cmp_fields(class_agnostic_box_regression) &&
cmp_fields(max_delta_log_wh) &&
cmp_fields(deltas_weights) &&
cmp_fields(output_classes.empty()) &&
cmp_fields(output_scores.empty());
#undef cmp_fields
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;

View File

@ -58,6 +58,19 @@ struct experimental_detectron_generate_proposals_single_image
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const experimental_detectron_generate_proposals_single_image>(rhs);
return min_size == rhs_casted.min_size &&
nms_threshold == rhs_casted.nms_threshold &&
pre_nms_count == rhs_casted.pre_nms_count &&
post_nms_count == rhs_casted.post_nms_count &&
output_roi_scores.empty() == rhs_casted.output_roi_scores.empty();
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;

View File

@ -60,6 +60,23 @@ struct experimental_detectron_prior_grid_generator
seed = hash_combine(seed, image_width);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const experimental_detectron_prior_grid_generator>(rhs);
return flatten == rhs_casted.flatten &&
h == rhs_casted.h &&
w == rhs_casted.w &&
stride_x == rhs_casted.stride_x &&
stride_y == rhs_casted.stride_y &&
featmap_height == rhs_casted.featmap_height &&
featmap_width == rhs_casted.featmap_width &&
image_height == rhs_casted.image_height &&
image_width == rhs_casted.image_width;
}
};
} // namespace cldnn

View File

@ -51,6 +51,20 @@ struct experimental_detectron_roi_feature_extractor : public primitive_base<expe
seed = hash_combine(seed, aligned);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const experimental_detectron_roi_feature_extractor>(rhs);
return output_dim == rhs_casted.output_dim &&
pooled_height == rhs_casted.pooled_height &&
pooled_width == rhs_casted.pooled_width &&
pyramid_scales == rhs_casted.pyramid_scales &&
sampling_ratio == rhs_casted.sampling_ratio &&
aligned == rhs_casted.aligned;
}
};
} // namespace cldnn

View File

@ -37,6 +37,15 @@ struct experimental_detectron_topk_rois : public primitive_base<experimental_det
seed = hash_combine(seed, max_rois);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const experimental_detectron_topk_rois>(rhs);
return max_rois == rhs_casted.max_rois;
}
};
} // namespace cldnn

View File

@ -63,5 +63,17 @@ struct extract_image_patches : public primitive_base<extract_image_patches> {
seed = hash_combine(seed, auto_pad);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const extract_image_patches>(rhs);
return sizes == rhs_casted.sizes &&
strides == rhs_casted.strides &&
rates == rhs_casted.rates &&
auto_pad == rhs_casted.auto_pad;
}
};
} // namespace cldnn

View File

@ -36,5 +36,14 @@ struct eye : public primitive_base<eye> {
seed = hash_combine(seed, shift);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const eye>(rhs);
return shift == rhs_casted.shift;
}
};
} // namespace cldnn

View File

@ -82,6 +82,16 @@ struct fully_connected : public primitive_base<fully_connected> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const fully_connected>(rhs);
return input_size == rhs_casted.input_size &&
bias.empty() == rhs_casted.bias.empty();
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;

View File

@ -52,5 +52,16 @@ struct gather : public primitive_base<gather> {
seed = hash_combine(seed, support_neg_ind);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const gather>(rhs);
return axis == rhs_casted.axis &&
batch_dim == rhs_casted.batch_dim &&
support_neg_ind == rhs_casted.support_neg_ind;
}
};
} // namespace cldnn

View File

@ -49,5 +49,15 @@ struct gather_elements : public primitive_base<gather_elements> {
seed = hash_combine(seed, axis);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const gather_elements>(rhs);
return output_format == rhs_casted.output_format &&
axis == rhs_casted.axis;
}
};
} // namespace cldnn

View File

@ -57,5 +57,17 @@ struct gather_nd : public primitive_base<gather_nd> {
seed = hash_combine(seed, batch_merged_output);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const gather_nd>(rhs);
return input_rank == rhs_casted.input_rank &&
indices_rank == rhs_casted.indices_rank &&
batch_dims == rhs_casted.batch_dims &&
batch_merged_output == rhs_casted.batch_merged_output;
}
};
} // namespace cldnn

View File

@ -13,22 +13,26 @@ namespace cldnn {
struct gather_tree : public primitive_base<gather_tree> {
CLDNN_DECLARE_PRIMITIVE(gather_tree)
/// @brief Constructs gather tree primitive / layer.
///
/// @param id An identifier of new primitive.
/// @param step_input An identifier of primitive which is an step input
/// @param parent_input An identifier of primitive which is an parent input
/// @param step_seq_len_input An identifier of primitive which is an input that contains
/// lengths of step sequence (per batch) to perform
/// @param end_token An identifier of primitive which is an input that contains
/// a value of the end_token
/// @param output_padding Optional padding for output from primitive
gather_tree(const primitive_id& id,
const input_info& step_input,
const input_info& parent_input,
const input_info& max_seq_len_input,
const input_info& end_token,
const padding& output_padding = padding())
: primitive_base(id, { step_input, parent_input, max_seq_len_input, end_token }, {output_padding}) {}
/// @brief Constructs gather tree primitive / layer.
///
/// @param id An identifier of new primitive.
/// @param step_input An identifier of primitive which is an step input
/// @param parent_input An identifier of primitive which is an parent input
/// @param step_seq_len_input An identifier of primitive which is an input that contains
/// lengths of step sequence (per batch) to perform
/// @param end_token An identifier of primitive which is an input that contains
/// a value of the end_token
/// @param output_padding Optional padding for output from primitive
gather_tree(const primitive_id& id,
const input_info& step_input,
const input_info& parent_input,
const input_info& max_seq_len_input,
const input_info& end_token,
const padding& output_padding = padding())
: primitive_base(id, { step_input, parent_input, max_seq_len_input, end_token }, {output_padding}) {}
bool operator==(const primitive& rhs) const override {
return compare_common_params(rhs);
}
};
} // namespace cldnn

View File

@ -76,6 +76,20 @@ struct gemm : public primitive_base<gemm> {
seed = hash_combine(seed, beta);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const gemm>(rhs);
return transpose_input0 == rhs_casted.transpose_input0 &&
transpose_input1 == rhs_casted.transpose_input1 &&
alpha == rhs_casted.alpha &&
beta == rhs_casted.beta &&
input_rank == rhs_casted.input_rank &&
weight_rank == rhs_casted.weight_rank;
}
};
} // namespace cldnn

View File

@ -73,6 +73,25 @@ struct generate_proposals
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const generate_proposals>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(min_size) &&
cmp_fields(nms_threshold) &&
cmp_fields(pre_nms_count) &&
cmp_fields(post_nms_count) &&
cmp_fields(normalized) &&
cmp_fields(nms_eta) &&
cmp_fields(roi_num_type) &&
cmp_fields(output_rois_scores.empty()) &&
cmp_fields(output_rois_num.empty());
#undef cmp_fields
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;

View File

@ -38,6 +38,17 @@ struct grid_sample : primitive_base<grid_sample> {
seed = hash_combine(seed, attributes.padding_mode);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const grid_sample>(rhs);
return attributes.align_corners == rhs_casted.attributes.align_corners &&
attributes.mode == rhs_casted.attributes.mode &&
attributes.padding_mode == rhs_casted.attributes.padding_mode;
}
};
} // namespace cldnn

View File

@ -32,5 +32,14 @@ struct grn : public primitive_base<grn> {
seed = hash_combine(seed, bias);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const grn>(rhs);
return bias == rhs_casted.bias;
}
};
} // namespace cldnn

View File

@ -70,5 +70,18 @@ struct lrn : public primitive_base<lrn> {
seed = hash_combine(seed, norm_region);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const lrn>(rhs);
return size == rhs_casted.size &&
k == rhs_casted.k &&
alpha == rhs_casted.alpha &&
beta == rhs_casted.beta &&
norm_region == rhs_casted.norm_region;
}
};
} // namespace cldnn

View File

@ -143,6 +143,32 @@ struct lstm : public primitive_base<lstm> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const lstm>(rhs);
bool act_params_eq = activation_params.size() == rhs_casted.activation_params.size();
for (size_t i = 0; i < activation_params.size(); ++i) {
act_params_eq &= activation_params[i].a == rhs_casted.activation_params[i].a &&
activation_params[i].b == rhs_casted.activation_params[i].b;
}
#define cmp_fields(name) name == rhs_casted.name
return act_params_eq &&
cmp_fields(clip) &&
cmp_fields(input_forget) &&
cmp_fields(activations) &&
cmp_fields(output_selection) &&
cmp_fields(offset_order) &&
cmp_fields(initial_hidden.empty()) &&
cmp_fields(initial_cell.empty()) &&
cmp_fields(peepholes.empty()) &&
cmp_fields(bias.empty());
#undef cmp_fields
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;
@ -205,6 +231,17 @@ struct lstm_gemm : public primitive_base<lstm_gemm> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const lstm_gemm>(rhs);
return direction == rhs_casted.direction &&
bias.empty() == rhs_casted.bias.empty() &&
hidden.empty() == rhs_casted.hidden.empty();
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;
@ -282,6 +319,29 @@ struct lstm_elt : public primitive_base<lstm_elt> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const lstm_elt>(rhs);
bool act_params_eq = activation_params.size() == rhs_casted.activation_params.size();
for (size_t i = 0; i < activation_params.size(); ++i) {
act_params_eq &= activation_params[i].a == rhs_casted.activation_params[i].a &&
activation_params[i].b == rhs_casted.activation_params[i].b;
}
#define cmp_fields(name) name == rhs_casted.name
return act_params_eq &&
cmp_fields(clip) &&
cmp_fields(input_forget) &&
cmp_fields(activations) &&
cmp_fields(offset_order) &&
cmp_fields(direction) &&
cmp_fields(cell.empty());
#undef cmp_fields
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;

View File

@ -91,6 +91,23 @@ struct lstm_dynamic : public primitive_base<lstm_dynamic> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const lstm_dynamic>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(clip) &&
cmp_fields(input_forget) &&
cmp_fields(last_hidden_state.empty()) &&
cmp_fields(last_cell_state.empty()) &&
cmp_fields(initial_hidden.empty()) &&
cmp_fields(initial_cell.empty()) &&
cmp_fields(bias.empty());
#undef cmp_fields
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;

View File

@ -48,6 +48,15 @@ struct lstm_dynamic_input : public primitive_base<lstm_dynamic_input> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const lstm_dynamic_input>(rhs);
return bias.empty() == rhs_casted.bias.empty();
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;

View File

@ -79,6 +79,22 @@ struct lstm_dynamic_timeloop
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const lstm_dynamic_timeloop>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(clip) &&
cmp_fields(input_forget) &&
cmp_fields(last_hidden_state.empty()) &&
cmp_fields(last_cell_state.empty()) &&
cmp_fields(initial_hidden.empty()) &&
cmp_fields(initial_cell.empty());
#undef cmp_fields
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;

View File

@ -133,6 +133,26 @@ struct matrix_nms : public primitive_base<matrix_nms> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const matrix_nms>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(attribs.sort_type) &&
cmp_fields(attribs.sort_result_across_batch) &&
cmp_fields(attribs.score_threshold) &&
cmp_fields(attribs.nms_top_k) &&
cmp_fields(attribs.keep_top_k) &&
cmp_fields(attribs.background_class) &&
cmp_fields(attribs.decay) &&
cmp_fields(attribs.gaussian_sigma) &&
cmp_fields(attribs.post_threshold) &&
cmp_fields(attribs.normalized);
#undef cmp_fields
}
private:
static cldnn::matrix_nms::decay_function from(ngraph::op::v8::MatrixNms::DecayFunction decay) {
switch (decay) {

View File

@ -141,6 +141,27 @@ struct multiclass_nms : public primitive_base<multiclass_nms> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const multiclass_nms>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(has_roisnum) &&
cmp_fields(attrs.background_class) &&
cmp_fields(attrs.indices_output_type) &&
cmp_fields(attrs.iou_threshold) &&
cmp_fields(attrs.keep_top_k) &&
cmp_fields(attrs.nms_eta) &&
cmp_fields(attrs.nms_top_k) &&
cmp_fields(attrs.normalized) &&
cmp_fields(attrs.score_threshold) &&
cmp_fields(attrs.sort_result) &&
cmp_fields(attrs.sort_result_across_batch);
#undef cmp_fields
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;

View File

@ -49,5 +49,17 @@ struct mvn : public primitive_base<mvn> {
seed = hash_combine(seed, across_channels);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const mvn>(rhs);
return normalize_variance == rhs_casted.normalize_variance &&
epsilon == rhs_casted.epsilon &&
eps_inside_sqrt == rhs_casted.eps_inside_sqrt &&
across_channels == rhs_casted.across_channels;
}
};
} // namespace cldnn

View File

@ -81,6 +81,25 @@ struct non_max_suppression : public primitive_base<non_max_suppression> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const non_max_suppression>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(selected_indices_num) &&
cmp_fields(center_point_box) &&
cmp_fields(sort_result_descending) &&
cmp_fields(num_select_per_class.empty()) &&
cmp_fields(iou_threshold.empty()) &&
cmp_fields(score_threshold.empty()) &&
cmp_fields(soft_nms_sigma.empty()) &&
cmp_fields(second_output.empty()) &&
cmp_fields(third_output.empty());
#undef cmp_fields
}
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;
if (!num_select_per_class.empty())

View File

@ -18,6 +18,10 @@ struct count_nonzero : public primitive_base<count_nonzero> {
const input_info& data,
const padding& output_padding = padding())
: primitive_base(id, {data}, {output_padding}) {}
bool operator==(const primitive& rhs) const override {
return compare_common_params(rhs);
}
};
struct gather_nonzero : public primitive_base<gather_nonzero> {
@ -32,6 +36,10 @@ struct gather_nonzero : public primitive_base<gather_nonzero> {
const input_info& output_shape,
const padding& output_padding = padding())
: primitive_base(id, {data, output_shape}, {output_padding}) {}
bool operator==(const primitive& rhs) const override {
return compare_common_params(rhs);
}
};
} // namespace cldnn

View File

@ -62,6 +62,16 @@ struct normalize : public primitive_base<normalize> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const normalize>(rhs);
return across_spatial == rhs_casted.across_spatial &&
epsilon == rhs_casted.epsilon;
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override { return {scale_input}; }
};

View File

@ -95,5 +95,17 @@ struct one_hot : public primitive_base<one_hot> {
seed = hash_combine(seed, off_value);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const one_hot>(rhs);
return one_hot_axis == rhs_casted.one_hot_axis &&
depth == rhs_casted.depth &&
on_value == rhs_casted.on_value &&
off_value == rhs_casted.off_value;
}
};
} // namespace cldnn

View File

@ -37,5 +37,14 @@ struct permute : public primitive_base<permute> {
seed = hash_range(seed, permute_order.begin(), permute_order.end());
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const permute>(rhs);
return permute_order == rhs_casted.permute_order;
}
};
} // namespace cldnn

View File

@ -176,6 +176,28 @@ struct pooling : public primitive_base<pooling> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const pooling>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(mode) &&
cmp_fields(size) &&
cmp_fields(stride) &&
cmp_fields(dilation) &&
cmp_fields(pads_begin) &&
cmp_fields(pads_end) &&
cmp_fields(auto_pad) &&
cmp_fields(rounding_type) &&
cmp_fields(axis) &&
cmp_fields(index_element_type) &&
cmp_fields(maxPoolOpset8Features) &&
cmp_fields(indices_output.empty());
#undef cmp_fields
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
std::vector<std::reference_wrapper<const primitive_id>> ret;

View File

@ -103,7 +103,40 @@ public:
return seed;
}
/// @brief Implicit conversion to primiitive id.
bool compare_common_params(const primitive& rhs) const {
if (type != rhs.type)
return false;
if (num_outputs != rhs.num_outputs)
return false;
if (dependencies().size() != rhs.dependencies().size())
return false;
if (output_data_types.size() != rhs.output_data_types.size())
return false;
for (size_t i = 0; i < output_data_types.size(); ++i) {
if (output_data_types[i].value_or(data_types::bin) != rhs.output_data_types[i].value_or(data_types::bin))
return false;
}
if (output_paddings.size() != rhs.output_paddings.size())
return false;
for (size_t i = 0; i < output_paddings.size(); ++i) {
if (output_paddings[i] != rhs.output_paddings[i])
return false;
}
return true;
}
virtual bool operator==(const primitive& rhs) const { return false; }
bool operator!=(const primitive& rhs) const { return !(*this == rhs); }
/// @brief Implicit conversion to primitive id.
operator primitive_id() const { return id; }
/// @brief Primitive's type id.

View File

@ -205,6 +205,36 @@ struct prior_box : public primitive_base<prior_box> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const prior_box>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(img_size) &&
cmp_fields(min_sizes) &&
cmp_fields(max_sizes) &&
cmp_fields(aspect_ratios) &&
cmp_fields(flip) &&
cmp_fields(clip) &&
cmp_fields(variance) &&
cmp_fields(step_width) &&
cmp_fields(step_height) &&
cmp_fields(offset) &&
cmp_fields(scale_all_sizes) &&
cmp_fields(fixed_ratio) &&
cmp_fields(fixed_size) &&
cmp_fields(density) &&
cmp_fields(support_opset8) &&
cmp_fields(step) &&
cmp_fields(min_max_aspect_ratios_order) &&
cmp_fields(widths) &&
cmp_fields(heights) &&
cmp_fields(clustered);
#undef cmp_fields
}
private:
bool clustered;

View File

@ -186,6 +186,36 @@ struct proposal : public primitive_base<proposal> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const proposal>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(max_proposals) &&
cmp_fields(iou_threshold) &&
cmp_fields(base_bbox_size) &&
cmp_fields(min_bbox_size) &&
cmp_fields(feature_stride) &&
cmp_fields(pre_nms_topn) &&
cmp_fields(post_nms_topn) &&
cmp_fields(ratios) &&
cmp_fields(scales) &&
cmp_fields(coordinates_offset) &&
cmp_fields(box_coordinate_scale) &&
cmp_fields(box_size_scale) &&
cmp_fields(for_deformable) &&
cmp_fields(swap_xy) &&
cmp_fields(initial_clip) &&
cmp_fields(clip_before_nms) &&
cmp_fields(clip_after_nms) &&
cmp_fields(round_ratios) &&
cmp_fields(shift_anchors) &&
cmp_fields(normalize);
#undef cmp_fields
}
void save(BinaryOutputBuffer& ob) const override {
ob << max_proposals;
ob << iou_threshold;

View File

@ -66,5 +66,17 @@ struct pyramid_roi_align : public primitive_base<pyramid_roi_align> {
seed = hash_combine(seed, pyramid_starting_level);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const pyramid_roi_align>(rhs);
return output_size == rhs_casted.output_size &&
sampling_ratio == rhs_casted.sampling_ratio &&
pyramid_scales == rhs_casted.pyramid_scales &&
pyramid_starting_level == rhs_casted.pyramid_starting_level;
}
};
} // namespace cldnn

View File

@ -36,5 +36,14 @@ struct quantize : public primitive_base<quantize> {
seed = cldnn::hash_combine(seed, levels);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const quantize>(rhs);
return levels == rhs_casted.levels;
}
};
} // namespace cldnn

View File

@ -49,6 +49,16 @@ struct random_uniform : public primitive_base<random_uniform> {
seed = hash_combine(seed, op_seed);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const random_uniform>(rhs);
return global_seed == rhs_casted.global_seed &&
op_seed == rhs_casted.op_seed;
}
};
} // namespace cldnn

View File

@ -29,5 +29,9 @@ struct range: public primitive_base<range> {
/// @brief requested range output layout
layout output_layout;
bool operator==(const primitive& rhs) const override {
return compare_common_params(rhs);
}
};
} // namespace cldnn

View File

@ -30,5 +30,14 @@ struct read_value : public primitive_base<read_value> {
std::string variable_id;
layout output_layout;
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const read_value>(rhs);
return variable_id == rhs_casted.variable_id;
}
};
} // namespace cldnn

View File

@ -68,5 +68,16 @@ struct reduce : public primitive_base<reduce> {
seed = hash_combine(seed, keep_dims);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const reduce>(rhs);
return mode == rhs_casted.mode &&
axes == rhs_casted.axes &&
keep_dims == rhs_casted.keep_dims;
}
};
} // namespace cldnn

View File

@ -51,6 +51,19 @@ struct region_yolo : public primitive_base<region_yolo> {
seed = hash_combine(seed, do_softmax);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const region_yolo>(rhs);
return coords == rhs_casted.coords &&
classes == rhs_casted.classes &&
num == rhs_casted.num &&
mask_size == rhs_casted.mask_size &&
do_softmax == rhs_casted.do_softmax;
}
};
} // namespace cldnn
#pragma once

View File

@ -165,6 +165,19 @@ struct reorder : public primitive_base<reorder> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const reorder>(rhs);
return subtract_per_feature == rhs_casted.subtract_per_feature &&
mean_mode == rhs_casted.mean_mode &&
input_mem_type == rhs_casted.input_mem_type &&
truncate == rhs_casted.truncate &&
mean.empty() == rhs_casted.mean.empty();
}
protected:
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
if (mean.empty())

View File

@ -34,6 +34,15 @@ struct reorg_yolo : public primitive_base<reorg_yolo> {
seed = hash_combine(seed, stride);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const reorg_yolo>(rhs);
return stride == rhs_casted.stride;
}
};
} // namespace cldnn
#pragma once

View File

@ -165,5 +165,27 @@ struct resample : public primitive_base<resample> {
seed = hash_combine(seed, round_mode);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const resample>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(num_filter) &&
cmp_fields(sizes) &&
cmp_fields(scales) &&
cmp_fields(axes) &&
cmp_fields(pads_begin) &&
cmp_fields(pads_end) &&
cmp_fields(operation_type) &&
cmp_fields(shape_calc_mode) &&
cmp_fields(antialias) &&
cmp_fields(cube_coeff) &&
cmp_fields(coord_trans_mode) &&
cmp_fields(round_mode);
#undef cmp_fields
}
};
} // namespace cldnn

View File

@ -79,6 +79,16 @@ struct reshape : public primitive_base<reshape> {
ov::PartialShape output_partial_shape;
reshape_mode mode;
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const reshape>(rhs);
return special_zero == rhs_casted.special_zero &&
mode == rhs_casted.mode;
}
};
} // namespace cldnn

View File

@ -33,5 +33,14 @@ struct reverse : public primitive_base<reverse> {
seed = hash_combine(seed, mode);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const reverse>(rhs);
return mode == rhs_casted.mode;
}
};
} // namespace cldnn

View File

@ -58,5 +58,15 @@ struct reverse_sequence : public primitive_base<reverse_sequence> {
seed = hash_combine(seed, batch_axis);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const reverse_sequence>(rhs);
return seq_axis == rhs_casted.seq_axis &&
batch_axis == rhs_casted.batch_axis;
}
};
} // namespace cldnn

View File

@ -68,5 +68,19 @@ struct roi_align : public primitive_base<roi_align> {
seed = hash_combine(seed, aligned_mode);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const roi_align>(rhs);
return pooled_h == rhs_casted.pooled_h &&
pooled_w == rhs_casted.pooled_w &&
sampling_ratio == rhs_casted.sampling_ratio &&
spatial_scale == rhs_casted.spatial_scale &&
pooling_mode == rhs_casted.pooling_mode &&
aligned_mode == rhs_casted.aligned_mode;
}
};
} // namespace cldnn

View File

@ -96,6 +96,28 @@ struct roi_pooling : public primitive_base<roi_pooling> {
seed = hash_combine(seed, spatial_bins_y);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const roi_pooling>(rhs);
#define cmp_fields(name) name == rhs_casted.name
return cmp_fields(mode) &&
cmp_fields(position_sensitive) &&
cmp_fields(pooled_width) &&
cmp_fields(pooled_height) &&
cmp_fields(spatial_scale) &&
cmp_fields(trans_std) &&
cmp_fields(no_trans) &&
cmp_fields(output_dim) &&
cmp_fields(part_size) &&
cmp_fields(group_size) &&
cmp_fields(spatial_bins_x) &&
cmp_fields(spatial_bins_y);
#undef cmp_fields
}
};
} // namespace cldnn

View File

@ -32,6 +32,15 @@ struct roll : primitive_base<roll> {
seed = hash_combine(seed, shift.hash());
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const roll>(rhs);
return shift == rhs_casted.shift;
}
};
} // namespace cldnn

View File

@ -34,5 +34,14 @@ struct scatter_elements_update : public primitive_base<scatter_elements_update>
seed = hash_combine(seed, axis);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const scatter_elements_update>(rhs);
return axis == rhs_casted.axis;
}
};
} // namespace cldnn

View File

@ -34,5 +34,14 @@ struct scatter_nd_update : public primitive_base<scatter_nd_update> {
seed = hash_combine(seed, indices_rank);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const scatter_nd_update>(rhs);
return indices_rank == rhs_casted.indices_rank;
}
};
} // namespace cldnn

View File

@ -43,5 +43,14 @@ struct scatter_update : public primitive_base<scatter_update> {
seed = hash_combine(seed, axis);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const scatter_update>(rhs);
return axis == rhs_casted.axis;
}
};
} // namespace cldnn

View File

@ -40,5 +40,14 @@ struct select : public primitive_base<select> {
/// @brief Define auto broadcast rule specification.
ov::op::AutoBroadcastSpec broadcast_spec;
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const select>(rhs);
return broadcast_spec == rhs_casted.broadcast_spec;
}
};
} // namespace cldnn

View File

@ -36,5 +36,14 @@ struct shape_of : public primitive_base<shape_of> {
, output_rank(0) {}
size_t output_rank;
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const shape_of>(rhs);
return output_rank == rhs_casted.output_rank;
}
};
} // namespace cldnn

View File

@ -36,5 +36,15 @@ struct shuffle_channels : public primitive_base<shuffle_channels> {
seed = hash_combine(seed, axis);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const shuffle_channels>(rhs);
return group == rhs_casted.group &&
axis == rhs_casted.axis;
}
};
} // namespace cldnn

View File

@ -16,13 +16,17 @@ struct slice : public primitive_base<slice> {
/// @param id This primitive id.
/// @param inputs List of primitive ids.
slice(const primitive_id& id,
const std::vector<input_info>& inputs,
const tensor output_shape,
const padding& output_padding = padding())
const std::vector<input_info>& inputs,
const tensor output_shape,
const padding& output_padding = padding())
: primitive_base{id, inputs, {output_padding}},
output_shape {output_shape}
{}
tensor output_shape;
bool operator==(const primitive& rhs) const override {
return compare_common_params(rhs);
}
};
} // namespace cldnn

View File

@ -44,5 +44,14 @@ struct softmax : public primitive_base<softmax> {
seed = hash_combine(seed, dimension);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const softmax>(rhs);
return dimension == rhs_casted.dimension;
}
};
} // namespace cldnn

View File

@ -68,5 +68,16 @@ struct space_to_batch : public primitive_base<space_to_batch> {
seed = hash_combine(seed, pads_end.hash());
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const space_to_batch>(rhs);
return block_shape == rhs_casted.block_shape &&
pads_begin == rhs_casted.pads_begin &&
pads_end == rhs_casted.pads_end;
}
};
} // namespace cldnn

View File

@ -71,5 +71,15 @@ struct space_to_depth : public primitive_base<space_to_depth> {
seed = hash_combine(seed, block_size);
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const space_to_depth>(rhs);
return mode == rhs_casted.mode &&
block_size == rhs_casted.block_size;
}
};
} // namespace cldnn

View File

@ -57,6 +57,15 @@ struct split : public primitive_base<split> {
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const split>(rhs);
return output_offsets == rhs_casted.output_offsets;
}
protected:
static std::vector<primitive_id> extract_primitive_vector(
const std::vector<std::pair<primitive_id, tensor> >& stor) {

View File

@ -111,5 +111,21 @@ struct strided_slice : public primitive_base<strided_slice> {
seed = hash_range(seed, shrink_axis_mask.begin(), shrink_axis_mask.end());
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const strided_slice>(rhs);
return begin == rhs_casted.begin &&
end == rhs_casted.end &&
strides == rhs_casted.strides &&
begin_mask == rhs_casted.begin_mask &&
end_mask == rhs_casted.end_mask &&
new_axis_mask == rhs_casted.new_axis_mask &&
shrink_axis_mask == rhs_casted.shrink_axis_mask &&
ellipsis_mask == rhs_casted.ellipsis_mask;
}
};
} // namespace cldnn

View File

@ -38,5 +38,14 @@ struct tile : public primitive_base<tile> {
seed = hash_range(seed, repeats.begin(), repeats.end());
return seed;
}
bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;
auto rhs_casted = downcast<const tile>(rhs);
return repeats == rhs_casted.repeats;
}
};
} // namespace cldnn

View File

@ -140,7 +140,7 @@ inline derived_type& downcast(base_type& base) {
} catch (std::bad_cast& /* ex */) {
throw std::runtime_error("Unable to cast reference from base to derived type");
}
throw std::runtime_error("downcast failed with unhadnled exception");
throw std::runtime_error("downcast failed with unhandled exception");
}
template <typename T>

View File

@ -0,0 +1,111 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "test_utils.h"
#include <intel_gpu/primitives/range.hpp>
#include <intel_gpu/primitives/convolution.hpp>
#include <intel_gpu/primitives/gemm.hpp>
#include <intel_gpu/primitives/fully_connected.hpp>
#include <intel_gpu/primitives/gather.hpp>
#include <intel_gpu/primitives/permute.hpp>
using namespace cldnn;
using namespace ::tests;
TEST(primitive_comparison, common_params) {
auto def_inputs = {input_info("input0"), input_info("input1"), input_info("input2")};
auto def_shape = ov::PartialShape{1, 2, 3, 4};
auto def_data_type = data_types::f32;
auto def_format = format::bfyx;
auto def_padding = padding({1, 1, 1, 1});
auto fc_prim = fully_connected("fc", input_info("input"), "weights");
auto range_prim = range("range", def_inputs, layout{def_shape, def_data_type, def_format, def_padding});
auto range_prim_inputs = range("range", {input_info("input0"), input_info("input1")}, layout{def_shape, def_data_type, def_format, def_padding});
auto range_prim_data_type = range("range", def_inputs, layout{def_shape, data_types::f16, def_format, def_padding});
auto range_prim_padding_values = range("range", def_inputs, layout{def_shape, def_data_type, def_format, padding({1, 1, 1, 2})});
auto range_prim_padding_fill_value = range("range", def_inputs, layout{def_shape, def_data_type, def_format, padding({1, 1, 1, 1}, 1.f)});
ASSERT_NE(range_prim, fc_prim);
ASSERT_NE(range_prim, range_prim_inputs);
ASSERT_NE(range_prim, range_prim_data_type);
ASSERT_NE(range_prim, range_prim_padding_values);
ASSERT_NE(range_prim, range_prim_padding_fill_value);
}
TEST(primitive_comparison, convolution) {
auto conv_prim = convolution("conv", input_info("input"), {"weights"}, {"bias"}, 1,
{2, 2}, {0, 0}, {1, 1}, {1, 3, 224, 224}, data_types::f32, false);
auto conv_prim_eq = convolution("conv_eq", input_info("input_eq"), {"weights_eq"}, {"bias_eq"}, 1,
{2, 2}, {0, 0}, {1, 1}, {1, 3, 224, 224}, data_types::f32, false);
auto conv_prim_stride = convolution("conv", input_info("input"), {"weights"}, {"bias"}, 1,
{1, 1}, {0, 0}, {1, 1}, {1, 3, 224, 224}, data_types::f32, false);
auto conv_prim_no_bias = convolution("conv", input_info("input"), {"weights"}, {}, 1,
{2, 2}, {0, 0}, {1, 1}, {1, 3, 224, 224}, data_types::f32, false);
auto conv_prim_grouped = convolution("conv", input_info("input"), {"weights"}, {"bias"}, 2,
{2, 2}, {0, 0}, {1, 1}, {1, 3, 224, 224}, data_types::f32, true);
ASSERT_EQ(conv_prim, conv_prim_eq);
ASSERT_NE(conv_prim, conv_prim_stride);
ASSERT_NE(conv_prim, conv_prim_no_bias);
ASSERT_NE(conv_prim, conv_prim_grouped);
}
TEST(primitive_comparison, gemm) {
auto def_inputs = {input_info("input0"), input_info("input1")};
auto gemm_prim = gemm("gemm", def_inputs, data_types::f32);
auto gemm_prim_eq = gemm("gemm_eq", {input_info("input0_eq"), input_info("input1_eq")}, data_types::f32);
auto gemm_prim_rank = gemm("gemm", def_inputs, data_types::f32, false, false, 1.0f, 0.0f, 2, 2);
auto gemm_prim_alpha = gemm("gemm", def_inputs, data_types::f32, false, false, 1.5f);
auto gemm_prim_transpose = gemm("gemm", def_inputs, data_types::f32, true);
ASSERT_EQ(gemm_prim, gemm_prim_eq);
ASSERT_NE(gemm_prim, gemm_prim_rank);
ASSERT_NE(gemm_prim, gemm_prim_alpha);
ASSERT_NE(gemm_prim, gemm_prim_transpose);
}
TEST(primitive_comparison, fully_connected) {
auto fc_prim = fully_connected("fc", input_info("input"), "weights", "bias", {}, 2);
auto fc_prim_eq = fully_connected("fc_eq", input_info("input_eq"), "weights_eq", "bias_eq", {}, 2);
auto fc_prim_bias = fully_connected("fc", input_info("input"), "weights", "", {}, 2);
auto fc_prim_input_size = fully_connected("fc", input_info("input"), "weights", "bias", {}, 4);
ASSERT_EQ(fc_prim, fc_prim_eq);
ASSERT_NE(fc_prim, fc_prim_bias);
ASSERT_NE(fc_prim, fc_prim_input_size);
}
TEST(primitive_comparison, gather) {
auto gather_prim = gather("gather", input_info("input0"), input_info("input1"), 2, {1, 3, 224, 224}, 1, true);
auto gather_prim_eq = gather("gather_eq", input_info("input0_eq"), input_info("input1_eq"), 2, {1, 3, 224, 224}, 1, true);
auto gather_prim_axis = gather("gather", input_info("input0"), input_info("input1"), 3, {1, 3, 224, 224}, 1, true);
auto gather_prim_batch_dim = gather("gather", input_info("input0"), input_info("input1"), 2, {1, 3, 224, 224}, 2, true);
auto gather_prim_support_neg_ind = gather("gather", input_info("input0"), input_info("input1"), 2, {1, 3, 224, 224}, 1, false);
ASSERT_EQ(gather_prim, gather_prim_eq);
ASSERT_NE(gather_prim, gather_prim_axis);
ASSERT_NE(gather_prim, gather_prim_batch_dim);
ASSERT_NE(gather_prim, gather_prim_support_neg_ind);
}
TEST(primitive_comparison, permute) {
auto permute_prim = permute("permute", input_info("input"), {0, 1, 2, 3});
auto permute_prim_eq = permute("permute_eq", input_info("input_eq"), {0, 1, 2, 3});
auto permute_prim_order = permute("permute", input_info("input"), {3, 2, 1, 0});
ASSERT_EQ(permute_prim, permute_prim_eq);
ASSERT_NE(permute_prim, permute_prim_order);
}