提交 858cf0a3 authored 作者: James Bergstra's avatar James Bergstra

fixing bugreport by Josh Snyder - complex is shadowed in scalar/basic

上级 a4db34e1
...@@ -9,6 +9,10 @@ from theano.gof import Op, utils, Variable, Constant, Type, Apply, Env ...@@ -9,6 +9,10 @@ from theano.gof import Op, utils, Variable, Constant, Type, Apply, Env
from theano.gof.python25 import partial, all, any from theano.gof.python25 import partial, all, any
from theano.configparser import config from theano.configparser import config
builtin_complex = complex
builtin_int = int
builtin_float = float
def upcast(dtype, *dtypes): def upcast(dtype, *dtypes):
z = numpy.zeros((), dtype = dtype) z = numpy.zeros((), dtype = dtype)
for dtype in dtypes: for dtype in dtypes:
...@@ -31,7 +35,7 @@ def as_scalar(x, name = None): ...@@ -31,7 +35,7 @@ def as_scalar(x, name = None):
raise TypeError("Cannot convert %s to Scalar" % x, type(x)) raise TypeError("Cannot convert %s to Scalar" % x, type(x))
def constant(x): def constant(x):
if isinstance(x, float): if isinstance(x, builtin_float):
for dtype in ['float32', 'float64']: for dtype in ['float32', 'float64']:
x_ = theano._asarray(x, dtype=dtype) x_ = theano._asarray(x, dtype=dtype)
if numpy.all(x == x_): if numpy.all(x == x_):
...@@ -39,7 +43,7 @@ def constant(x): ...@@ -39,7 +43,7 @@ def constant(x):
x_ = None x_ = None
assert x_ is not None assert x_ is not None
return ScalarConstant(Scalar(str(x_.dtype)), x) return ScalarConstant(Scalar(str(x_.dtype)), x)
if isinstance(x, int): if isinstance(x, builtin_int):
for dtype in ['int8', 'int16', 'int32', 'int64']: for dtype in ['int8', 'int16', 'int32', 'int64']:
x_ = theano._asarray(x, dtype=dtype) x_ = theano._asarray(x, dtype=dtype)
if numpy.all(x == x_): if numpy.all(x == x_):
...@@ -47,7 +51,8 @@ def constant(x): ...@@ -47,7 +51,8 @@ def constant(x):
x_ = None x_ = None
assert x_ is not None assert x_ is not None
return ScalarConstant(Scalar(str(x_.dtype)), x) return ScalarConstant(Scalar(str(x_.dtype)), x)
if isinstance(x, complex): if isinstance(x, builtin_complex):
#TODO: We have added the complex type, so this should be tested
raise NotImplementedError() raise NotImplementedError()
raise TypeError(x) raise TypeError(x)
#return ScalarConstant(float64, float(x)) #return ScalarConstant(float64, float(x))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论