[CPU] Fix ScaledDotProductAttension accuracy problem (#21217)
This commit is contained in:
parent
493a338ad2
commit
91660b1c05
@ -123,7 +123,7 @@ struct MHAKernel {
|
|||||||
if (auto_causal)
|
if (auto_causal)
|
||||||
ncausal = kv_len - q_len + m + 1;
|
ncausal = kv_len - q_len + m + 1;
|
||||||
for (size_t n = 0; n < ncausal; n++) {
|
for (size_t n = 0; n < ncausal; n++) {
|
||||||
auto* k = &present_key.at({b, h, n, 0});
|
auto* k = &present_key.at({b, h, n, 0}, true);
|
||||||
attn_score[n] = dot_product(q, k, head_size, k_stride_s) * d_scale;
|
attn_score[n] = dot_product(q, k, head_size, k_stride_s) * d_scale;
|
||||||
|
|
||||||
// apply alibi tensor
|
// apply alibi tensor
|
||||||
@ -154,7 +154,7 @@ struct MHAKernel {
|
|||||||
// linearly combine value
|
// linearly combine value
|
||||||
word_vec.assign(head_size, 0.0f);
|
word_vec.assign(head_size, 0.0f);
|
||||||
for (size_t n = 0; n < ncausal; n++) {
|
for (size_t n = 0; n < ncausal; n++) {
|
||||||
auto* v = &present_value.at({b, h, n, 0});
|
auto* v = &present_value.at({b, h, n, 0}, true);
|
||||||
accumulate(word_vec.data(), v, head_size, attn_score[n]);
|
accumulate(word_vec.data(), v, head_size, attn_score[n]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -183,7 +183,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
|
|||||||
using tag = dnnl::memory::format_tag;
|
using tag = dnnl::memory::format_tag;
|
||||||
using dt = dnnl::memory::data_type;
|
using dt = dnnl::memory::data_type;
|
||||||
|
|
||||||
void prepare_prim(dnnl::stream strm, size_t B, size_t H, size_t q_len, size_t kv_len, size_t S, bool has_out_transpose) {
|
void prepare_prim(dnnl::stream strm, size_t B, size_t H, size_t Hk, size_t q_len, size_t kv_len, size_t S, bool has_out_transpose) {
|
||||||
auto make_dnnl_dims = [](const std::vector<size_t>& dims) {
|
auto make_dnnl_dims = [](const std::vector<size_t>& dims) {
|
||||||
dnnl::memory::dims dnnl_dims(dims.size());
|
dnnl::memory::dims dnnl_dims(dims.size());
|
||||||
for (size_t i = 0; i < dims.size(); i++)
|
for (size_t i = 0; i < dims.size(); i++)
|
||||||
@ -192,7 +192,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
|
|||||||
};
|
};
|
||||||
auto qkv_dt = precision_of<T>::value == ov::element::f32 ? dt::f32 : dt::bf16;
|
auto qkv_dt = precision_of<T>::value == ov::element::f32 ? dt::f32 : dt::bf16;
|
||||||
dnnl::memory::desc cur_q_md(make_dnnl_dims({B, H, q_len, S}), qkv_dt, tag::abcd);
|
dnnl::memory::desc cur_q_md(make_dnnl_dims({B, H, q_len, S}), qkv_dt, tag::abcd);
|
||||||
dnnl::memory::desc cur_k_md(make_dnnl_dims({B, H, kv_len, S}), qkv_dt, tag::abcd);
|
dnnl::memory::desc cur_k_md(make_dnnl_dims({B, Hk, kv_len, S}), qkv_dt, tag::abcd);
|
||||||
if (cur_q_md == q_md && cur_k_md == k_md)
|
if (cur_q_md == q_md && cur_k_md == k_md)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
@ -204,7 +204,7 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
|
|||||||
qk_prim = dnnl::matmul(qk_pd);
|
qk_prim = dnnl::matmul(qk_pd);
|
||||||
|
|
||||||
weight_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, kv_len}), qkv_dt, tag::abcd);
|
weight_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, kv_len}), qkv_dt, tag::abcd);
|
||||||
v_md = dnnl::memory::desc(make_dnnl_dims({B, H, kv_len, S}), qkv_dt, tag::abcd);
|
v_md = dnnl::memory::desc(make_dnnl_dims({B, Hk, kv_len, S}), qkv_dt, tag::abcd);
|
||||||
out_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, S}), qkv_dt, tag::abcd);
|
out_md = dnnl::memory::desc(make_dnnl_dims({B, H, q_len, S}), qkv_dt, tag::abcd);
|
||||||
if (has_out_transpose)
|
if (has_out_transpose)
|
||||||
out_md = out_md.permute_axes({0, 2, 1, 3});
|
out_md = out_md.permute_axes({0, 2, 1, 3});
|
||||||
@ -259,12 +259,13 @@ struct MHAKernel<ScaledDotProductAttention::KT_ONEDNN, T> {
|
|||||||
auto H = query.size(1);
|
auto H = query.size(1);
|
||||||
auto q_len = query.size(2);
|
auto q_len = query.size(2);
|
||||||
auto head_size = query.size(3);
|
auto head_size = query.size(3);
|
||||||
|
auto Hk = present_key.size(1);
|
||||||
auto kv_len = present_key.size(2);
|
auto kv_len = present_key.size(2);
|
||||||
|
|
||||||
if (d_scale == 0.0f)
|
if (d_scale == 0.0f)
|
||||||
d_scale = 1.0f / sqrt(head_size);
|
d_scale = 1.0f / sqrt(head_size);
|
||||||
|
|
||||||
prepare_prim(strm, B, H, q_len, kv_len, head_size, has_out_transpose);
|
prepare_prim(strm, B, H, Hk, q_len, kv_len, head_size, has_out_transpose);
|
||||||
exec_qk(strm, query, present_key);
|
exec_qk(strm, query, present_key);
|
||||||
|
|
||||||
PlainTensor<float> score;
|
PlainTensor<float> score;
|
||||||
@ -495,7 +496,7 @@ struct MHASingleToken {
|
|||||||
std::vector<float*> cs(q_len);
|
std::vector<float*> cs(q_len);
|
||||||
for (size_t pq = 0; pq < q_len; pq++) {
|
for (size_t pq = 0; pq < q_len; pq++) {
|
||||||
as[pq] = &query.at({b, h, pq, 0});
|
as[pq] = &query.at({b, h, pq, 0});
|
||||||
bs[pq] = &present_key.at({b_kv, h, pk, 0});
|
bs[pq] = &present_key.at({b_kv, h, pk, 0}, true);
|
||||||
cs[pq] = &m_attn_w.at({b, h, pq, pk});
|
cs[pq] = &m_attn_w.at({b, h, pq, pk});
|
||||||
}
|
}
|
||||||
attn_dot_products(reinterpret_cast<void**>(as.data()),
|
attn_dot_products(reinterpret_cast<void**>(as.data()),
|
||||||
@ -543,7 +544,7 @@ struct MHASingleToken {
|
|||||||
size_t idx = 0;
|
size_t idx = 0;
|
||||||
for (size_t iwork = start; iwork < end; ++iwork) {
|
for (size_t iwork = start; iwork < end; ++iwork) {
|
||||||
auto b_kv = beams ? beams.at({b, pv}) : b;
|
auto b_kv = beams ? beams.at({b, pv}) : b;
|
||||||
auto* v = &present_value.at({b_kv, h, pv, 0});
|
auto* v = &present_value.at({b_kv, h, pv, 0}, true);
|
||||||
for (size_t pq = 0; pq < q_len; pq++) {
|
for (size_t pq = 0; pq < q_len; pq++) {
|
||||||
outs[idx] = &m_temp.at({ithr, b, pq, h, 0});
|
outs[idx] = &m_temp.at({ithr, b, pq, h, 0});
|
||||||
weights[idx] = m_attn_w.at({b, h, pq, pv});
|
weights[idx] = m_attn_w.at({b, h, pq, pv});
|
||||||
@ -636,8 +637,8 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
|
|||||||
PlainTensor<T> present_key, present_value;
|
PlainTensor<T> present_key, present_value;
|
||||||
|
|
||||||
q_input.assert_dims({B, H, L1, S});
|
q_input.assert_dims({B, H, L1, S});
|
||||||
k_input.assert_dims({B, H, L0 + L1, S});
|
k_input.assert_dims({B, 0, L0 + L1, S}, true);
|
||||||
v_input.assert_dims({B, H, L0 + L1, S});
|
v_input.assert_dims({B, 0, L0 + L1, S}, true);
|
||||||
m_query_emb = q_input;
|
m_query_emb = q_input;
|
||||||
present_key = k_input;
|
present_key = k_input;
|
||||||
present_value = v_input;
|
present_value = v_input;
|
||||||
@ -657,9 +658,11 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
|
|||||||
// no attn_mask but has scale, there is a 1-d fake attn_mask
|
// no attn_mask but has scale, there is a 1-d fake attn_mask
|
||||||
if (input_num > 3 && attn_mask.m_rank > 1) {
|
if (input_num > 3 && attn_mask.m_rank > 1) {
|
||||||
assert(attn_mask);
|
assert(attn_mask);
|
||||||
auto num = std::accumulate(attn_mask.m_dims, attn_mask.m_dims + attn_mask.m_rank, size_t{1}, std::multiplies<size_t>());
|
// spec requires at least 3, but torch sl test does use rank 2
|
||||||
num /= B * (L0 + L1);
|
if (attn_mask.m_rank == 2)
|
||||||
attn_mask = attn_mask.reshape({B, 1, num, L0 + L1});
|
attn_mask = attn_mask.reshape({1, 1, attn_mask.m_dims[0], attn_mask.m_dims[1]});
|
||||||
|
else if (attn_mask.m_rank == 3)
|
||||||
|
attn_mask = attn_mask.reshape({1, attn_mask.m_dims[0], attn_mask.m_dims[1], attn_mask.m_dims[2]});
|
||||||
auto_causal = false;
|
auto_causal = false;
|
||||||
use_attn_mask = true;
|
use_attn_mask = true;
|
||||||
} else {
|
} else {
|
||||||
@ -758,12 +761,19 @@ bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr<const
|
|||||||
errorMessage = "Only ScaledDotProductAttention operation are supported";
|
errorMessage = "Only ScaledDotProductAttention operation are supported";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// expect shape: [B, H, L, S]
|
// expect shape of q: [B, H, L, S]
|
||||||
const auto inRank = op->get_input_partial_shape(0).size();
|
auto inRank = op->get_input_partial_shape(0).size();
|
||||||
if (inRank != 4u) {
|
if (inRank != 4u) {
|
||||||
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank);
|
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (op->get_input_size() > 3) {
|
||||||
|
inRank = op->get_input_partial_shape(3).size();
|
||||||
|
if (inRank > 4u) {
|
||||||
|
errorMessage = "Doesn't support 'attention mask' with rank: " + std::to_string(inRank);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
// using mha should be better for static shapes
|
// using mha should be better for static shapes
|
||||||
if (!op->is_dynamic()) {
|
if (!op->is_dynamic()) {
|
||||||
errorMessage = "Only run in dynamic mode";
|
errorMessage = "Only run in dynamic mode";
|
||||||
|
@ -18,21 +18,25 @@ namespace CPULayerTestsDefinitions {
|
|||||||
std::string ScaledAttnLayerCPUTest::getTestCaseName(const testing::TestParamInfo<ScaledAttnCPUTestParams>& obj) {
|
std::string ScaledAttnLayerCPUTest::getTestCaseName(const testing::TestParamInfo<ScaledAttnCPUTestParams>& obj) {
|
||||||
CPUSpecificParams cpuParams;
|
CPUSpecificParams cpuParams;
|
||||||
ElementType inType;
|
ElementType inType;
|
||||||
InputShape inputShape;
|
std::vector<InputShape> inputShapes;
|
||||||
bool is_causal;
|
bool is_causal;
|
||||||
bool has_attn;
|
bool has_attn;
|
||||||
bool has_scale;
|
bool has_scale;
|
||||||
std::string targetDevice;
|
std::string targetDevice;
|
||||||
std::tie(inType, inputShape, is_causal, has_attn, has_scale, targetDevice, cpuParams) = obj.param;
|
std::tie(inType, inputShapes, is_causal, has_attn, has_scale, targetDevice, cpuParams) = obj.param;
|
||||||
|
|
||||||
std::ostringstream result;
|
std::ostringstream result;
|
||||||
result << "netPRC=" << inType << "_";
|
result << "netPRC=" << inType << "_";
|
||||||
result << "IS=" << ov::test::utils::partialShape2str({inputShape.first}) << "_";
|
result << "IS=";
|
||||||
|
for (const auto& inputShape : inputShapes) {
|
||||||
|
result << ov::test::utils::partialShape2str({inputShape.first}) << "_";
|
||||||
|
}
|
||||||
result << "TS=";
|
result << "TS=";
|
||||||
for (const auto& shape : inputShape.second) {
|
for (const auto& shapes : inputShapes) {
|
||||||
result << "(";
|
for (const auto& shape : shapes.second) {
|
||||||
result << ov::test::utils::vec2str(shape);
|
result << ov::test::utils::vec2str(shape);
|
||||||
result << ")_";
|
result << "_";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
result << "is_causal=" << is_causal << "_";
|
result << "is_causal=" << is_causal << "_";
|
||||||
result << "has_attn=" << has_attn << "_";
|
result << "has_attn=" << has_attn << "_";
|
||||||
@ -46,8 +50,8 @@ std::string ScaledAttnLayerCPUTest::getTestCaseName(const testing::TestParamInfo
|
|||||||
void ScaledAttnLayerCPUTest::SetUp() {
|
void ScaledAttnLayerCPUTest::SetUp() {
|
||||||
ElementType inType;
|
ElementType inType;
|
||||||
CPUSpecificParams cpuParams;
|
CPUSpecificParams cpuParams;
|
||||||
InputShape inputShape;
|
std::vector<InputShape> inputShapes;
|
||||||
std::tie(inType, inputShape, is_causal, has_attn, has_scale, targetDevice, cpuParams) = this->GetParam();
|
std::tie(inType, inputShapes, is_causal, has_attn, has_scale, targetDevice, cpuParams) = this->GetParam();
|
||||||
|
|
||||||
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
|
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
|
||||||
if (selectedType.empty()) {
|
if (selectedType.empty()) {
|
||||||
@ -58,12 +62,12 @@ void ScaledAttnLayerCPUTest::SetUp() {
|
|||||||
rel_threshold = 2e-2f;
|
rel_threshold = 2e-2f;
|
||||||
}
|
}
|
||||||
selectedType = makeSelectedTypeStr(selectedType, inType);
|
selectedType = makeSelectedTypeStr(selectedType, inType);
|
||||||
init_input_shapes({inputShape});
|
init_input_shapes(inputShapes);
|
||||||
ov::ParameterVector inputParams;
|
ov::ParameterVector inputParams;
|
||||||
// q,k,v
|
// q,k,v
|
||||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
|
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
|
||||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
|
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
|
||||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
|
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
|
||||||
inputParams[0]->set_friendly_name("q");
|
inputParams[0]->set_friendly_name("q");
|
||||||
inputParams[1]->set_friendly_name("k");
|
inputParams[1]->set_friendly_name("k");
|
||||||
inputParams[2]->set_friendly_name("v");
|
inputParams[2]->set_friendly_name("v");
|
||||||
@ -77,9 +81,7 @@ void ScaledAttnLayerCPUTest::SetUp() {
|
|||||||
inputParams.back()->set_friendly_name("scale");
|
inputParams.back()->set_friendly_name("scale");
|
||||||
} else {
|
} else {
|
||||||
if (has_attn) {
|
if (has_attn) {
|
||||||
// attention_mask:[B, 1, 1, L0+L1]
|
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[2]));
|
||||||
ov::PartialShape attnShape{inputDynamicShapes[0][0], 1, 1, -1};
|
|
||||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, attnShape));
|
|
||||||
inputParams.back()->set_friendly_name("attention_mask");
|
inputParams.back()->set_friendly_name("attention_mask");
|
||||||
}
|
}
|
||||||
if (has_scale) {
|
if (has_scale) {
|
||||||
@ -106,14 +108,14 @@ void ScaledAttnLayerCPUTest::SetUp() {
|
|||||||
void ScaledAttnLayerCPUTest::generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) {
|
void ScaledAttnLayerCPUTest::generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) {
|
||||||
std::vector<ov::Shape> shapes(3);
|
std::vector<ov::Shape> shapes(3);
|
||||||
shapes[0] = targetInputStaticShapes[0];
|
shapes[0] = targetInputStaticShapes[0];
|
||||||
shapes[1] = targetInputStaticShapes[0];
|
shapes[1] = targetInputStaticShapes[1];
|
||||||
shapes[2] = targetInputStaticShapes[0];
|
shapes[2] = targetInputStaticShapes[1];
|
||||||
if (!has_attn && has_scale) {
|
if (!has_attn && has_scale) {
|
||||||
shapes.push_back(ov::Shape{});
|
shapes.push_back(ov::Shape{});
|
||||||
shapes.push_back(ov::Shape{1});
|
shapes.push_back(ov::Shape{1});
|
||||||
} else {
|
} else {
|
||||||
if (has_attn) {
|
if (has_attn) {
|
||||||
shapes.push_back(ov::Shape{targetInputStaticShapes[0][0], 1, 1, targetInputStaticShapes[0][2]});
|
shapes.push_back(targetInputStaticShapes[2]);
|
||||||
}
|
}
|
||||||
if (has_scale) {
|
if (has_scale) {
|
||||||
shapes.push_back(ov::Shape{1});
|
shapes.push_back(ov::Shape{1});
|
||||||
|
@ -13,12 +13,12 @@ using namespace ov::test;
|
|||||||
|
|
||||||
namespace CPULayerTestsDefinitions {
|
namespace CPULayerTestsDefinitions {
|
||||||
|
|
||||||
typedef std::tuple<ElementType, // netPrecision
|
typedef std::tuple<ElementType, // netPrecision
|
||||||
InputShape, // shape
|
std::vector<InputShape>, // shape
|
||||||
bool, // is_causal
|
bool, // is_causal
|
||||||
bool, // has_attn
|
bool, // has_attn
|
||||||
bool, // has_scale
|
bool, // has_scale
|
||||||
std::string, // targetDevice
|
std::string, // targetDevice
|
||||||
CPUSpecificParams>
|
CPUSpecificParams>
|
||||||
ScaledAttnCPUTestParams;
|
ScaledAttnCPUTestParams;
|
||||||
|
|
||||||
|
@ -14,10 +14,52 @@ namespace CPULayerTestsDefinitions {
|
|||||||
namespace ScaledAttn {
|
namespace ScaledAttn {
|
||||||
const auto cpuSpec = CPUSpecificParams{{}, {}, {"ref_any"}, "ref_any"};
|
const auto cpuSpec = CPUSpecificParams{{}, {}, {"ref_any"}, "ref_any"};
|
||||||
|
|
||||||
const std::vector<InputShape> shapes{
|
const std::vector<std::vector<InputShape>> shapes{
|
||||||
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
|
// normal case, shapes of q,k,v are same
|
||||||
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
|
{
|
||||||
}
|
// q shape
|
||||||
|
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
|
||||||
|
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
|
||||||
|
},
|
||||||
|
// kv shape
|
||||||
|
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
|
||||||
|
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
|
||||||
|
},
|
||||||
|
// attn shape: [B, 1, 1, L0+L1]
|
||||||
|
{ov::test::InputShape{ov::PartialShape{-1, 1, 1, -1},
|
||||||
|
{ov::Shape{1, 1, 1, 100}, ov::Shape{1, 1, 1, 1}, ov::Shape{2, 1, 1, 10}}}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// heads number of kv is 1, attn mask: [B, H, L1, L0+L1]
|
||||||
|
{
|
||||||
|
// q shape
|
||||||
|
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
|
||||||
|
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
|
||||||
|
},
|
||||||
|
// kv shape
|
||||||
|
{ov::test::InputShape{ov::PartialShape{-1, 1, -1, 64},
|
||||||
|
{ov::Shape{1, 1, 100, 64}, ov::Shape{1, 1, 1, 64}, ov::Shape{2, 1, 10, 64}}}
|
||||||
|
},
|
||||||
|
// attn shape
|
||||||
|
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, -1},
|
||||||
|
{ov::Shape{1, 8, 100, 100}, ov::Shape{1, 8, 1, 1}, ov::Shape{2, 8, 10, 10}}}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// heads number of kv is 1, attn mask: [H, L1, L0+L1]
|
||||||
|
{
|
||||||
|
// q shape
|
||||||
|
{ov::test::InputShape{ov::PartialShape{-1, 8, -1, 64},
|
||||||
|
{ov::Shape{1, 8, 100, 64}, ov::Shape{1, 8, 1, 64}, ov::Shape{2, 8, 10, 64}}}
|
||||||
|
},
|
||||||
|
// kv shape
|
||||||
|
{ov::test::InputShape{ov::PartialShape{-1, 1, -1, 64},
|
||||||
|
{ov::Shape{1, 1, 100, 64}, ov::Shape{1, 1, 1, 64}, ov::Shape{2, 1, 10, 64}}}
|
||||||
|
},
|
||||||
|
// attn shape
|
||||||
|
{ov::test::InputShape{ov::PartialShape{8, -1, -1},
|
||||||
|
{ov::Shape{8, 100, 100}, ov::Shape{8, 1, 1}, ov::Shape{8, 10, 10}}}
|
||||||
|
},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto params = testing::Combine(testing::Values(ElementType::f32, ElementType::bf16),
|
const auto params = testing::Combine(testing::Values(ElementType::f32, ElementType::bf16),
|
||||||
|
@ -75,7 +75,8 @@ class TestTimmConvertModel(TestConvertModel):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("name", ["mobilevitv2_050.cvnets_in1k",
|
@pytest.mark.parametrize("name", ["mobilevitv2_050.cvnets_in1k",
|
||||||
"poolformerv2_s12.sail_in1k",
|
"poolformerv2_s12.sail_in1k",
|
||||||
"vit_base_patch8_224.augreg_in21k"])
|
"vit_base_patch8_224.augreg_in21k",
|
||||||
|
"beit_base_patch16_224.in22k_ft_in22k"])
|
||||||
@pytest.mark.precommit
|
@pytest.mark.precommit
|
||||||
def test_convert_model_precommit(self, name, ie_device):
|
def test_convert_model_precommit(self, name, ie_device):
|
||||||
self.run(name, None, ie_device)
|
self.run(name, None, ie_device)
|
||||||
|
Loading…
Reference in New Issue
Block a user