提交 0ef2abb8 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Re-using lists of types already defined in tensor

Similar lists were added recently in elemwise.py: it is better to re-use existing ones.
上级 fda86e28
...@@ -37,9 +37,10 @@ python_any = any ...@@ -37,9 +37,10 @@ python_any = any
python_all = all python_all = all
# Define common subsets of dtypes (as strings). # Define common subsets of dtypes (as strings).
int_dtypes = map(str, scal.int_types)
discrete_dtypes = map(str, scal.discrete_types)
complex_dtypes = map(str, scal.complex_types) complex_dtypes = map(str, scal.complex_types)
continuous_dtypes = map(str, scal.continuous_types)
discrete_dtypes = map(str, scal.discrete_types)
int_dtypes = map(str, scal.int_types)
class ShapeError(Exception): class ShapeError(Exception):
......
...@@ -13,7 +13,6 @@ from theano.gof.python25 import all, any ...@@ -13,7 +13,6 @@ from theano.gof.python25 import all, any
config = theano.config config = theano.config
# tensor depends on elemwise to provide definitions for several ops # tensor depends on elemwise to provide definitions for several ops
# but elemwise needs to make TensorType instances, so we have these as # but elemwise needs to make TensorType instances, so we have these as
# placeholders and the tensor module fills them # placeholders and the tensor module fills them
...@@ -29,10 +28,6 @@ def TensorVariable(*inputs, **kwargs): ...@@ -29,10 +28,6 @@ def TensorVariable(*inputs, **kwargs):
def TensorConstant(*inputs, **kwargs): def TensorConstant(*inputs, **kwargs):
raise Exception("Circular dependencies prevent using this here. import tensor before elemwise") raise Exception("Circular dependencies prevent using this here. import tensor before elemwise")
# Define common subsets of dtypes (as strings).
discrete_dtypes = map(str, scalar.discrete_types)
continuous_dtypes = map(str, scalar.continuous_types)
################## ##################
### DimShuffle ### ### DimShuffle ###
...@@ -1374,7 +1369,8 @@ class CAReduceDtype(CAReduce): ...@@ -1374,7 +1369,8 @@ class CAReduceDtype(CAReduce):
uint16='uint64', uint16='uint64',
uint32='uint64', uint32='uint64',
).get(idtype, idtype) ).get(idtype, idtype)
elif dtype in continuous_dtypes and idtype in discrete_dtypes: elif (dtype in theano.tensor.continuous_dtypes and
idtype in theano.tensor.discrete_dtypes):
# Specifying a continuous output for discrete input is OK # Specifying a continuous output for discrete input is OK
return dtype return dtype
else: else:
......
...@@ -534,8 +534,8 @@ class T_sum_dtype(unittest.TestCase): ...@@ -534,8 +534,8 @@ class T_sum_dtype(unittest.TestCase):
# We always allow int/uint inputs with float/complex outputs. # We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype) upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and (input_dtype in tensor.discrete_dtypes and
output_dtype in continuous_dtypes) output_dtype in tensor.continuous_dtypes)
): ):
sum_var = x.sum(dtype=output_dtype, axis=axis) sum_var = x.sum(dtype=output_dtype, axis=axis)
assert sum_var.dtype == output_dtype assert sum_var.dtype == output_dtype
...@@ -559,7 +559,7 @@ class T_mean_dtype(unittest.TestCase): ...@@ -559,7 +559,7 @@ class T_mean_dtype(unittest.TestCase):
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)): for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)] axis = axes[idx % len(axes)]
x = tensor.matrix(dtype=dtype).mean(axis=axis) x = tensor.matrix(dtype=dtype).mean(axis=axis)
if dtype in discrete_dtypes: if dtype in tensor.discrete_dtypes:
assert x.dtype == 'float64' assert x.dtype == 'float64'
else: else:
assert x.dtype == dtype, (x, x.dtype, dtype) assert x.dtype == dtype, (x, x.dtype, dtype)
...@@ -582,7 +582,7 @@ class T_mean_dtype(unittest.TestCase): ...@@ -582,7 +582,7 @@ class T_mean_dtype(unittest.TestCase):
pass pass
else: else:
# Executed if no TypeError was raised # Executed if no TypeError was raised
if sum_dtype in discrete_dtypes: if sum_dtype in tensor.discrete_dtypes:
assert mean_var.dtype == 'float64', (mean_var.dtype, sum_dtype) assert mean_var.dtype == 'float64', (mean_var.dtype, sum_dtype)
else: else:
assert mean_var.dtype == sum_dtype, (mean_var.dtype, output_dtype) assert mean_var.dtype == sum_dtype, (mean_var.dtype, output_dtype)
...@@ -635,8 +635,8 @@ class T_prod_dtype(unittest.TestCase): ...@@ -635,8 +635,8 @@ class T_prod_dtype(unittest.TestCase):
# We always allow int/uint inputs with float/complex outputs. # We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype) upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and (input_dtype in tensor.discrete_dtypes and
output_dtype in continuous_dtypes) output_dtype in tensor.continuous_dtypes)
): ):
prod_var = x.prod(dtype=output_dtype, axis=axis) prod_var = x.prod(dtype=output_dtype, axis=axis)
assert prod_var.dtype == output_dtype assert prod_var.dtype == output_dtype
...@@ -684,8 +684,8 @@ class T_prod_without_zeros_dtype(unittest.TestCase): ...@@ -684,8 +684,8 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
# We always allow int/uint inputs with float/complex outputs. # We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype) upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and (input_dtype in tensor.discrete_dtypes and
output_dtype in continuous_dtypes) output_dtype in tensor.continuous_dtypes)
): ):
prod_woz_var = ProdWithoutZeros( prod_woz_var = ProdWithoutZeros(
axis=axis, dtype=output_dtype)(x) axis=axis, dtype=output_dtype)(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论