提交 84898822 authored 作者: Frederic Bastien's avatar Frederic Bastien

make tensor.stack generate MakeVector when he receive python/numpy scalar.

上级 83ff805b
...@@ -3269,12 +3269,15 @@ def stack(*tensors): ...@@ -3269,12 +3269,15 @@ def stack(*tensors):
raise Exception('theano.tensor.stack(*tensors) must have at least one parameter') raise Exception('theano.tensor.stack(*tensors) must have at least one parameter')
# If all tensors are scalars of the same type, call make_vector. # If all tensors are scalars of the same type, call make_vector.
# It makes the graph simpler, by not adding DimShuffles and Rebroadcasts # It makes the graph simpler, by not adding DimShuffles and Rebroadcasts
if numpy.all([isinstance(t, Variable) and\ if numpy.all([(isinstance(t, Variable) and
isinstance(t.type, TensorType) and\ isinstance(t.type, TensorType) and
t.ndim==0 and \ t.ndim==0 and
t.type.__class__==tensors[0].type.__class__\ t.type.__class__==tensors[0].type.__class__)
or isinstance(t, (numpy.number, float, int, complex))#in case their is direct int
for t in tensors]): for t in tensors]):
return theano.tensor.opt.MakeVector(scal.upcast(*[i.dtype for i in tensors]))(*tensors) tensors = map(as_tensor_variable,tensors)#in case their is direct int
dtype = scal.upcast(*[i.dtype for i in tensors])
return theano.tensor.opt.MakeVector(dtype)(*tensors)
return join(0, *[shape_padleft(t, 1) for t in tensors]) return join(0, *[shape_padleft(t, 1) for t in tensors])
@constructor @constructor
......
...@@ -1566,6 +1566,20 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -1566,6 +1566,20 @@ class T_Join_and_Split(unittest.TestCase):
assert len([n for n in e if isinstance(n, Join)]) == 0 assert len([n for n in e if isinstance(n, Join)]) == 0
assert f.maker.env.outputs[0].dtype == 'int64' assert f.maker.env.outputs[0].dtype == 'int64'
def test_stack_scalar_make_vector_constant(self):
'''Test that calling stack() on scalars instantiates MakeVector,
event when the scalar are simple int type.'''
a = tensor.iscalar('a')
b = tensor.lscalar('b')
s = stack(a,b,10, numpy.int8(3))
f = function([a,b], s)
val = f(1,2)
self.failUnless(numpy.all(val == [1,2,10,3]))
e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op,opt.MakeVector)]) > 0
assert len([n for n in e if isinstance(n, Join)]) == 0
assert f.maker.env.outputs[0].dtype == 'int64'
def test_join_vector(self): def test_join_vector(self):
a = as_tensor_variable(numpy.array([1, 2, 3])) a = as_tensor_variable(numpy.array([1, 2, 3]))
b = as_tensor_variable(numpy.array([7, 8, 9])) b = as_tensor_variable(numpy.array([7, 8, 9]))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论