提交 1df1f91a authored 作者: Frederic Bastien's avatar Frederic Bastien

Simplified assertion following code review

上级 b89dd3da
......@@ -79,12 +79,8 @@ def test_gemv1():
assert numpy.allclose(no_gpu_f(), gpu_f(), atol = atol)
assert numpy.allclose(no_gpu_f(), gpu_f2(), atol = atol)
# Assert that the gpu version actually uses gpu
assert sum([isinstance(node.op, blasop.GpuGemm) for node in
gpu_f.maker.env.toposort() ]) == 1
assert sum([isinstance(node.op, blasop.GpuGemm) for node in
gpu_f2.maker.env.toposort() ]) == 1
assert any([node.op is cuda.blas.gpu_gemm_inplace for node in gpu_f2.maker.env.toposort()])
assert any([node.op is cuda.blas.gpu_gemm_inplace for node in gpu_f.maker.env.toposort()])
assert sum([node.op is cuda.blas.gpu_gemm_inplace for node in gpu_f2.maker.env.toposort()]) == 1
assert sum([node.op is cuda.blas.gpu_gemm_inplace for node in gpu_f.maker.env.toposort()]) == 1
def test_gemv2():
......@@ -103,12 +99,8 @@ def test_gemv2():
assert numpy.allclose(no_gpu_f(), gpu_f(), atol = atol)
assert numpy.allclose(no_gpu_f(), gpu_f2(), atol = atol)
# Assert that the gpu version actually uses gpu
assert sum([isinstance(node.op, blasop.GpuGemm) for node in
gpu_f.maker.env.toposort() ]) == 1
assert sum([isinstance(node.op, blasop.GpuGemm) for node in
gpu_f2.maker.env.toposort() ]) == 1
assert any([node.op is cuda.blas.gpu_gemm_inplace for node in gpu_f2.maker.env.toposort()])
assert any([node.op is cuda.blas.gpu_gemm_inplace for node in gpu_f.maker.env.toposort()])
assert sum([node.op is cuda.blas.gpu_gemm_inplace for node in gpu_f2.maker.env.toposort()]) == 1
assert sum([node.op is cuda.blas.gpu_gemm_inplace for node in gpu_f.maker.env.toposort()]) == 1
if __name__=='__main__':
test_dot_vm()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论