提交 ed797624 authored 作者: Adam Becker's avatar Adam Becker

add dtype checks for topk

上级 c89cab31
......@@ -250,6 +250,7 @@ class Test_TopK(unittest.TestCase):
xval = np.asarray([1]).astype(dtype)
yval = fn(xval)
assert yval == np.asarray([0], dtype=idx_dtype)
assert yval.dtype == np.dtype(idx_dtype)
@utt.parameterized.expand(product(
_dtypes, [-1, 0, None]))
......@@ -259,6 +260,7 @@ class Test_TopK(unittest.TestCase):
xval = np.asarray([1]).astype(dtype)
yval = fn(xval)
assert yval == xval
assert yval.dtype == xval.dtype
@utt.parameterized.expand(product(
_dtypes, _int_dtypes, [-1, 0, None]))
......@@ -270,6 +272,8 @@ class Test_TopK(unittest.TestCase):
yvval, yival = fn(xval)
assert yival == np.asarray([0], dtype=idx_dtype)
assert np.allclose(xval, yvval)
assert yvval.dtype == xval.dtype
assert yival.dtype == np.dtype(idx_dtype)
@utt.parameterized.expand(chain(
product(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论