提交 7cc17ae7 authored 作者: Frederic Bastien's avatar Frederic Bastien

make tensor.stack generate MakeVector instead of Join more often. test it.

上级 d1c3d45e
...@@ -3271,7 +3271,8 @@ def stack(*tensors): ...@@ -3271,7 +3271,8 @@ def stack(*tensors):
# It makes the graph simpler, by not adding DimShuffles and Rebroadcasts # It makes the graph simpler, by not adding DimShuffles and Rebroadcasts
if numpy.all([isinstance(t, Variable) and\ if numpy.all([isinstance(t, Variable) and\
isinstance(t.type, TensorType) and\ isinstance(t.type, TensorType) and\
t.ndim==0 and t.type==tensors[0].type\ t.ndim==0 and \
t.type.__class__==tensors[0].type.__class__\
for t in tensors]): for t in tensors]):
return theano.tensor.opt.MakeVector(scal.upcast(*[i.dtype for i in tensors]))(*tensors) return theano.tensor.opt.MakeVector(scal.upcast(*[i.dtype for i in tensors]))(*tensors)
return join(0, *[shape_padleft(t, 1) for t in tensors]) return join(0, *[shape_padleft(t, 1) for t in tensors])
......
...@@ -1552,6 +1552,20 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -1552,6 +1552,20 @@ class T_Join_and_Split(unittest.TestCase):
assert len([n for n in e if isinstance(n, Join)]) == 0 assert len([n for n in e if isinstance(n, Join)]) == 0
assert f.maker.env.outputs[0].dtype == config.floatX assert f.maker.env.outputs[0].dtype == config.floatX
def test_stack_scalar_make_vector_dtype(self):
'''Test that calling stack() on scalars instantiates MakeVector,
event when the scalar don't have the same dtype.'''
a = tensor.iscalar('a')
b = tensor.lscalar('b')
s = stack(a, b, a, b)
f = function([a,b], s)
val = f(1,2)
self.failUnless(numpy.all(val == [1,2,1,2]))
e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op,opt.MakeVector)]) > 0
assert len([n for n in e if isinstance(n, Join)]) == 0
assert f.maker.env.outputs[0].dtype == 'int64'
def test_join_vector(self): def test_join_vector(self):
a = as_tensor_variable(numpy.array([1, 2, 3])) a = as_tensor_variable(numpy.array([1, 2, 3]))
b = as_tensor_variable(numpy.array([7, 8, 9])) b = as_tensor_variable(numpy.array([7, 8, 9]))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论