提交 18dd2955 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5082 from nouiz/constant_floatX

For empty list/tuple, use floatX for the dtype instead of the numpy d…
...@@ -390,7 +390,8 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None): ...@@ -390,7 +390,8 @@ def constant_or_value(x, rtype, name=None, ndim=None, dtype=None):
# it will only work if the long fits in int64. For NumPy 1.7.1+, # it will only work if the long fits in int64. For NumPy 1.7.1+,
# it will work if the long fits in int64 or uint64. # it will work if the long fits in int64 or uint64.
x_ = numpy.asarray(x) x_ = numpy.asarray(x)
if x_.size == 0 and not hasattr(x, 'dtype'):
x_ = numpy.asarray(x, dtype=config.floatX)
assert type(x_) in [numpy.ndarray, numpy.memmap] assert type(x_) in [numpy.ndarray, numpy.memmap]
bcastable = [d == 1 for d in x_.shape] bcastable = [d == 1 for d in x_.shape]
......
...@@ -7127,6 +7127,16 @@ class T_as_tensor_variable(unittest.TestCase): ...@@ -7127,6 +7127,16 @@ class T_as_tensor_variable(unittest.TestCase):
new_inp[...] = inp new_inp[...] = inp
x = as_tensor_variable(new_inp) x = as_tensor_variable(new_inp)
def test_empty_dtype(self):
old = theano.config.floatX
for dtype in ['float16', 'float32', 'float64']:
try:
theano.config.floatX = dtype
assert theano.tensor.as_tensor_variable(()).dtype == dtype
assert theano.tensor.as_tensor_variable([]).dtype == dtype
finally:
theano.config.floatX = old
class test_complex_mod(unittest.TestCase): class test_complex_mod(unittest.TestCase):
"""Make sure % fails on complex numbers.""" """Make sure % fails on complex numbers."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论