Skip to content

Commit 4f56c95

Browse files
committed
Add ONNX torch_topk pnnx regression test
1 parent e6f594c commit 4f56c95

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

tools/pnnx/tests/onnx/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ pnnx_onnx_add_test(torch_split)
191191
pnnx_onnx_add_test(torch_squeeze)
192192
pnnx_onnx_add_test(torch_stack)
193193
pnnx_onnx_add_test(torch_sum)
194+
pnnx_onnx_add_test(torch_topk)
194195
pnnx_onnx_add_test(torch_transpose)
195196
pnnx_onnx_add_test(torch_unbind)
196197
pnnx_onnx_add_test(torch_unsqueeze)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2026 Tencent
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
8+
class Model(nn.Module):
9+
def __init__(self):
10+
super(Model, self).__init__()
11+
12+
def forward(self, x, y, z):
13+
x_values, x_indices = torch.topk(
14+
x, 2, dim=1, largest=True, sorted=True
15+
)
16+
y_values, y_indices = torch.topk(
17+
y, 4, dim=3, largest=False, sorted=True
18+
)
19+
z_values, z_indices = torch.topk(
20+
z, 3, dim=0, largest=True, sorted=True
21+
)
22+
return x_values, x_indices, y_values, y_indices, z_values, z_indices
23+
24+
25+
def test():
26+
net = Model()
27+
net.eval()
28+
29+
torch.manual_seed(0)
30+
x = torch.rand(1, 3, 16)
31+
y = torch.rand(1, 5, 9, 11)
32+
z = torch.rand(14, 8, 5, 9, 10)
33+
34+
a = net(x, y, z)
35+
36+
# export onnx
37+
torch.onnx.export(net, (x, y, z), "test_torch_topk.onnx")
38+
39+
# onnx to pnnx
40+
import os
41+
42+
os.system(
43+
"../../src/pnnx test_torch_topk.onnx "
44+
"inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]"
45+
)
46+
47+
# pnnx inference
48+
import test_torch_topk_pnnx
49+
b = test_torch_topk_pnnx.test_inference()
50+
51+
for a0, b0 in zip(a, b):
52+
if not torch.equal(a0, b0):
53+
return False
54+
return True
55+
56+
57+
if __name__ == "__main__":
58+
if test():
59+
exit(0)
60+
else:
61+
exit(1)

0 commit comments

Comments
 (0)