提交 495182f1 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix MakeVector.make_node to upcast its input correctly. fix stack to don't…

fix MakeVector.make_node to upcast its input correctly. fix stack to don't downcast float* to int64 and to don't upcast inputs more then necessary.
上级 5dca0a9c
...@@ -2941,7 +2941,7 @@ def stack(*tensors): ...@@ -2941,7 +2941,7 @@ def stack(*tensors):
isinstance(t.type, TensorType) and\ isinstance(t.type, TensorType) and\
t.ndim==0 and t.type==tensors[0].type\ t.ndim==0 and t.type==tensors[0].type\
for t in tensors]): for t in tensors]):
return theano.tensor.opt.make_vector(*tensors) return theano.tensor.opt.MakeVector(scal.upcast(*[i.dtype for i in tensors]))(*tensors)
return join(0, *[shape_padleft(t, 1) for t in tensors]) return join(0, *[shape_padleft(t, 1) for t in tensors])
@constructor @constructor
......
...@@ -226,7 +226,7 @@ class MakeVector(T.Op): ...@@ -226,7 +226,7 @@ class MakeVector(T.Op):
return hash(type(self)) ^ hash(self.dtype) return hash(type(self)) ^ hash(self.dtype)
def make_node(self, *inputs): def make_node(self, *inputs):
inputs = map(T.as_tensor_variable, inputs) inputs = map(T.as_tensor_variable, inputs)
if not all(a.type == inputs[0].type for a in inputs): if not all(a.type == inputs[0].type for a in inputs) or inputs[0].dtype != self.dtype:
dtype=theano.scalar.upcast(self.dtype,*[i.dtype for i in inputs]) dtype=theano.scalar.upcast(self.dtype,*[i.dtype for i in inputs])
#upcast the input to the determined dtype, but don't upcast downcast anything #upcast the input to the determined dtype, but don't upcast downcast anything
assert dtype==self.dtype, "Upcast the input of MakeVector to dtype gived in init without precissino loss only." assert dtype==self.dtype, "Upcast the input of MakeVector to dtype gived in init without precissino loss only."
......
...@@ -7,7 +7,7 @@ from theano.tensor import inplace ...@@ -7,7 +7,7 @@ from theano.tensor import inplace
import unittest import unittest
from copy import copy from copy import copy
from theano import compile from theano import compile, config
from theano import gradient from theano import gradient
from theano import gof from theano import gof
from theano.gof.python25 import any, all from theano.gof.python25 import any, all
...@@ -1033,7 +1033,7 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -1033,7 +1033,7 @@ class T_Join_and_Split(unittest.TestCase):
def test_stack_scalar_make_vector(self): def test_stack_scalar_make_vector(self):
'''Test that calling stack() on scalars instantiates MakeVector, '''Test that calling stack() on scalars instantiates MakeVector,
not Join.''' not Join. Test that the floatX dtype stay floatX, not down casted to int64'''
a = tensor.scalar('a') a = tensor.scalar('a')
b = tensor.scalar('b') b = tensor.scalar('b')
s = stack(a, b, a, b) s = stack(a, b, a, b)
...@@ -1042,8 +1042,9 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -1042,8 +1042,9 @@ class T_Join_and_Split(unittest.TestCase):
print val print val
self.failUnless(numpy.all(val == [1,2,1,2])) self.failUnless(numpy.all(val == [1,2,1,2]))
e = f.maker.env.toposort() e = f.maker.env.toposort()
assert len([n for n in e if n.op == opt.make_vector]) > 0 assert len([n for n in e if isinstance(n.op,opt.MakeVector)]) > 0
assert len([n for n in e if isinstance(n, Join)]) == 0 assert len([n for n in e if isinstance(n, Join)]) == 0
assert f.maker.env.outputs[0].dtype == config.floatX
def test_join_vector(self): def test_join_vector(self):
a = as_tensor_variable(numpy.array([1, 2, 3])) a = as_tensor_variable(numpy.array([1, 2, 3]))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论