提交 71b73411 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

make test easier to read

上级 4a241779
......@@ -1055,7 +1055,8 @@ class test_fusion(unittest.TestCase):
if gpu:
import theano.sandbox.cuda as cuda
topo_ = [x for x in topo if not isinstance(
x.op,cuda.basic_ops.GpuFromHost) and not isinstance(x.op,cuda.basic_ops.HostFromGpu)]
x.op, (cuda.basic_ops.GpuFromHost, cuda.basic_ops.HostFromGpu))]
gpu_ = [x for x in topo if isinstance(x.op,
cuda.basic_ops.GpuFromHost)]
if not len(gpu_) == len(sym_inputs):
......@@ -1066,13 +1067,16 @@ class test_fusion(unittest.TestCase):
if not len(topo_) == nb_elemwise:
fail3.append((id, topo_, nb_elemwise))
if nb_elemwise == 1:
# check that the number of input to the Composite Elemwise is ok
# when there is not variable that appear multiple time the in input
# of g
assert ((numpy.sum([not isinstance(x, theano.gof.Constant)
for x in topo_[0].inputs]) ==
len(sym_inputs)) or
len(set(g.owner.inputs)) != len(g.owner.inputs))
# if no variable appears multiple times in the
# input of g,
# check that the number of input to the Composite
# Elemwise is ok
if len(set(g.owner.inputs)) == len(g.owner.inputs):
expected_len_sym_inputs = numpy.sum(
[not isinstance(x, theano.gof.Constant)
for x in topo_[0].inputs])
assert expected_len_sym_inputs == len(sym_inputs)
if not out_dtype == out.dtype:
fail4.append((id, out_dtype, out.dtype))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论