提交 463dad0c authored 作者: Frederic's avatar Frederic

Fix opt crash of local_gpu_lazy_ifelse.

Reported by Ryan Price
上级 2faeb62c
......@@ -403,7 +403,12 @@ def local_gpu_lazy_ifelse(node):
host_input = node.inputs[0]
if (host_input.owner and
isinstance(host_input.owner.op, theano.ifelse.IfElse) and
not host_input.owner.op.gpu):
not host_input.owner.op.gpu and
# If there is more then 1 outputs, we can't replace it
# here with a local optimizer as we replace the
# GpuFromHost node and the other output of the if won't be
# replaced.
host_input.owner.op.n_outs == 1):
gpu_ifelse = theano.ifelse.IfElse(host_input.owner.op.n_outs,
gpu=True)
......
......@@ -159,6 +159,49 @@ class test_ifelse(unittest.TestCase, utt.TestOptimizationMixin):
assert numpy.all(outs_0[2] == 1.)
assert numpy.all(outs_0[3] == 1.)
def test_multiple_out_crash(self):
# This test failed up to commit 2faeb62c38
p0 = self.shared(numpy.asarray(numpy.random.random([4, 8]),
dtype=self.dtype))
p1 = self.shared(numpy.asarray(numpy.random.random(8),
dtype=self.dtype))
p2 = self.shared(numpy.asarray(numpy.random.random([8, 3]),
dtype=self.dtype))
p3 = self.shared(numpy.asarray(numpy.random.random(3),
dtype=self.dtype))
p = [p0, p1, p2, p3]
# in my code these vars are the result of applying scan
ften0 = tensor.ftensor3('ft0')
fmat1 = tensor.fmatrix('fm1')
ften2 = tensor.ftensor3('ft2')
fmat3 = tensor.fmatrix('fm3')
# then I keep only the last iteration
fsub0 = ften0[-1]
fsub1 = fmat1[-1]
fsub2 = ften2[-1]
fsub3 = fmat3[-1]
fsub = [fsub0, fsub1, fsub2, fsub3]
acc = theano.tensor.constant(1, 'int8') >= 0
new_positions = theano.ifelse.ifelse(acc, fsub, p)
new_updates = [(p[0], new_positions[0])]
f = theano.function([ften0, fmat1, ften2, fmat3], [],
updates=new_updates, mode=self.mode)
self.assertFunctionContains1(f, self.get_ifelse(4))
i1 = numpy.asarray(numpy.random.random([19, 4, 8]), dtype=self.dtype)
i2 = numpy.asarray(numpy.random.random([19, 8]), dtype=self.dtype)
i3 = numpy.asarray(numpy.random.random([19, 8, 3]), dtype=self.dtype)
i4 = numpy.asarray(numpy.random.random([19, 3]), dtype=self.dtype)
f(i1, i2, i3, i4)
def test_dtype_mismatch(self):
rng = numpy.random.RandomState(utt.fetch_seed())
data = rng.rand(5).astype(self.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论