提交 5a0e695e authored 作者: Frederic Bastien's avatar Frederic Bastien

make tensor.stack() work when the first element is a python/numpy int, float, complex

上级 84898822
......@@ -33,6 +33,9 @@ def _info(*msg):
def _warn(*msg):
_logger.warn(' '.join(msg))
#This is needed as we will hide it later
python_complex=complex
def check_equal_numpy(x, y):
"""
Returns True iff x and y are equal (checks the dtype and
......@@ -3269,11 +3272,14 @@ def stack(*tensors):
raise Exception('theano.tensor.stack(*tensors) must have at least one parameter')
# If all tensors are scalars of the same type, call make_vector.
# It makes the graph simpler, by not adding DimShuffles and Rebroadcasts
if numpy.all([(isinstance(t, Variable) and
isinstance(t.type, TensorType) and
t.ndim==0 and
t.type.__class__==tensors[0].type.__class__)
or isinstance(t, (numpy.number, float, int, complex))#in case their is direct int
if isinstance(tensors[0], (numpy.number, float, int, python_complex)):
tensors=list(tensors)
tensors[0]=as_tensor_variable(tensors[0])
if numpy.all([isinstance(t, (numpy.number, float, int, python_complex))#in case their is direct int
or (isinstance(t, Variable) and
isinstance(t.type, TensorType) and
t.ndim==0 and
t.type.__class__==tensors[0].type.__class__)
for t in tensors]):
tensors = map(as_tensor_variable,tensors)#in case their is direct int
dtype = scal.upcast(*[i.dtype for i in tensors])
......
......@@ -1571,10 +1571,10 @@ class T_Join_and_Split(unittest.TestCase):
event when the scalar are simple int type.'''
a = tensor.iscalar('a')
b = tensor.lscalar('b')
s = stack(a,b,10, numpy.int8(3))
s = stack(10,a,b, numpy.int8(3))
f = function([a,b], s)
val = f(1,2)
self.failUnless(numpy.all(val == [1,2,10,3]))
self.failUnless(numpy.all(val == [10,1,2,3]))
e = f.maker.env.toposort()
assert len([n for n in e if isinstance(n.op,opt.MakeVector)]) > 0
assert len([n for n in e if isinstance(n, Join)]) == 0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论