提交 8495d077 authored 作者: Frederic Bastien's avatar Frederic Bastien

For empty list/tuple, use floatX for the dtype instead of the numpy default (float64)

上级 421ca4bd
......@@ -390,7 +390,8 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
# it will only work if the long fits in int64. For NumPy 1.7.1+,
# it will work if the long fits in int64 or uint64.
x_ = numpy.asarray(x)
if x_.size == 0 and not hasattr(x, 'dtype'):
x_ = numpy.asarray(x, dtype=config.floatX)
assert type(x_) in [numpy.ndarray, numpy.memmap]
bcastable = [d == 1 for d in x_.shape]
......
......@@ -7118,6 +7118,16 @@ class T_as_tensor_variable(unittest.TestCase):
new_inp[...] = inp
x = as_tensor_variable(new_inp)
def test_empty_dtype(self):
old = theano.config.floatX
for dtype in ['float16', 'float32', 'float64']:
try:
theano.config.floatX = dtype
assert theano.tensor.as_tensor_variable(()).dtype == dtype
assert theano.tensor.as_tensor_variable([]).dtype == dtype
finally:
theano.config.floatX = old
class test_complex_mod(unittest.TestCase):
"""Make sure % fails on complex numbers."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论