[PT FE] Add aten::log10 (#20621)
* Add log10 operator and test * fix * Update test_log.py --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
parent
04c766e9f1
commit
bc463e886b
@ -41,6 +41,18 @@ OutputVector translate_log2(const NodeContext& context) {
|
||||
return {res};
|
||||
};
|
||||
|
||||
OutputVector translate_log10(const NodeContext& context) {
|
||||
// torch.log10 returns a tensor with the logarithm to the base 10 of the elements of input.
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto x = context.get_input(0);
|
||||
auto ten = context.mark_node(v0::Constant::create(element::f32, Shape{}, {10}));
|
||||
x = context.mark_node(std::make_shared<v0::Convert>(x, element::f32));
|
||||
auto log10 = context.mark_node(std::make_shared<v0::Log>(ten));
|
||||
auto log = context.mark_node(std::make_shared<v0::Log>(x));
|
||||
auto res = context.mark_node(std::make_shared<v1::Divide>(log, log10));
|
||||
return {res};
|
||||
};
|
||||
|
||||
OutputVector translate_logsumexp(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 2);
|
||||
auto input = context.get_input(0);
|
||||
|
@ -104,6 +104,7 @@ OP_CONVERTER(translate_log);
|
||||
OP_CONVERTER(translate_log1p);
|
||||
OP_CONVERTER(translate_log_softmax);
|
||||
OP_CONVERTER(translate_log2);
|
||||
OP_CONVERTER(translate_log10);
|
||||
OP_CONVERTER(translate_logsumexp);
|
||||
OP_CONVERTER(translate_loop);
|
||||
OP_CONVERTER(translate_masked_fill);
|
||||
@ -387,6 +388,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::log1p_", op::inplace_op<op::translate_log1p>},
|
||||
{"aten::log2", op::translate_log2},
|
||||
{"aten::log2_", op::inplace_op<op::translate_log2>},
|
||||
{"aten::log10", op::translate_log10},
|
||||
{"aten::log10_", op::inplace_op<op::translate_log10>},
|
||||
{"aten::lt", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},
|
||||
{"aten::masked_fill", op::translate_masked_fill},
|
||||
{"aten::masked_fill_", op::inplace_op<op::translate_masked_fill>},
|
||||
|
@ -18,6 +18,8 @@ class TestLog(PytorchLayerTest):
|
||||
"log_": torch.log_,
|
||||
"log2": torch.log2,
|
||||
"log2_": torch.log2_,
|
||||
"log10": torch.log10,
|
||||
"log10_": torch.log10_,
|
||||
"log1p": torch.log1p,
|
||||
"log1p_": torch.log1p_
|
||||
}
|
||||
@ -45,6 +47,9 @@ class TestLog(PytorchLayerTest):
|
||||
["log2", "float32"],
|
||||
["log2", "int32"],
|
||||
["log2_", "float32"],
|
||||
["log10", "float32"],
|
||||
["log10", "int32"],
|
||||
["log10_", "float32"],
|
||||
["log1p", "float32"],
|
||||
["log1p", "int32"],
|
||||
["log1p_", "float32"]])
|
||||
|
Loading…
Reference in New Issue
Block a user