提交 fe6a777b authored 作者: James Bergstra's avatar James Bergstra

Modified casting to work correctly with complex-valued types. Casting from

complex->real is forbidden by the cast() function.
上级 5d44e7db
......@@ -882,18 +882,6 @@ class ScalarFromTensor(Op):
scalar_from_tensor = ScalarFromTensor()
@constructor
def cast(t, dtype):
mapping = {'int8': convert_to_int8,
'int16': convert_to_int16,
'int32': convert_to_int32,
'int64': convert_to_int64,
'float32': convert_to_float32,
'float64': convert_to_float64,
'complex64': convert_to_complex64,
'complex128': convert_to_complex128}
return mapping[dtype](t)
#to be removed as we get the epydoc routine-documenting thing going -JB 20080924
def _conversion(real_value, name):
__oplist_tag(real_value, 'casting')
......@@ -901,30 +889,52 @@ def _conversion(real_value, name):
pprint.assign(real_value, printing.FunctionPrinter(name))
return real_value
convert_to_int8 = _conversion(elemwise.Elemwise(scal.convert_to_int8), 'int8')
#
# These _conver_to_<type> functions have leading underscores to indicate that they should not
# be called directly. They do not perform sanity checks about what types you are casting to
# what. That logic is implemented by the `cast()` function below.
#
_convert_to_int8 = _conversion(elemwise.Elemwise(scal.convert_to_int8), 'int8')
"""Cast to 8-bit integer"""
convert_to_int16 = _conversion(elemwise.Elemwise(scal.convert_to_int16), 'int16')
_convert_to_int16 = _conversion(elemwise.Elemwise(scal.convert_to_int16), 'int16')
"""Cast to 16-bit integer"""
convert_to_int32 = _conversion(elemwise.Elemwise(scal.convert_to_int32), 'int32')
_convert_to_int32 = _conversion(elemwise.Elemwise(scal.convert_to_int32), 'int32')
"""Cast to 32-bit integer"""
convert_to_int64 = _conversion(elemwise.Elemwise(scal.convert_to_int64), 'int64')
_convert_to_int64 = _conversion(elemwise.Elemwise(scal.convert_to_int64), 'int64')
"""Cast to 64-bit integer"""
convert_to_float32 = _conversion(elemwise.Elemwise(scal.convert_to_float32), 'float32')
_convert_to_float32 = _conversion(elemwise.Elemwise(scal.convert_to_float32), 'float32')
"""Cast to single-precision floating point"""
convert_to_float64 = _conversion(elemwise.Elemwise(scal.convert_to_float64), 'float64')
_convert_to_float64 = _conversion(elemwise.Elemwise(scal.convert_to_float64), 'float64')
"""Cast to double-precision floating point"""
convert_to_complex64 = _conversion(elemwise.Elemwise(scal.convert_to_complex64), 'complex64')
_convert_to_complex64 = _conversion(elemwise.Elemwise(scal.convert_to_complex64), 'complex64')
"""Cast to single-precision complex"""
convert_to_complex128 = _conversion(elemwise.Elemwise(scal.convert_to_complex128), 'complex128')
_convert_to_complex128 = _conversion(elemwise.Elemwise(scal.convert_to_complex128), 'complex128')
"""Cast to double-precision complex"""
_cast_mapping = {'int8': _convert_to_int8,
'int16': _convert_to_int16,
'int32': _convert_to_int32,
'int64': _convert_to_int64,
'float32': _convert_to_float32,
'float64': _convert_to_float64,
'complex64': _convert_to_complex64,
'complex128': _convert_to_complex128}
@constructor
def cast(x, dtype):
"""Symbolically cast `x` to a Tensor of type `dtype`."""
if x.type.dtype.startswith('complex') and not dtype.startswith('complex'):
raise TypeError('Casting from complex to real is ambiguous: consider real(), imag(), angle() or abs()')
return _cast_mapping[dtype](x)
##########################
......@@ -1140,7 +1150,6 @@ def abs_(a):
pprint.assign(abs_, printing.PatternPrinter(('|%(0)s|', -1000)))
@_scal_elemwise
def exp(a):
"""e^`a`"""
......
import unittest
from theano import function
from theano.tensor import *
class test_casting(unittest.TestCase):
def test_0(self):
for op_fn in convert_to_int32, convert_to_float32, convert_to_float64:
for type_fn in bvector, ivector, fvector, dvector:
x = type_fn()
f = function([x], op_fn(x))
xval = numpy.asarray(numpy.random.rand(10)*10, dtype=type_fn.dtype)
yval = f(xval)
assert str(yval.dtype) == op_fn.scalar_op.output_types_preference.spec[0].dtype
def test_illegal(self):
try:
x = zmatrix()
function([x], convert_to_float64(x))(numpy.ones((2,3), dtype='complex128'))
except TypeError:
return
assert 0
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论