提交 0ef2abb8 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Re-using lists of types already defined in tensor

Similar lists were added recently in elemwise.py: it is better to re-use existing ones.
上级 fda86e28
......@@ -37,9 +37,10 @@ python_any = any
python_all = all
# Define common subsets of dtypes (as strings).
int_dtypes = map(str, scal.int_types)
discrete_dtypes = map(str, scal.discrete_types)
complex_dtypes = map(str, scal.complex_types)
continuous_dtypes = map(str, scal.continuous_types)
discrete_dtypes = map(str, scal.discrete_types)
int_dtypes = map(str, scal.int_types)
class ShapeError(Exception):
......
......@@ -13,7 +13,6 @@ from theano.gof.python25 import all, any
config = theano.config
# tensor depends on elemwise to provide definitions for several ops
# but elemwise needs to make TensorType instances, so we have these as
# placeholders and the tensor module fills them
......@@ -29,10 +28,6 @@ def TensorVariable(*inputs, **kwargs):
def TensorConstant(*inputs, **kwargs):
raise Exception("Circular dependencies prevent using this here. import tensor before elemwise")
# Define common subsets of dtypes (as strings).
discrete_dtypes = map(str, scalar.discrete_types)
continuous_dtypes = map(str, scalar.continuous_types)
##################
### DimShuffle ###
......@@ -1374,7 +1369,8 @@ class CAReduceDtype(CAReduce):
uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
elif dtype in continuous_dtypes and idtype in discrete_dtypes:
elif (dtype in theano.tensor.continuous_dtypes and
idtype in theano.tensor.discrete_dtypes):
# Specifying a continuous output for discrete input is OK
return dtype
else:
......
......@@ -534,8 +534,8 @@ class T_sum_dtype(unittest.TestCase):
# We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and
output_dtype in continuous_dtypes)
(input_dtype in tensor.discrete_dtypes and
output_dtype in tensor.continuous_dtypes)
):
sum_var = x.sum(dtype=output_dtype, axis=axis)
assert sum_var.dtype == output_dtype
......@@ -559,7 +559,7 @@ class T_mean_dtype(unittest.TestCase):
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)]
x = tensor.matrix(dtype=dtype).mean(axis=axis)
if dtype in discrete_dtypes:
if dtype in tensor.discrete_dtypes:
assert x.dtype == 'float64'
else:
assert x.dtype == dtype, (x, x.dtype, dtype)
......@@ -582,7 +582,7 @@ class T_mean_dtype(unittest.TestCase):
pass
else:
# Executed if no TypeError was raised
if sum_dtype in discrete_dtypes:
if sum_dtype in tensor.discrete_dtypes:
assert mean_var.dtype == 'float64', (mean_var.dtype, sum_dtype)
else:
assert mean_var.dtype == sum_dtype, (mean_var.dtype, output_dtype)
......@@ -635,8 +635,8 @@ class T_prod_dtype(unittest.TestCase):
# We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and
output_dtype in continuous_dtypes)
(input_dtype in tensor.discrete_dtypes and
output_dtype in tensor.continuous_dtypes)
):
prod_var = x.prod(dtype=output_dtype, axis=axis)
assert prod_var.dtype == output_dtype
......@@ -684,8 +684,8 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
# We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and
output_dtype in continuous_dtypes)
(input_dtype in tensor.discrete_dtypes and
output_dtype in tensor.continuous_dtypes)
):
prod_woz_var = ProdWithoutZeros(
axis=axis, dtype=output_dtype)(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论