提交 8c607e8f authored 作者: Frederic Bastien's avatar Frederic Bastien

Remove opt warning with mixed dtype

上级 554be9cd
...@@ -1949,9 +1949,10 @@ def local_gpu_join(node): ...@@ -1949,9 +1949,10 @@ def local_gpu_join(node):
# print "OPT: axis_and_tensors=", axis_and_tensors # print "OPT: axis_and_tensors=", axis_and_tensors
matches = [(t.owner is not None and matches = [t.dtype == 'float32' and
((t.owner is not None and
isinstance(t.owner.op, HostFromGpu)) or isinstance(t.owner.op, HostFromGpu)) or
isinstance(t, gof.Constant) for t in axis_and_tensors[1:]] isinstance(t, gof.Constant)) for t in axis_and_tensors[1:]]
# print "OPT: matches =", matches # print "OPT: matches =", matches
# if all input tensors are host_from_gpu'ified # if all input tensors are host_from_gpu'ified
......
...@@ -300,6 +300,21 @@ def test_opt_gpujoin_onlyajoin(): ...@@ -300,6 +300,21 @@ def test_opt_gpujoin_onlyajoin():
assert numpy.all(f() == numpy.concatenate([_a, _b], axis=1)) assert numpy.all(f() == numpy.concatenate([_a, _b], axis=1))
# test mixed dtype
_b = numpy.asarray([[5, 6, 7], [8, 9, 10]], dtype='float64')
b = theano.tensor.constant(_b)
c = tensor.join(1, a, b)
f = theano.function([], c, mode=mode_with_gpu)
f()
graph_nodes = f.maker.fgraph.toposort()
assert isinstance(graph_nodes[-1].op, theano.tensor.Join)
assert numpy.all(f() == numpy.concatenate([_a, _b], axis=1))
def test_opt_gpujoin_joinvectors_elemwise_then_minusone(): def test_opt_gpujoin_joinvectors_elemwise_then_minusone():
# from a bug in gpu normal sampling # from a bug in gpu normal sampling
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论