提交 495182f1 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix MakeVector.make_node to upcast its input correctly. fix stack to don't…

fix MakeVector.make_node to upcast its input correctly. fix stack to don't downcast float* to int64 and to don't upcast inputs more then necessary.
上级 5dca0a9c
......@@ -2941,7 +2941,7 @@ def stack(*tensors):
isinstance(t.type, TensorType) and\
t.ndim==0 and t.type==tensors[0].type\
for t in tensors]):
return theano.tensor.opt.make_vector(*tensors)
return theano.tensor.opt.MakeVector(scal.upcast(*[i.dtype for i in tensors]))(*tensors)
return join(0, *[shape_padleft(t, 1) for t in tensors])
@constructor
......
......@@ -226,7 +226,7 @@ class MakeVector(T.Op):
return hash(type(self)) ^ hash(self.dtype)
def make_node(self, *inputs):
inputs = map(T.as_tensor_variable, inputs)
if not all(a.type == inputs[0].type for a in inputs):
if not all(a.type == inputs[0].type for a in inputs) or inputs[0].dtype != self.dtype:
dtype=theano.scalar.upcast(self.dtype,*[i.dtype for i in inputs])
#upcast the input to the determined dtype, but don't upcast downcast anything
assert dtype==self.dtype, "Upcast the input of MakeVector to dtype gived in init without precissino loss only."
......
......@@ -7,7 +7,7 @@ from theano.tensor import inplace
import unittest
from copy import copy
from theano import compile
from theano import compile, config
from theano import gradient
from theano import gof
from theano.gof.python25 import any, all
......@@ -1033,7 +1033,7 @@ class T_Join_and_Split(unittest.TestCase):
def test_stack_scalar_make_vector(self):
'''Test that calling stack() on scalars instantiates MakeVector,
not Join.'''
not Join. Test that the floatX dtype stay floatX, not down casted to int64'''
a = tensor.scalar('a')
b = tensor.scalar('b')
s = stack(a, b, a, b)
......@@ -1042,8 +1042,9 @@ class T_Join_and_Split(unittest.TestCase):
print val
self.failUnless(numpy.all(val == [1,2,1,2]))
e = f.maker.env.toposort()
assert len([n for n in e if n.op == opt.make_vector]) > 0
assert len([n for n in e if isinstance(n.op,opt.MakeVector)]) > 0
assert len([n for n in e if isinstance(n, Join)]) == 0
assert f.maker.env.outputs[0].dtype == config.floatX
def test_join_vector(self):
a = as_tensor_variable(numpy.array([1, 2, 3]))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论