[GPU] Added operator== for cldnn primitives (#15736)
This commit is contained in:
parent
59542d5cd3
commit
efb51b058c
@ -125,6 +125,18 @@ struct activation : public primitive_base<activation> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const activation>(rhs);
|
||||
|
||||
return activation_function == rhs_casted.activation_function &&
|
||||
additional_params.a == rhs_casted.additional_params.a &&
|
||||
additional_params.b == rhs_casted.additional_params.b &&
|
||||
additional_params_input.empty() == rhs_casted.additional_params_input.empty();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
if (additional_params_input.empty())
|
||||
|
@ -58,6 +58,17 @@ struct adaptive_pooling : public primitive_base<adaptive_pooling> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const adaptive_pooling>(rhs);
|
||||
|
||||
return mode == rhs_casted.mode &&
|
||||
indices_output == rhs_casted.indices_output &&
|
||||
index_element_type == rhs_casted.index_element_type;
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
|
@ -82,5 +82,18 @@ struct arg_max_min : public primitive_base<arg_max_min> {
|
||||
seed = hash_combine(seed, values_first);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const arg_max_min>(rhs);
|
||||
|
||||
return mode == rhs_casted.mode &&
|
||||
top_k == rhs_casted.top_k &&
|
||||
axis == rhs_casted.axis &&
|
||||
sort == rhs_casted.sort &&
|
||||
values_first == rhs_casted.values_first;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -21,14 +21,23 @@ struct assign : public primitive_base<assign> {
|
||||
/// @param variable_id Variable id
|
||||
/// @param output_layout Memory layout
|
||||
assign(const primitive_id &id,
|
||||
const std::vector<input_info>& inputs,
|
||||
const std::string& variable_id,
|
||||
const layout& output_layout)
|
||||
: primitive_base(id, inputs, {padding()}, {optional_data_type{output_layout.data_type}}),
|
||||
variable_id{variable_id},
|
||||
output_layout{output_layout} {}
|
||||
const std::vector<input_info>& inputs,
|
||||
const std::string& variable_id,
|
||||
const layout& output_layout)
|
||||
: primitive_base(id, inputs, {padding()}, {optional_data_type{output_layout.data_type}}),
|
||||
variable_id{variable_id},
|
||||
output_layout{output_layout} {}
|
||||
|
||||
std::string variable_id;
|
||||
layout output_layout;
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const assign>(rhs);
|
||||
|
||||
return variable_id == rhs_casted.variable_id;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -71,5 +71,16 @@ struct batch_to_space : public primitive_base<batch_to_space> {
|
||||
seed = hash_combine(seed, crops_end.hash());
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const batch_to_space>(rhs);
|
||||
|
||||
return block_shape == rhs_casted.block_shape &&
|
||||
crops_begin == rhs_casted.crops_begin &&
|
||||
crops_end == rhs_casted.crops_end;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -76,6 +76,20 @@ struct binary_convolution : public primitive_base<binary_convolution> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const binary_convolution>(rhs);
|
||||
|
||||
return pad == rhs_casted.pad &&
|
||||
stride == rhs_casted.stride &&
|
||||
dilation == rhs_casted.dilation &&
|
||||
groups == rhs_casted.groups &&
|
||||
pad_value == rhs_casted.pad_value &&
|
||||
weights.size() == rhs_casted.weights.size();
|
||||
}
|
||||
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
ret.reserve(weights.size());
|
||||
|
@ -80,5 +80,17 @@ struct border : public primitive_base<border> {
|
||||
seed = hash_combine(seed, pad_value);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const border>(rhs);
|
||||
|
||||
return pads_begin == rhs_casted.pads_begin &&
|
||||
pads_end == rhs_casted.pads_end &&
|
||||
pad_mode == rhs_casted.pad_mode &&
|
||||
pad_value == rhs_casted.pad_value;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -135,5 +135,16 @@ struct broadcast : public primitive_base<broadcast> {
|
||||
seed = hash_range(seed, axes_mapping.begin(), axes_mapping.end());
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const broadcast>(rhs);
|
||||
|
||||
return axes_mapping == rhs_casted.axes_mapping &&
|
||||
broadcast_mode == rhs_casted.broadcast_mode &&
|
||||
broadcast_sizes == rhs_casted.broadcast_sizes;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -31,6 +31,15 @@ struct bucketize : primitive_base<bucketize> {
|
||||
seed = hash_combine(seed, with_right_bound);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const bucketize>(rhs);
|
||||
|
||||
return with_right_bound == rhs_casted.with_right_bound;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cldnn
|
||||
|
@ -64,5 +64,14 @@ struct concatenation : public primitive_base<concatenation> {
|
||||
seed = hash_combine(seed, axis);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const concatenation>(rhs);
|
||||
|
||||
return axis == rhs_casted.axis;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -26,9 +26,9 @@ struct condition : public primitive_base<condition> {
|
||||
/// @param id An identifier of new primitive.
|
||||
/// @param input An identifier of primitive which is an input for newly created
|
||||
/// condition primitive.
|
||||
/// @param topology_true Topolgoy containg primitives, which will be executed when comparsion results
|
||||
/// @param topology_true Topology containg primitives, which will be executed when comparsion results
|
||||
/// true.
|
||||
/// @param topology_false Topolgoy containg primitives, which will be executed when comparsion results
|
||||
/// @param topology_false Topology containg primitives, which will be executed when comparsion results
|
||||
/// false..
|
||||
/// @param compare_Data An identifier of primitive which contains compare values
|
||||
/// @param func Used function during comparison.
|
||||
|
@ -58,5 +58,17 @@ struct convert_color : public primitive_base<convert_color> {
|
||||
seed = hash_combine(seed, mem_type);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const convert_color>(rhs);
|
||||
|
||||
return input_color_format == rhs_casted.input_color_format &&
|
||||
output_color_format == rhs_casted.output_color_format &&
|
||||
mem_type == rhs_casted.mem_type &&
|
||||
output_layout == rhs_casted.output_layout;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -781,6 +781,31 @@ struct convolution : public primitive_base<convolution> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const convolution>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(pad) &&
|
||||
cmp_fields(stride) &&
|
||||
cmp_fields(dilation) &&
|
||||
cmp_fields(groups) &&
|
||||
cmp_fields(deformable_groups) &&
|
||||
cmp_fields(padding_above) &&
|
||||
cmp_fields(padding_below) &&
|
||||
cmp_fields(deformable_mode) &&
|
||||
cmp_fields(bilinear_interpolation_pad) &&
|
||||
cmp_fields(grouped_weights_shape) &&
|
||||
cmp_fields(weights.size()) &&
|
||||
cmp_fields(bias.size()) &&
|
||||
cmp_fields(weights_zero_points.size()) &&
|
||||
cmp_fields(activations_zero_points.size()) &&
|
||||
cmp_fields(compensation.size());
|
||||
#undef cmp_fields
|
||||
}
|
||||
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
ret.reserve(weights.size() + bias.size() + weights_zero_points.size() +
|
||||
@ -858,6 +883,25 @@ struct deformable_interp : public primitive_base<deformable_interp> {
|
||||
seed = cldnn::hash_combine(seed, bilinear_interpolation_pad);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const deformable_interp>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(pad) &&
|
||||
cmp_fields(stride) &&
|
||||
cmp_fields(dilation) &&
|
||||
cmp_fields(kernel_size) &&
|
||||
cmp_fields(groups) &&
|
||||
cmp_fields(deformable_groups) &&
|
||||
cmp_fields(padding_above) &&
|
||||
cmp_fields(padding_below) &&
|
||||
cmp_fields(bilinear_interpolation_pad);
|
||||
#undef cmp_fields
|
||||
}
|
||||
};
|
||||
|
||||
struct deformable_conv : public primitive_base<deformable_conv> {
|
||||
@ -893,6 +937,17 @@ struct deformable_conv : public primitive_base<deformable_conv> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const deformable_conv>(rhs);
|
||||
|
||||
return groups == rhs_casted.groups &&
|
||||
weights.size() == rhs_casted.weights.size() &&
|
||||
bias.size() == rhs_casted.bias.size();
|
||||
}
|
||||
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
ret.reserve(weights.size() + bias.size());
|
||||
|
@ -134,5 +134,18 @@ struct crop : public primitive_base<crop> {
|
||||
seed = hash_combine(seed, op_mode);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const crop>(rhs);
|
||||
|
||||
return reference_input == rhs_casted.reference_input &&
|
||||
offsets == rhs_casted.offsets &&
|
||||
output_idx == rhs_casted.output_idx &&
|
||||
num_splits == rhs_casted.num_splits &&
|
||||
op_mode == rhs_casted.op_mode;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -39,5 +39,16 @@ struct ctc_greedy_decoder : public primitive_base<ctc_greedy_decoder> {
|
||||
seed = hash_combine(seed, second_output.empty());
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const ctc_greedy_decoder>(rhs);
|
||||
|
||||
return blank_index == rhs_casted.blank_index &&
|
||||
ctc_merge_repeated == rhs_casted.ctc_merge_repeated &&
|
||||
second_output.empty() == rhs_casted.second_output.empty();
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -41,6 +41,17 @@ struct ctc_loss : primitive_base<ctc_loss> {
|
||||
seed = hash_combine(seed, unique);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const ctc_loss>(rhs);
|
||||
|
||||
return preprocess_collapse_repeated == rhs_casted.preprocess_collapse_repeated &&
|
||||
ctc_merge_repeated == rhs_casted.ctc_merge_repeated &&
|
||||
unique == rhs_casted.unique;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cldnn
|
||||
|
@ -40,5 +40,16 @@ struct cum_sum : public primitive_base<cum_sum> {
|
||||
seed = hash_combine(seed, reverse);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const cum_sum>(rhs);
|
||||
|
||||
return axis == rhs_casted.axis &&
|
||||
exclusive == rhs_casted.exclusive &&
|
||||
reverse == rhs_casted.reverse;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -80,5 +80,16 @@ struct custom_gpu_primitive : public primitive_base<custom_gpu_primitive> {
|
||||
seed = hash_combine(seed, kernels_code.size());
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const custom_gpu_primitive>(rhs);
|
||||
|
||||
return kernel_entry_point == rhs_casted.kernel_entry_point &&
|
||||
build_options == rhs_casted.build_options &&
|
||||
kernels_code.size() == rhs_casted.kernels_code.size();
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -387,6 +387,27 @@ struct deconvolution : public primitive_base<deconvolution> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const deconvolution>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(pad) &&
|
||||
cmp_fields(stride) &&
|
||||
cmp_fields(dilations) &&
|
||||
cmp_fields(groups) &&
|
||||
cmp_fields(pads_begin) &&
|
||||
cmp_fields(pads_end) &&
|
||||
cmp_fields(out_padding) &&
|
||||
cmp_fields(grouped_weights_shape) &&
|
||||
cmp_fields(weights.size()) &&
|
||||
cmp_fields(bias.size()) &&
|
||||
cmp_fields(output_shape_id.empty());
|
||||
#undef cmp_fields
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
|
@ -45,5 +45,15 @@ struct depth_to_space : public primitive_base<depth_to_space> {
|
||||
seed = hash_combine(seed, mode);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const depth_to_space>(rhs);
|
||||
|
||||
return block_size == rhs_casted.block_size &&
|
||||
mode == rhs_casted.mode;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -143,6 +143,34 @@ struct detection_output : public primitive_base<detection_output> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const detection_output>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(num_classes) &&
|
||||
cmp_fields(keep_top_k) &&
|
||||
cmp_fields(share_location) &&
|
||||
cmp_fields(background_label_id) &&
|
||||
cmp_fields(nms_threshold) &&
|
||||
cmp_fields(top_k) &&
|
||||
cmp_fields(eta) &&
|
||||
cmp_fields(code_type) &&
|
||||
cmp_fields(variance_encoded_in_target) &&
|
||||
cmp_fields(confidence_threshold) &&
|
||||
cmp_fields(prior_info_size) &&
|
||||
cmp_fields(prior_coordinates_offset) &&
|
||||
cmp_fields(prior_is_normalized) &&
|
||||
cmp_fields(input_width) &&
|
||||
cmp_fields(input_height) &&
|
||||
cmp_fields(decrease_label_id) &&
|
||||
cmp_fields(clip_before_nms) &&
|
||||
cmp_fields(clip_after_nms);
|
||||
#undef cmp_fields
|
||||
}
|
||||
|
||||
protected:
|
||||
};
|
||||
|
||||
|
@ -64,6 +64,18 @@ struct dft : public primitive_base<dft> {
|
||||
seed = hash_combine(seed, mode);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const dft>(rhs);
|
||||
|
||||
return axes == rhs_casted.axes &&
|
||||
signal_size == rhs_casted.signal_size &&
|
||||
direction == rhs_casted.direction &&
|
||||
mode == rhs_casted.mode;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cldnn
|
||||
|
@ -177,5 +177,17 @@ struct eltwise : public primitive_base<eltwise> {
|
||||
}
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const eltwise>(rhs);
|
||||
|
||||
return mode == rhs_casted.mode &&
|
||||
coefficients == rhs_casted.coefficients &&
|
||||
broadcast_spec == rhs_casted.broadcast_spec &&
|
||||
stride == rhs_casted.stride;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -45,5 +45,15 @@ struct embedding_bag : public primitive_base<embedding_bag> {
|
||||
seed = hash_combine(seed, default_index);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const embedding_bag>(rhs);
|
||||
|
||||
return type == rhs_casted.type &&
|
||||
default_index == rhs_casted.default_index;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -86,6 +86,26 @@ struct experimental_detectron_detection_output : public primitive_base<experimen
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const experimental_detectron_detection_output>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(score_threshold) &&
|
||||
cmp_fields(nms_threshold) &&
|
||||
cmp_fields(num_classes) &&
|
||||
cmp_fields(post_nms_count) &&
|
||||
cmp_fields(max_detections_per_image) &&
|
||||
cmp_fields(class_agnostic_box_regression) &&
|
||||
cmp_fields(max_delta_log_wh) &&
|
||||
cmp_fields(deltas_weights) &&
|
||||
cmp_fields(output_classes.empty()) &&
|
||||
cmp_fields(output_scores.empty());
|
||||
#undef cmp_fields
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
|
@ -58,6 +58,19 @@ struct experimental_detectron_generate_proposals_single_image
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const experimental_detectron_generate_proposals_single_image>(rhs);
|
||||
|
||||
return min_size == rhs_casted.min_size &&
|
||||
nms_threshold == rhs_casted.nms_threshold &&
|
||||
pre_nms_count == rhs_casted.pre_nms_count &&
|
||||
post_nms_count == rhs_casted.post_nms_count &&
|
||||
output_roi_scores.empty() == rhs_casted.output_roi_scores.empty();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
|
@ -60,6 +60,23 @@ struct experimental_detectron_prior_grid_generator
|
||||
seed = hash_combine(seed, image_width);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const experimental_detectron_prior_grid_generator>(rhs);
|
||||
|
||||
return flatten == rhs_casted.flatten &&
|
||||
h == rhs_casted.h &&
|
||||
w == rhs_casted.w &&
|
||||
stride_x == rhs_casted.stride_x &&
|
||||
stride_y == rhs_casted.stride_y &&
|
||||
featmap_height == rhs_casted.featmap_height &&
|
||||
featmap_width == rhs_casted.featmap_width &&
|
||||
image_height == rhs_casted.image_height &&
|
||||
image_width == rhs_casted.image_width;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cldnn
|
||||
|
@ -51,6 +51,20 @@ struct experimental_detectron_roi_feature_extractor : public primitive_base<expe
|
||||
seed = hash_combine(seed, aligned);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const experimental_detectron_roi_feature_extractor>(rhs);
|
||||
|
||||
return output_dim == rhs_casted.output_dim &&
|
||||
pooled_height == rhs_casted.pooled_height &&
|
||||
pooled_width == rhs_casted.pooled_width &&
|
||||
pyramid_scales == rhs_casted.pyramid_scales &&
|
||||
sampling_ratio == rhs_casted.sampling_ratio &&
|
||||
aligned == rhs_casted.aligned;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cldnn
|
||||
|
@ -37,6 +37,15 @@ struct experimental_detectron_topk_rois : public primitive_base<experimental_det
|
||||
seed = hash_combine(seed, max_rois);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const experimental_detectron_topk_rois>(rhs);
|
||||
|
||||
return max_rois == rhs_casted.max_rois;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cldnn
|
||||
|
@ -63,5 +63,17 @@ struct extract_image_patches : public primitive_base<extract_image_patches> {
|
||||
seed = hash_combine(seed, auto_pad);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const extract_image_patches>(rhs);
|
||||
|
||||
return sizes == rhs_casted.sizes &&
|
||||
strides == rhs_casted.strides &&
|
||||
rates == rhs_casted.rates &&
|
||||
auto_pad == rhs_casted.auto_pad;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -36,5 +36,14 @@ struct eye : public primitive_base<eye> {
|
||||
seed = hash_combine(seed, shift);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const eye>(rhs);
|
||||
|
||||
return shift == rhs_casted.shift;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -82,6 +82,16 @@ struct fully_connected : public primitive_base<fully_connected> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const fully_connected>(rhs);
|
||||
|
||||
return input_size == rhs_casted.input_size &&
|
||||
bias.empty() == rhs_casted.bias.empty();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
|
@ -52,5 +52,16 @@ struct gather : public primitive_base<gather> {
|
||||
seed = hash_combine(seed, support_neg_ind);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const gather>(rhs);
|
||||
|
||||
return axis == rhs_casted.axis &&
|
||||
batch_dim == rhs_casted.batch_dim &&
|
||||
support_neg_ind == rhs_casted.support_neg_ind;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -49,5 +49,15 @@ struct gather_elements : public primitive_base<gather_elements> {
|
||||
seed = hash_combine(seed, axis);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const gather_elements>(rhs);
|
||||
|
||||
return output_format == rhs_casted.output_format &&
|
||||
axis == rhs_casted.axis;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -57,5 +57,17 @@ struct gather_nd : public primitive_base<gather_nd> {
|
||||
seed = hash_combine(seed, batch_merged_output);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const gather_nd>(rhs);
|
||||
|
||||
return input_rank == rhs_casted.input_rank &&
|
||||
indices_rank == rhs_casted.indices_rank &&
|
||||
batch_dims == rhs_casted.batch_dims &&
|
||||
batch_merged_output == rhs_casted.batch_merged_output;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -13,22 +13,26 @@ namespace cldnn {
|
||||
struct gather_tree : public primitive_base<gather_tree> {
|
||||
CLDNN_DECLARE_PRIMITIVE(gather_tree)
|
||||
|
||||
/// @brief Constructs gather tree primitive / layer.
|
||||
///
|
||||
/// @param id An identifier of new primitive.
|
||||
/// @param step_input An identifier of primitive which is an step input
|
||||
/// @param parent_input An identifier of primitive which is an parent input
|
||||
/// @param step_seq_len_input An identifier of primitive which is an input that contains
|
||||
/// lengths of step sequence (per batch) to perform
|
||||
/// @param end_token An identifier of primitive which is an input that contains
|
||||
/// a value of the end_token
|
||||
/// @param output_padding Optional padding for output from primitive
|
||||
gather_tree(const primitive_id& id,
|
||||
const input_info& step_input,
|
||||
const input_info& parent_input,
|
||||
const input_info& max_seq_len_input,
|
||||
const input_info& end_token,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, { step_input, parent_input, max_seq_len_input, end_token }, {output_padding}) {}
|
||||
/// @brief Constructs gather tree primitive / layer.
|
||||
///
|
||||
/// @param id An identifier of new primitive.
|
||||
/// @param step_input An identifier of primitive which is an step input
|
||||
/// @param parent_input An identifier of primitive which is an parent input
|
||||
/// @param step_seq_len_input An identifier of primitive which is an input that contains
|
||||
/// lengths of step sequence (per batch) to perform
|
||||
/// @param end_token An identifier of primitive which is an input that contains
|
||||
/// a value of the end_token
|
||||
/// @param output_padding Optional padding for output from primitive
|
||||
gather_tree(const primitive_id& id,
|
||||
const input_info& step_input,
|
||||
const input_info& parent_input,
|
||||
const input_info& max_seq_len_input,
|
||||
const input_info& end_token,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, { step_input, parent_input, max_seq_len_input, end_token }, {output_padding}) {}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
return compare_common_params(rhs);
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -76,6 +76,20 @@ struct gemm : public primitive_base<gemm> {
|
||||
seed = hash_combine(seed, beta);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const gemm>(rhs);
|
||||
|
||||
return transpose_input0 == rhs_casted.transpose_input0 &&
|
||||
transpose_input1 == rhs_casted.transpose_input1 &&
|
||||
alpha == rhs_casted.alpha &&
|
||||
beta == rhs_casted.beta &&
|
||||
input_rank == rhs_casted.input_rank &&
|
||||
weight_rank == rhs_casted.weight_rank;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cldnn
|
||||
|
@ -73,6 +73,25 @@ struct generate_proposals
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const generate_proposals>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(min_size) &&
|
||||
cmp_fields(nms_threshold) &&
|
||||
cmp_fields(pre_nms_count) &&
|
||||
cmp_fields(post_nms_count) &&
|
||||
cmp_fields(normalized) &&
|
||||
cmp_fields(nms_eta) &&
|
||||
cmp_fields(roi_num_type) &&
|
||||
cmp_fields(output_rois_scores.empty()) &&
|
||||
cmp_fields(output_rois_num.empty());
|
||||
#undef cmp_fields
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
|
@ -38,6 +38,17 @@ struct grid_sample : primitive_base<grid_sample> {
|
||||
seed = hash_combine(seed, attributes.padding_mode);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const grid_sample>(rhs);
|
||||
|
||||
return attributes.align_corners == rhs_casted.attributes.align_corners &&
|
||||
attributes.mode == rhs_casted.attributes.mode &&
|
||||
attributes.padding_mode == rhs_casted.attributes.padding_mode;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cldnn
|
||||
|
@ -32,5 +32,14 @@ struct grn : public primitive_base<grn> {
|
||||
seed = hash_combine(seed, bias);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const grn>(rhs);
|
||||
|
||||
return bias == rhs_casted.bias;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -70,5 +70,18 @@ struct lrn : public primitive_base<lrn> {
|
||||
seed = hash_combine(seed, norm_region);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const lrn>(rhs);
|
||||
|
||||
return size == rhs_casted.size &&
|
||||
k == rhs_casted.k &&
|
||||
alpha == rhs_casted.alpha &&
|
||||
beta == rhs_casted.beta &&
|
||||
norm_region == rhs_casted.norm_region;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -143,6 +143,32 @@ struct lstm : public primitive_base<lstm> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const lstm>(rhs);
|
||||
|
||||
bool act_params_eq = activation_params.size() == rhs_casted.activation_params.size();
|
||||
for (size_t i = 0; i < activation_params.size(); ++i) {
|
||||
act_params_eq &= activation_params[i].a == rhs_casted.activation_params[i].a &&
|
||||
activation_params[i].b == rhs_casted.activation_params[i].b;
|
||||
}
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return act_params_eq &&
|
||||
cmp_fields(clip) &&
|
||||
cmp_fields(input_forget) &&
|
||||
cmp_fields(activations) &&
|
||||
cmp_fields(output_selection) &&
|
||||
cmp_fields(offset_order) &&
|
||||
cmp_fields(initial_hidden.empty()) &&
|
||||
cmp_fields(initial_cell.empty()) &&
|
||||
cmp_fields(peepholes.empty()) &&
|
||||
cmp_fields(bias.empty());
|
||||
#undef cmp_fields
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
@ -205,6 +231,17 @@ struct lstm_gemm : public primitive_base<lstm_gemm> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const lstm_gemm>(rhs);
|
||||
|
||||
return direction == rhs_casted.direction &&
|
||||
bias.empty() == rhs_casted.bias.empty() &&
|
||||
hidden.empty() == rhs_casted.hidden.empty();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
@ -282,6 +319,29 @@ struct lstm_elt : public primitive_base<lstm_elt> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const lstm_elt>(rhs);
|
||||
|
||||
bool act_params_eq = activation_params.size() == rhs_casted.activation_params.size();
|
||||
for (size_t i = 0; i < activation_params.size(); ++i) {
|
||||
act_params_eq &= activation_params[i].a == rhs_casted.activation_params[i].a &&
|
||||
activation_params[i].b == rhs_casted.activation_params[i].b;
|
||||
}
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return act_params_eq &&
|
||||
cmp_fields(clip) &&
|
||||
cmp_fields(input_forget) &&
|
||||
cmp_fields(activations) &&
|
||||
cmp_fields(offset_order) &&
|
||||
cmp_fields(direction) &&
|
||||
cmp_fields(cell.empty());
|
||||
#undef cmp_fields
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
|
@ -91,6 +91,23 @@ struct lstm_dynamic : public primitive_base<lstm_dynamic> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const lstm_dynamic>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(clip) &&
|
||||
cmp_fields(input_forget) &&
|
||||
cmp_fields(last_hidden_state.empty()) &&
|
||||
cmp_fields(last_cell_state.empty()) &&
|
||||
cmp_fields(initial_hidden.empty()) &&
|
||||
cmp_fields(initial_cell.empty()) &&
|
||||
cmp_fields(bias.empty());
|
||||
#undef cmp_fields
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
|
@ -48,6 +48,15 @@ struct lstm_dynamic_input : public primitive_base<lstm_dynamic_input> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const lstm_dynamic_input>(rhs);
|
||||
|
||||
return bias.empty() == rhs_casted.bias.empty();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
|
@ -79,6 +79,22 @@ struct lstm_dynamic_timeloop
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const lstm_dynamic_timeloop>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(clip) &&
|
||||
cmp_fields(input_forget) &&
|
||||
cmp_fields(last_hidden_state.empty()) &&
|
||||
cmp_fields(last_cell_state.empty()) &&
|
||||
cmp_fields(initial_hidden.empty()) &&
|
||||
cmp_fields(initial_cell.empty());
|
||||
#undef cmp_fields
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
|
@ -133,6 +133,26 @@ struct matrix_nms : public primitive_base<matrix_nms> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const matrix_nms>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(attribs.sort_type) &&
|
||||
cmp_fields(attribs.sort_result_across_batch) &&
|
||||
cmp_fields(attribs.score_threshold) &&
|
||||
cmp_fields(attribs.nms_top_k) &&
|
||||
cmp_fields(attribs.keep_top_k) &&
|
||||
cmp_fields(attribs.background_class) &&
|
||||
cmp_fields(attribs.decay) &&
|
||||
cmp_fields(attribs.gaussian_sigma) &&
|
||||
cmp_fields(attribs.post_threshold) &&
|
||||
cmp_fields(attribs.normalized);
|
||||
#undef cmp_fields
|
||||
}
|
||||
|
||||
private:
|
||||
static cldnn::matrix_nms::decay_function from(ngraph::op::v8::MatrixNms::DecayFunction decay) {
|
||||
switch (decay) {
|
||||
|
@ -141,6 +141,27 @@ struct multiclass_nms : public primitive_base<multiclass_nms> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const multiclass_nms>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(has_roisnum) &&
|
||||
cmp_fields(attrs.background_class) &&
|
||||
cmp_fields(attrs.indices_output_type) &&
|
||||
cmp_fields(attrs.iou_threshold) &&
|
||||
cmp_fields(attrs.keep_top_k) &&
|
||||
cmp_fields(attrs.nms_eta) &&
|
||||
cmp_fields(attrs.nms_top_k) &&
|
||||
cmp_fields(attrs.normalized) &&
|
||||
cmp_fields(attrs.score_threshold) &&
|
||||
cmp_fields(attrs.sort_result) &&
|
||||
cmp_fields(attrs.sort_result_across_batch);
|
||||
#undef cmp_fields
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
|
@ -49,5 +49,17 @@ struct mvn : public primitive_base<mvn> {
|
||||
seed = hash_combine(seed, across_channels);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const mvn>(rhs);
|
||||
|
||||
return normalize_variance == rhs_casted.normalize_variance &&
|
||||
epsilon == rhs_casted.epsilon &&
|
||||
eps_inside_sqrt == rhs_casted.eps_inside_sqrt &&
|
||||
across_channels == rhs_casted.across_channels;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -81,6 +81,25 @@ struct non_max_suppression : public primitive_base<non_max_suppression> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const non_max_suppression>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(selected_indices_num) &&
|
||||
cmp_fields(center_point_box) &&
|
||||
cmp_fields(sort_result_descending) &&
|
||||
cmp_fields(num_select_per_class.empty()) &&
|
||||
cmp_fields(iou_threshold.empty()) &&
|
||||
cmp_fields(score_threshold.empty()) &&
|
||||
cmp_fields(soft_nms_sigma.empty()) &&
|
||||
cmp_fields(second_output.empty()) &&
|
||||
cmp_fields(third_output.empty());
|
||||
#undef cmp_fields
|
||||
}
|
||||
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
if (!num_select_per_class.empty())
|
||||
|
@ -18,6 +18,10 @@ struct count_nonzero : public primitive_base<count_nonzero> {
|
||||
const input_info& data,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {data}, {output_padding}) {}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
return compare_common_params(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
struct gather_nonzero : public primitive_base<gather_nonzero> {
|
||||
@ -32,6 +36,10 @@ struct gather_nonzero : public primitive_base<gather_nonzero> {
|
||||
const input_info& output_shape,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, {data, output_shape}, {output_padding}) {}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
return compare_common_params(rhs);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cldnn
|
||||
|
@ -62,6 +62,16 @@ struct normalize : public primitive_base<normalize> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const normalize>(rhs);
|
||||
|
||||
return across_spatial == rhs_casted.across_spatial &&
|
||||
epsilon == rhs_casted.epsilon;
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override { return {scale_input}; }
|
||||
};
|
||||
|
@ -95,5 +95,17 @@ struct one_hot : public primitive_base<one_hot> {
|
||||
seed = hash_combine(seed, off_value);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const one_hot>(rhs);
|
||||
|
||||
return one_hot_axis == rhs_casted.one_hot_axis &&
|
||||
depth == rhs_casted.depth &&
|
||||
on_value == rhs_casted.on_value &&
|
||||
off_value == rhs_casted.off_value;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -37,5 +37,14 @@ struct permute : public primitive_base<permute> {
|
||||
seed = hash_range(seed, permute_order.begin(), permute_order.end());
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const permute>(rhs);
|
||||
|
||||
return permute_order == rhs_casted.permute_order;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -176,6 +176,28 @@ struct pooling : public primitive_base<pooling> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const pooling>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(mode) &&
|
||||
cmp_fields(size) &&
|
||||
cmp_fields(stride) &&
|
||||
cmp_fields(dilation) &&
|
||||
cmp_fields(pads_begin) &&
|
||||
cmp_fields(pads_end) &&
|
||||
cmp_fields(auto_pad) &&
|
||||
cmp_fields(rounding_type) &&
|
||||
cmp_fields(axis) &&
|
||||
cmp_fields(index_element_type) &&
|
||||
cmp_fields(maxPoolOpset8Features) &&
|
||||
cmp_fields(indices_output.empty());
|
||||
#undef cmp_fields
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
std::vector<std::reference_wrapper<const primitive_id>> ret;
|
||||
|
@ -103,7 +103,40 @@ public:
|
||||
return seed;
|
||||
}
|
||||
|
||||
/// @brief Implicit conversion to primiitive id.
|
||||
bool compare_common_params(const primitive& rhs) const {
|
||||
if (type != rhs.type)
|
||||
return false;
|
||||
|
||||
if (num_outputs != rhs.num_outputs)
|
||||
return false;
|
||||
|
||||
if (dependencies().size() != rhs.dependencies().size())
|
||||
return false;
|
||||
|
||||
if (output_data_types.size() != rhs.output_data_types.size())
|
||||
return false;
|
||||
|
||||
for (size_t i = 0; i < output_data_types.size(); ++i) {
|
||||
if (output_data_types[i].value_or(data_types::bin) != rhs.output_data_types[i].value_or(data_types::bin))
|
||||
return false;
|
||||
}
|
||||
|
||||
if (output_paddings.size() != rhs.output_paddings.size())
|
||||
return false;
|
||||
|
||||
for (size_t i = 0; i < output_paddings.size(); ++i) {
|
||||
if (output_paddings[i] != rhs.output_paddings[i])
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual bool operator==(const primitive& rhs) const { return false; }
|
||||
|
||||
bool operator!=(const primitive& rhs) const { return !(*this == rhs); }
|
||||
|
||||
/// @brief Implicit conversion to primitive id.
|
||||
operator primitive_id() const { return id; }
|
||||
|
||||
/// @brief Primitive's type id.
|
||||
|
@ -205,6 +205,36 @@ struct prior_box : public primitive_base<prior_box> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const prior_box>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(img_size) &&
|
||||
cmp_fields(min_sizes) &&
|
||||
cmp_fields(max_sizes) &&
|
||||
cmp_fields(aspect_ratios) &&
|
||||
cmp_fields(flip) &&
|
||||
cmp_fields(clip) &&
|
||||
cmp_fields(variance) &&
|
||||
cmp_fields(step_width) &&
|
||||
cmp_fields(step_height) &&
|
||||
cmp_fields(offset) &&
|
||||
cmp_fields(scale_all_sizes) &&
|
||||
cmp_fields(fixed_ratio) &&
|
||||
cmp_fields(fixed_size) &&
|
||||
cmp_fields(density) &&
|
||||
cmp_fields(support_opset8) &&
|
||||
cmp_fields(step) &&
|
||||
cmp_fields(min_max_aspect_ratios_order) &&
|
||||
cmp_fields(widths) &&
|
||||
cmp_fields(heights) &&
|
||||
cmp_fields(clustered);
|
||||
#undef cmp_fields
|
||||
}
|
||||
|
||||
private:
|
||||
bool clustered;
|
||||
|
||||
|
@ -186,6 +186,36 @@ struct proposal : public primitive_base<proposal> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const proposal>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(max_proposals) &&
|
||||
cmp_fields(iou_threshold) &&
|
||||
cmp_fields(base_bbox_size) &&
|
||||
cmp_fields(min_bbox_size) &&
|
||||
cmp_fields(feature_stride) &&
|
||||
cmp_fields(pre_nms_topn) &&
|
||||
cmp_fields(post_nms_topn) &&
|
||||
cmp_fields(ratios) &&
|
||||
cmp_fields(scales) &&
|
||||
cmp_fields(coordinates_offset) &&
|
||||
cmp_fields(box_coordinate_scale) &&
|
||||
cmp_fields(box_size_scale) &&
|
||||
cmp_fields(for_deformable) &&
|
||||
cmp_fields(swap_xy) &&
|
||||
cmp_fields(initial_clip) &&
|
||||
cmp_fields(clip_before_nms) &&
|
||||
cmp_fields(clip_after_nms) &&
|
||||
cmp_fields(round_ratios) &&
|
||||
cmp_fields(shift_anchors) &&
|
||||
cmp_fields(normalize);
|
||||
#undef cmp_fields
|
||||
}
|
||||
|
||||
void save(BinaryOutputBuffer& ob) const override {
|
||||
ob << max_proposals;
|
||||
ob << iou_threshold;
|
||||
|
@ -66,5 +66,17 @@ struct pyramid_roi_align : public primitive_base<pyramid_roi_align> {
|
||||
seed = hash_combine(seed, pyramid_starting_level);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const pyramid_roi_align>(rhs);
|
||||
|
||||
return output_size == rhs_casted.output_size &&
|
||||
sampling_ratio == rhs_casted.sampling_ratio &&
|
||||
pyramid_scales == rhs_casted.pyramid_scales &&
|
||||
pyramid_starting_level == rhs_casted.pyramid_starting_level;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -36,5 +36,14 @@ struct quantize : public primitive_base<quantize> {
|
||||
seed = cldnn::hash_combine(seed, levels);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const quantize>(rhs);
|
||||
|
||||
return levels == rhs_casted.levels;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -49,6 +49,16 @@ struct random_uniform : public primitive_base<random_uniform> {
|
||||
seed = hash_combine(seed, op_seed);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const random_uniform>(rhs);
|
||||
|
||||
return global_seed == rhs_casted.global_seed &&
|
||||
op_seed == rhs_casted.op_seed;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cldnn
|
||||
|
@ -29,5 +29,9 @@ struct range: public primitive_base<range> {
|
||||
|
||||
/// @brief requested range output layout
|
||||
layout output_layout;
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
return compare_common_params(rhs);
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -30,5 +30,14 @@ struct read_value : public primitive_base<read_value> {
|
||||
|
||||
std::string variable_id;
|
||||
layout output_layout;
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const read_value>(rhs);
|
||||
|
||||
return variable_id == rhs_casted.variable_id;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -68,5 +68,16 @@ struct reduce : public primitive_base<reduce> {
|
||||
seed = hash_combine(seed, keep_dims);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const reduce>(rhs);
|
||||
|
||||
return mode == rhs_casted.mode &&
|
||||
axes == rhs_casted.axes &&
|
||||
keep_dims == rhs_casted.keep_dims;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -51,6 +51,19 @@ struct region_yolo : public primitive_base<region_yolo> {
|
||||
seed = hash_combine(seed, do_softmax);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const region_yolo>(rhs);
|
||||
|
||||
return coords == rhs_casted.coords &&
|
||||
classes == rhs_casted.classes &&
|
||||
num == rhs_casted.num &&
|
||||
mask_size == rhs_casted.mask_size &&
|
||||
do_softmax == rhs_casted.do_softmax;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
#pragma once
|
||||
|
@ -165,6 +165,19 @@ struct reorder : public primitive_base<reorder> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const reorder>(rhs);
|
||||
|
||||
return subtract_per_feature == rhs_casted.subtract_per_feature &&
|
||||
mean_mode == rhs_casted.mean_mode &&
|
||||
input_mem_type == rhs_casted.input_mem_type &&
|
||||
truncate == rhs_casted.truncate &&
|
||||
mean.empty() == rhs_casted.mean.empty();
|
||||
}
|
||||
|
||||
protected:
|
||||
std::vector<std::reference_wrapper<const primitive_id>> get_dependencies() const override {
|
||||
if (mean.empty())
|
||||
|
@ -34,6 +34,15 @@ struct reorg_yolo : public primitive_base<reorg_yolo> {
|
||||
seed = hash_combine(seed, stride);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const reorg_yolo>(rhs);
|
||||
|
||||
return stride == rhs_casted.stride;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
#pragma once
|
||||
|
@ -165,5 +165,27 @@ struct resample : public primitive_base<resample> {
|
||||
seed = hash_combine(seed, round_mode);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const resample>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(num_filter) &&
|
||||
cmp_fields(sizes) &&
|
||||
cmp_fields(scales) &&
|
||||
cmp_fields(axes) &&
|
||||
cmp_fields(pads_begin) &&
|
||||
cmp_fields(pads_end) &&
|
||||
cmp_fields(operation_type) &&
|
||||
cmp_fields(shape_calc_mode) &&
|
||||
cmp_fields(antialias) &&
|
||||
cmp_fields(cube_coeff) &&
|
||||
cmp_fields(coord_trans_mode) &&
|
||||
cmp_fields(round_mode);
|
||||
#undef cmp_fields
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -79,6 +79,16 @@ struct reshape : public primitive_base<reshape> {
|
||||
ov::PartialShape output_partial_shape;
|
||||
|
||||
reshape_mode mode;
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const reshape>(rhs);
|
||||
|
||||
return special_zero == rhs_casted.special_zero &&
|
||||
mode == rhs_casted.mode;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cldnn
|
||||
|
@ -33,5 +33,14 @@ struct reverse : public primitive_base<reverse> {
|
||||
seed = hash_combine(seed, mode);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const reverse>(rhs);
|
||||
|
||||
return mode == rhs_casted.mode;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -58,5 +58,15 @@ struct reverse_sequence : public primitive_base<reverse_sequence> {
|
||||
seed = hash_combine(seed, batch_axis);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const reverse_sequence>(rhs);
|
||||
|
||||
return seq_axis == rhs_casted.seq_axis &&
|
||||
batch_axis == rhs_casted.batch_axis;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -68,5 +68,19 @@ struct roi_align : public primitive_base<roi_align> {
|
||||
seed = hash_combine(seed, aligned_mode);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const roi_align>(rhs);
|
||||
|
||||
return pooled_h == rhs_casted.pooled_h &&
|
||||
pooled_w == rhs_casted.pooled_w &&
|
||||
sampling_ratio == rhs_casted.sampling_ratio &&
|
||||
spatial_scale == rhs_casted.spatial_scale &&
|
||||
pooling_mode == rhs_casted.pooling_mode &&
|
||||
aligned_mode == rhs_casted.aligned_mode;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -96,6 +96,28 @@ struct roi_pooling : public primitive_base<roi_pooling> {
|
||||
seed = hash_combine(seed, spatial_bins_y);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const roi_pooling>(rhs);
|
||||
|
||||
#define cmp_fields(name) name == rhs_casted.name
|
||||
return cmp_fields(mode) &&
|
||||
cmp_fields(position_sensitive) &&
|
||||
cmp_fields(pooled_width) &&
|
||||
cmp_fields(pooled_height) &&
|
||||
cmp_fields(spatial_scale) &&
|
||||
cmp_fields(trans_std) &&
|
||||
cmp_fields(no_trans) &&
|
||||
cmp_fields(output_dim) &&
|
||||
cmp_fields(part_size) &&
|
||||
cmp_fields(group_size) &&
|
||||
cmp_fields(spatial_bins_x) &&
|
||||
cmp_fields(spatial_bins_y);
|
||||
#undef cmp_fields
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cldnn
|
||||
|
@ -32,6 +32,15 @@ struct roll : primitive_base<roll> {
|
||||
seed = hash_combine(seed, shift.hash());
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const roll>(rhs);
|
||||
|
||||
return shift == rhs_casted.shift;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cldnn
|
||||
|
@ -34,5 +34,14 @@ struct scatter_elements_update : public primitive_base<scatter_elements_update>
|
||||
seed = hash_combine(seed, axis);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const scatter_elements_update>(rhs);
|
||||
|
||||
return axis == rhs_casted.axis;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -34,5 +34,14 @@ struct scatter_nd_update : public primitive_base<scatter_nd_update> {
|
||||
seed = hash_combine(seed, indices_rank);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const scatter_nd_update>(rhs);
|
||||
|
||||
return indices_rank == rhs_casted.indices_rank;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -43,5 +43,14 @@ struct scatter_update : public primitive_base<scatter_update> {
|
||||
seed = hash_combine(seed, axis);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const scatter_update>(rhs);
|
||||
|
||||
return axis == rhs_casted.axis;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -40,5 +40,14 @@ struct select : public primitive_base<select> {
|
||||
|
||||
/// @brief Define auto broadcast rule specification.
|
||||
ov::op::AutoBroadcastSpec broadcast_spec;
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const select>(rhs);
|
||||
|
||||
return broadcast_spec == rhs_casted.broadcast_spec;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -36,5 +36,14 @@ struct shape_of : public primitive_base<shape_of> {
|
||||
, output_rank(0) {}
|
||||
|
||||
size_t output_rank;
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const shape_of>(rhs);
|
||||
|
||||
return output_rank == rhs_casted.output_rank;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -36,5 +36,15 @@ struct shuffle_channels : public primitive_base<shuffle_channels> {
|
||||
seed = hash_combine(seed, axis);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const shuffle_channels>(rhs);
|
||||
|
||||
return group == rhs_casted.group &&
|
||||
axis == rhs_casted.axis;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -16,13 +16,17 @@ struct slice : public primitive_base<slice> {
|
||||
/// @param id This primitive id.
|
||||
/// @param inputs List of primitive ids.
|
||||
slice(const primitive_id& id,
|
||||
const std::vector<input_info>& inputs,
|
||||
const tensor output_shape,
|
||||
const padding& output_padding = padding())
|
||||
const std::vector<input_info>& inputs,
|
||||
const tensor output_shape,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base{id, inputs, {output_padding}},
|
||||
output_shape {output_shape}
|
||||
{}
|
||||
|
||||
tensor output_shape;
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
return compare_common_params(rhs);
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -44,5 +44,14 @@ struct softmax : public primitive_base<softmax> {
|
||||
seed = hash_combine(seed, dimension);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const softmax>(rhs);
|
||||
|
||||
return dimension == rhs_casted.dimension;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -68,5 +68,16 @@ struct space_to_batch : public primitive_base<space_to_batch> {
|
||||
seed = hash_combine(seed, pads_end.hash());
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const space_to_batch>(rhs);
|
||||
|
||||
return block_shape == rhs_casted.block_shape &&
|
||||
pads_begin == rhs_casted.pads_begin &&
|
||||
pads_end == rhs_casted.pads_end;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -71,5 +71,15 @@ struct space_to_depth : public primitive_base<space_to_depth> {
|
||||
seed = hash_combine(seed, block_size);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const space_to_depth>(rhs);
|
||||
|
||||
return mode == rhs_casted.mode &&
|
||||
block_size == rhs_casted.block_size;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -57,6 +57,15 @@ struct split : public primitive_base<split> {
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const split>(rhs);
|
||||
|
||||
return output_offsets == rhs_casted.output_offsets;
|
||||
}
|
||||
|
||||
protected:
|
||||
static std::vector<primitive_id> extract_primitive_vector(
|
||||
const std::vector<std::pair<primitive_id, tensor> >& stor) {
|
||||
|
@ -111,5 +111,21 @@ struct strided_slice : public primitive_base<strided_slice> {
|
||||
seed = hash_range(seed, shrink_axis_mask.begin(), shrink_axis_mask.end());
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const strided_slice>(rhs);
|
||||
|
||||
return begin == rhs_casted.begin &&
|
||||
end == rhs_casted.end &&
|
||||
strides == rhs_casted.strides &&
|
||||
begin_mask == rhs_casted.begin_mask &&
|
||||
end_mask == rhs_casted.end_mask &&
|
||||
new_axis_mask == rhs_casted.new_axis_mask &&
|
||||
shrink_axis_mask == rhs_casted.shrink_axis_mask &&
|
||||
ellipsis_mask == rhs_casted.ellipsis_mask;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -38,5 +38,14 @@ struct tile : public primitive_base<tile> {
|
||||
seed = hash_range(seed, repeats.begin(), repeats.end());
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool operator==(const primitive& rhs) const override {
|
||||
if (!compare_common_params(rhs))
|
||||
return false;
|
||||
|
||||
auto rhs_casted = downcast<const tile>(rhs);
|
||||
|
||||
return repeats == rhs_casted.repeats;
|
||||
}
|
||||
};
|
||||
} // namespace cldnn
|
||||
|
@ -140,7 +140,7 @@ inline derived_type& downcast(base_type& base) {
|
||||
} catch (std::bad_cast& /* ex */) {
|
||||
throw std::runtime_error("Unable to cast reference from base to derived type");
|
||||
}
|
||||
throw std::runtime_error("downcast failed with unhadnled exception");
|
||||
throw std::runtime_error("downcast failed with unhandled exception");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -0,0 +1,111 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_utils.h"
|
||||
|
||||
#include <intel_gpu/primitives/range.hpp>
|
||||
#include <intel_gpu/primitives/convolution.hpp>
|
||||
#include <intel_gpu/primitives/gemm.hpp>
|
||||
#include <intel_gpu/primitives/fully_connected.hpp>
|
||||
#include <intel_gpu/primitives/gather.hpp>
|
||||
#include <intel_gpu/primitives/permute.hpp>
|
||||
|
||||
using namespace cldnn;
|
||||
using namespace ::tests;
|
||||
|
||||
TEST(primitive_comparison, common_params) {
|
||||
auto def_inputs = {input_info("input0"), input_info("input1"), input_info("input2")};
|
||||
auto def_shape = ov::PartialShape{1, 2, 3, 4};
|
||||
auto def_data_type = data_types::f32;
|
||||
auto def_format = format::bfyx;
|
||||
auto def_padding = padding({1, 1, 1, 1});
|
||||
|
||||
auto fc_prim = fully_connected("fc", input_info("input"), "weights");
|
||||
|
||||
auto range_prim = range("range", def_inputs, layout{def_shape, def_data_type, def_format, def_padding});
|
||||
|
||||
auto range_prim_inputs = range("range", {input_info("input0"), input_info("input1")}, layout{def_shape, def_data_type, def_format, def_padding});
|
||||
|
||||
auto range_prim_data_type = range("range", def_inputs, layout{def_shape, data_types::f16, def_format, def_padding});
|
||||
|
||||
auto range_prim_padding_values = range("range", def_inputs, layout{def_shape, def_data_type, def_format, padding({1, 1, 1, 2})});
|
||||
|
||||
auto range_prim_padding_fill_value = range("range", def_inputs, layout{def_shape, def_data_type, def_format, padding({1, 1, 1, 1}, 1.f)});
|
||||
|
||||
ASSERT_NE(range_prim, fc_prim);
|
||||
ASSERT_NE(range_prim, range_prim_inputs);
|
||||
ASSERT_NE(range_prim, range_prim_data_type);
|
||||
ASSERT_NE(range_prim, range_prim_padding_values);
|
||||
ASSERT_NE(range_prim, range_prim_padding_fill_value);
|
||||
}
|
||||
|
||||
TEST(primitive_comparison, convolution) {
|
||||
auto conv_prim = convolution("conv", input_info("input"), {"weights"}, {"bias"}, 1,
|
||||
{2, 2}, {0, 0}, {1, 1}, {1, 3, 224, 224}, data_types::f32, false);
|
||||
|
||||
auto conv_prim_eq = convolution("conv_eq", input_info("input_eq"), {"weights_eq"}, {"bias_eq"}, 1,
|
||||
{2, 2}, {0, 0}, {1, 1}, {1, 3, 224, 224}, data_types::f32, false);
|
||||
|
||||
auto conv_prim_stride = convolution("conv", input_info("input"), {"weights"}, {"bias"}, 1,
|
||||
{1, 1}, {0, 0}, {1, 1}, {1, 3, 224, 224}, data_types::f32, false);
|
||||
|
||||
auto conv_prim_no_bias = convolution("conv", input_info("input"), {"weights"}, {}, 1,
|
||||
{2, 2}, {0, 0}, {1, 1}, {1, 3, 224, 224}, data_types::f32, false);
|
||||
|
||||
auto conv_prim_grouped = convolution("conv", input_info("input"), {"weights"}, {"bias"}, 2,
|
||||
{2, 2}, {0, 0}, {1, 1}, {1, 3, 224, 224}, data_types::f32, true);
|
||||
|
||||
ASSERT_EQ(conv_prim, conv_prim_eq);
|
||||
ASSERT_NE(conv_prim, conv_prim_stride);
|
||||
ASSERT_NE(conv_prim, conv_prim_no_bias);
|
||||
ASSERT_NE(conv_prim, conv_prim_grouped);
|
||||
}
|
||||
|
||||
TEST(primitive_comparison, gemm) {
|
||||
auto def_inputs = {input_info("input0"), input_info("input1")};
|
||||
|
||||
auto gemm_prim = gemm("gemm", def_inputs, data_types::f32);
|
||||
auto gemm_prim_eq = gemm("gemm_eq", {input_info("input0_eq"), input_info("input1_eq")}, data_types::f32);
|
||||
auto gemm_prim_rank = gemm("gemm", def_inputs, data_types::f32, false, false, 1.0f, 0.0f, 2, 2);
|
||||
auto gemm_prim_alpha = gemm("gemm", def_inputs, data_types::f32, false, false, 1.5f);
|
||||
auto gemm_prim_transpose = gemm("gemm", def_inputs, data_types::f32, true);
|
||||
|
||||
ASSERT_EQ(gemm_prim, gemm_prim_eq);
|
||||
ASSERT_NE(gemm_prim, gemm_prim_rank);
|
||||
ASSERT_NE(gemm_prim, gemm_prim_alpha);
|
||||
ASSERT_NE(gemm_prim, gemm_prim_transpose);
|
||||
}
|
||||
|
||||
TEST(primitive_comparison, fully_connected) {
|
||||
auto fc_prim = fully_connected("fc", input_info("input"), "weights", "bias", {}, 2);
|
||||
auto fc_prim_eq = fully_connected("fc_eq", input_info("input_eq"), "weights_eq", "bias_eq", {}, 2);
|
||||
auto fc_prim_bias = fully_connected("fc", input_info("input"), "weights", "", {}, 2);
|
||||
auto fc_prim_input_size = fully_connected("fc", input_info("input"), "weights", "bias", {}, 4);
|
||||
|
||||
ASSERT_EQ(fc_prim, fc_prim_eq);
|
||||
ASSERT_NE(fc_prim, fc_prim_bias);
|
||||
ASSERT_NE(fc_prim, fc_prim_input_size);
|
||||
}
|
||||
|
||||
TEST(primitive_comparison, gather) {
|
||||
auto gather_prim = gather("gather", input_info("input0"), input_info("input1"), 2, {1, 3, 224, 224}, 1, true);
|
||||
auto gather_prim_eq = gather("gather_eq", input_info("input0_eq"), input_info("input1_eq"), 2, {1, 3, 224, 224}, 1, true);
|
||||
auto gather_prim_axis = gather("gather", input_info("input0"), input_info("input1"), 3, {1, 3, 224, 224}, 1, true);
|
||||
auto gather_prim_batch_dim = gather("gather", input_info("input0"), input_info("input1"), 2, {1, 3, 224, 224}, 2, true);
|
||||
auto gather_prim_support_neg_ind = gather("gather", input_info("input0"), input_info("input1"), 2, {1, 3, 224, 224}, 1, false);
|
||||
|
||||
ASSERT_EQ(gather_prim, gather_prim_eq);
|
||||
ASSERT_NE(gather_prim, gather_prim_axis);
|
||||
ASSERT_NE(gather_prim, gather_prim_batch_dim);
|
||||
ASSERT_NE(gather_prim, gather_prim_support_neg_ind);
|
||||
}
|
||||
|
||||
TEST(primitive_comparison, permute) {
|
||||
auto permute_prim = permute("permute", input_info("input"), {0, 1, 2, 3});
|
||||
auto permute_prim_eq = permute("permute_eq", input_info("input_eq"), {0, 1, 2, 3});
|
||||
auto permute_prim_order = permute("permute", input_info("input"), {3, 2, 1, 0});
|
||||
|
||||
ASSERT_EQ(permute_prim, permute_prim_eq);
|
||||
ASSERT_NE(permute_prim, permute_prim_order);
|
||||
}
|
Loading…
Reference in New Issue
Block a user