提交 04e9d2a9 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #4898 from nouiz/gpu_assert

Lift more assert
...@@ -2224,7 +2224,7 @@ def local_gpualloc(node): ...@@ -2224,7 +2224,7 @@ def local_gpualloc(node):
@register_opt() @register_opt()
@local_optimizer([theano.tensor.opt.Assert]) @local_optimizer([theano.tensor.opt.Assert, GpuFromHost])
def local_assert(node): def local_assert(node):
if (isinstance(node.op, theano.tensor.opt.Assert) and if (isinstance(node.op, theano.tensor.opt.Assert) and
node.inputs[0].owner and node.inputs[0].owner and
...@@ -2232,6 +2232,13 @@ def local_assert(node): ...@@ -2232,6 +2232,13 @@ def local_assert(node):
HostFromGpu)): HostFromGpu)):
return [host_from_gpu(node.op(node.inputs[0].owner.inputs[0], return [host_from_gpu(node.op(node.inputs[0].owner.inputs[0],
*node.inputs[1:]))] *node.inputs[1:]))]
elif (isinstance(node.op, GpuFromHost) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op,
theano.tensor.opt.Assert)):
a = node.inputs[0].owner
new = a.op(gpu_from_host(a.inputs[0]), *a.inputs[1:])
return [new]
@register_opt() @register_opt()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论