提交 af3d84bd authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix problems in optimization.

上级 7d898822
...@@ -10,6 +10,7 @@ from theano.compile.ops import shape_i ...@@ -10,6 +10,7 @@ from theano.compile.ops import shape_i
from theano.gof import (local_optimizer, EquilibriumDB, from theano.gof import (local_optimizer, EquilibriumDB,
SequenceDB, Optimizer, toolbox) SequenceDB, Optimizer, toolbox)
from theano.gof.optdb import LocalGroupDB from theano.gof.optdb import LocalGroupDB
from theano.ifelse import IfElse
from theano.scalar.basic import Scalar, Pow, Cast from theano.scalar.basic import Scalar, Pow, Cast
from theano.scan_module import scan_utils, scan_op, scan_opt from theano.scan_module import scan_utils, scan_op, scan_opt
...@@ -529,9 +530,9 @@ def local_gpu_pdbbreakpoint_op(node): ...@@ -529,9 +530,9 @@ def local_gpu_pdbbreakpoint_op(node):
def local_gpua_lazy_ifelse(node, context_name): def local_gpua_lazy_ifelse(node, context_name):
if node.op.gpu: if node.op.gpu:
return return
c = nodes.inputs[0] c = node.inputs[0]
outs = [as_gpuarray_variable(v, context_name) for v in node.inputs[1:]] inps = [as_gpuarray_variable(v, context_name) for v in node.inputs[1:]]
return IfElse(node.op.n_outs, gpu=True)(c, *outs, return_list=True) return IfElse(node.op.n_outs, gpu=True)(c, *inps, return_list=True)
@register_opt('fast_compile') @register_opt('fast_compile')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论