large_batch_opt (#11951)
This commit is contained in:
parent
a4e6cda7e8
commit
1288706589
@ -1036,7 +1036,7 @@ void MVN::MVNJitExecutor::mvn_pln(const uint8_t* src_data, uint8_t* dst_data, co
|
|||||||
size_t src_stride_size = static_cast<size_t>(blk_size * src_data_size);
|
size_t src_stride_size = static_cast<size_t>(blk_size * src_data_size);
|
||||||
size_t dst_stride_size = static_cast<size_t>(blk_size * dst_data_size);
|
size_t dst_stride_size = static_cast<size_t>(blk_size * dst_data_size);
|
||||||
|
|
||||||
for (size_t b = 0lu; b < N; b++) {
|
parallel_for(N, [&](int b) {
|
||||||
size_t cb = b * C3;
|
size_t cb = b * C3;
|
||||||
if (mvnAttrs.execAcrossChannels_) {
|
if (mvnAttrs.execAcrossChannels_) {
|
||||||
// Calculate mean value for one instance in batch
|
// Calculate mean value for one instance in batch
|
||||||
@ -1153,7 +1153,7 @@ void MVN::MVNJitExecutor::mvn_pln(const uint8_t* src_data, uint8_t* dst_data, co
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void MVN::MVNRefExecutor::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) {
|
void MVN::MVNRefExecutor::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) {
|
||||||
@ -1166,7 +1166,7 @@ void MVN::MVNRefExecutor::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) {
|
|||||||
size_t C2 = C1 * D;
|
size_t C2 = C1 * D;
|
||||||
size_t C3 = C2 * C;
|
size_t C3 = C2 * C;
|
||||||
|
|
||||||
for (size_t b = 0lu; b < N; b++) {
|
parallel_for(N, [&](int b) {
|
||||||
size_t cb = b * C3;
|
size_t cb = b * C3;
|
||||||
if (mvnAttrs.execAcrossChannels_) {
|
if (mvnAttrs.execAcrossChannels_) {
|
||||||
// Parallel sum for each channel for mean
|
// Parallel sum for each channel for mean
|
||||||
@ -1251,7 +1251,7 @@ void MVN::MVNRefExecutor::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void MVN::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* dst_data, const void *post_ops_data_) {
|
void MVN::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* dst_data, const void *post_ops_data_) {
|
||||||
|
Loading…
Reference in New Issue
Block a user