提交 0044b695 authored 作者: Frederic Bastien's avatar Frederic Bastien

Small. Fix test in DebugMode (the transfer was complaining about inf values,…

Small. Fix test in DebugMode (the transfer was complaining about inf values, also speed up optimization time). Test gpubatchgemm inplace opt. flake8
上级 a6174c22
...@@ -1735,6 +1735,7 @@ def local_inplace_gpuagemm(node, inputs): ...@@ -1735,6 +1735,7 @@ def local_inplace_gpuagemm(node, inputs):
def local_inplace_gpuager(node, inputs): def local_inplace_gpuager(node, inputs):
return [gpuger_inplace(*inputs)] return [gpuger_inplace(*inputs)]
@inplace_allocempty(GpuGemmBatch, 0) @inplace_allocempty(GpuGemmBatch, 0)
def local_inplace_gpuagemmbatch(node, inputs): def local_inplace_gpuagemmbatch(node, inputs):
return [gpugemmbatch_inplace(*inputs)] return [gpugemmbatch_inplace(*inputs)]
......
...@@ -1235,7 +1235,8 @@ def local_gpua_gemmbatch(op, context_name, inputs, outputs): ...@@ -1235,7 +1235,8 @@ def local_gpua_gemmbatch(op, context_name, inputs, outputs):
if b.dtype != out_dtype: if b.dtype != out_dtype:
b = gpu_cast_op(b) b = gpu_cast_op(b)
c = tensor.AllocEmpty(out_dtype)(a.shape[0], a.shape[1], b.shape[2]) c = GpuAllocEmpty(out_dtype, context_name)(
a.shape[0], a.shape[1], b.shape[2])
out = gpugemmbatch_no_inplace(c, np.asarray(1.0, dtype=out_dtype), out = gpugemmbatch_no_inplace(c, np.asarray(1.0, dtype=out_dtype),
a, b, np.asarray(0.0, dtype=out_dtype)) a, b, np.asarray(0.0, dtype=out_dtype))
if len(output_dims) != 3: if len(output_dims) != 3:
......
...@@ -166,6 +166,7 @@ class TestGpuGemmBatchStrided(TestCase): ...@@ -166,6 +166,7 @@ class TestGpuGemmBatchStrided(TestCase):
x_num = np.arange(32 * 19 * 600, dtype=config.floatX).reshape((32, 19, 600)) x_num = np.arange(32 * 19 * 600, dtype=config.floatX).reshape((32, 19, 600))
y_num = np.arange(7 * 32 * 600, dtype=config.floatX).reshape((32, 7, 600)) y_num = np.arange(7 * 32 * 600, dtype=config.floatX).reshape((32, 7, 600))
f(x_num, y_num) f(x_num, y_num)
assert f.maker.fgraph.toposort()[-2].op.inplace
class TestGpuSger(TestGer): class TestGpuSger(TestGer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论