提交 669b911a authored 作者: David Warde-Farley's avatar David Warde-Farley

Merge pull request #418 from delallea/minor

Minor stuff
...@@ -72,7 +72,7 @@ Scan fix: ...@@ -72,7 +72,7 @@ Scan fix:
* computing grad of a function of grad of scan(reported by ?, Razvan) * computing grad of a function of grad of scan(reported by ?, Razvan)
before : most of the time crash, but could be wrong value with bad number of dimensions(so a visible bug) before : most of the time crash, but could be wrong value with bad number of dimensions(so a visible bug)
now : do the right thing. now : do the right thing.
* gradient with respect to outputs using multiple taps(Timothy reported, fix by Razvan) * gradient with respect to outputs using multiple taps(reported by Timothy, fix by Razvan)
before : it used to return wrong values before : it used to return wrong values
now : do the right thing. now : do the right thing.
Note: The reported case of this bug was happening in conjunction with the Note: The reported case of this bug was happening in conjunction with the
......
...@@ -37,9 +37,10 @@ python_any = any ...@@ -37,9 +37,10 @@ python_any = any
python_all = all python_all = all
# Define common subsets of dtypes (as strings). # Define common subsets of dtypes (as strings).
int_dtypes = map(str, scal.int_types)
discrete_dtypes = map(str, scal.discrete_types)
complex_dtypes = map(str, scal.complex_types) complex_dtypes = map(str, scal.complex_types)
continuous_dtypes = map(str, scal.continuous_types)
discrete_dtypes = map(str, scal.discrete_types)
int_dtypes = map(str, scal.int_types)
class ShapeError(Exception): class ShapeError(Exception):
...@@ -2686,8 +2687,10 @@ def mean(input, axis=None, dtype=None, op=False): ...@@ -2686,8 +2687,10 @@ def mean(input, axis=None, dtype=None, op=False):
:param dtype: dtype to use for the inner summation. This will not :param dtype: dtype to use for the inner summation. This will not
necessarily be the dtype of the output (in particular necessarily be the dtype of the output (in particular
if it is a discrete (int/uint) dtype, the output will if it is a discrete (int/uint) dtype, the output will
be in a float type) be in a float type).
:type dtype: string If None, then we use float64 for a discrete input, and the
same rules as `sum()` for a continuous input.
:type dtype: None or string
:note: for gpu, if you specify dtype=float32, everything will be done :note: for gpu, if you specify dtype=float32, everything will be done
on the gpu. on the gpu.
...@@ -2712,6 +2715,8 @@ def mean(input, axis=None, dtype=None, op=False): ...@@ -2712,6 +2715,8 @@ def mean(input, axis=None, dtype=None, op=False):
shp = shape(input) shp = shape(input)
# Cast shp into a float type # Cast shp into a float type
# TODO Once we have a consistent casting policy, we could simply
# use true_div.
if s.dtype in ('float32', 'complex64'): if s.dtype in ('float32', 'complex64'):
shp = cast(shp, 'float32') shp = cast(shp, 'float32')
else: else:
......
...@@ -13,7 +13,6 @@ from theano.gof.python25 import all, any ...@@ -13,7 +13,6 @@ from theano.gof.python25 import all, any
config = theano.config config = theano.config
# tensor depends on elemwise to provide definitions for several ops # tensor depends on elemwise to provide definitions for several ops
# but elemwise needs to make TensorType instances, so we have these as # but elemwise needs to make TensorType instances, so we have these as
# placeholders and the tensor module fills them # placeholders and the tensor module fills them
...@@ -29,10 +28,6 @@ def TensorVariable(*inputs, **kwargs): ...@@ -29,10 +28,6 @@ def TensorVariable(*inputs, **kwargs):
def TensorConstant(*inputs, **kwargs): def TensorConstant(*inputs, **kwargs):
raise Exception("Circular dependencies prevent using this here. import tensor before elemwise") raise Exception("Circular dependencies prevent using this here. import tensor before elemwise")
# Define common subsets of dtypes (as strings).
discrete_dtypes = map(str, scalar.discrete_types)
continuous_dtypes = map(str, scalar.continuous_types)
################## ##################
### DimShuffle ### ### DimShuffle ###
...@@ -1337,9 +1332,9 @@ class CAReduceDtype(CAReduce): ...@@ -1337,9 +1332,9 @@ class CAReduceDtype(CAReduce):
:param scalar_op: a binary scalar op with only one output. :param scalar_op: a binary scalar op with only one output.
It must be commutative and associative. It must be commutative and associative.
:axis: - the dimension along which we want to reduce :param axis: - the dimension along which we want to reduce
- list of dimensions that we want to reduce - list of dimensions that we want to reduce
- if None, all dimensions are reduced - if None, all dimensions are reduced
:param dtype: The dtype of the internal accumulator and returned :param dtype: The dtype of the internal accumulator and returned
tensor. If None, then we use the default dtype which is the same as the tensor. If None, then we use the default dtype which is the same as the
...@@ -1365,7 +1360,7 @@ class CAReduceDtype(CAReduce): ...@@ -1365,7 +1360,7 @@ class CAReduceDtype(CAReduce):
def _output_dtype(self, idtype): def _output_dtype(self, idtype):
dtype = self.dtype dtype = self.dtype
if dtype is None: if dtype is None:
# If input has an discrete dtype, upcast it to 64 # If input has a discrete dtype, upcast it to 64
return dict( return dict(
int8='int64', int8='int64',
int16='int64', int16='int64',
...@@ -1374,7 +1369,8 @@ class CAReduceDtype(CAReduce): ...@@ -1374,7 +1369,8 @@ class CAReduceDtype(CAReduce):
uint16='uint64', uint16='uint64',
uint32='uint64', uint32='uint64',
).get(idtype, idtype) ).get(idtype, idtype)
elif dtype in continuous_dtypes and idtype in discrete_dtypes: elif (dtype in theano.tensor.continuous_dtypes and
idtype in theano.tensor.discrete_dtypes):
# Specifying a continuous output for discrete input is OK # Specifying a continuous output for discrete input is OK
return dtype return dtype
else: else:
......
...@@ -13,7 +13,6 @@ from theano.compile.mode import get_default_mode ...@@ -13,7 +13,6 @@ from theano.compile.mode import get_default_mode
from theano.tensor.elemwise import * from theano.tensor.elemwise import *
from theano.tests import unittest_tools from theano.tests import unittest_tools
complex_dtypes = map(str, scalar.complex_types)
def Env(i, o): def Env(i, o):
e = gof.Env(i, o) e = gof.Env(i, o)
...@@ -534,8 +533,8 @@ class T_sum_dtype(unittest.TestCase): ...@@ -534,8 +533,8 @@ class T_sum_dtype(unittest.TestCase):
# We always allow int/uint inputs with float/complex outputs. # We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype) upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and (input_dtype in tensor.discrete_dtypes and
output_dtype in continuous_dtypes) output_dtype in tensor.continuous_dtypes)
): ):
sum_var = x.sum(dtype=output_dtype, axis=axis) sum_var = x.sum(dtype=output_dtype, axis=axis)
assert sum_var.dtype == output_dtype assert sum_var.dtype == output_dtype
...@@ -559,7 +558,7 @@ class T_mean_dtype(unittest.TestCase): ...@@ -559,7 +558,7 @@ class T_mean_dtype(unittest.TestCase):
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)): for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)] axis = axes[idx % len(axes)]
x = tensor.matrix(dtype=dtype).mean(axis=axis) x = tensor.matrix(dtype=dtype).mean(axis=axis)
if dtype in discrete_dtypes: if dtype in tensor.discrete_dtypes:
assert x.dtype == 'float64' assert x.dtype == 'float64'
else: else:
assert x.dtype == dtype, (x, x.dtype, dtype) assert x.dtype == dtype, (x, x.dtype, dtype)
...@@ -582,7 +581,7 @@ class T_mean_dtype(unittest.TestCase): ...@@ -582,7 +581,7 @@ class T_mean_dtype(unittest.TestCase):
pass pass
else: else:
# Executed if no TypeError was raised # Executed if no TypeError was raised
if sum_dtype in discrete_dtypes: if sum_dtype in tensor.discrete_dtypes:
assert mean_var.dtype == 'float64', (mean_var.dtype, sum_dtype) assert mean_var.dtype == 'float64', (mean_var.dtype, sum_dtype)
else: else:
assert mean_var.dtype == sum_dtype, (mean_var.dtype, output_dtype) assert mean_var.dtype == sum_dtype, (mean_var.dtype, output_dtype)
...@@ -594,7 +593,7 @@ class T_mean_dtype(unittest.TestCase): ...@@ -594,7 +593,7 @@ class T_mean_dtype(unittest.TestCase):
except NotImplementedError: except NotImplementedError:
# TrueDiv does not seem to have a gradient when # TrueDiv does not seem to have a gradient when
# the numerator is complex. # the numerator is complex.
if mean_var.dtype in complex_dtypes: if mean_var.dtype in tensor.complex_dtypes:
pass pass
else: else:
raise raise
...@@ -635,8 +634,8 @@ class T_prod_dtype(unittest.TestCase): ...@@ -635,8 +634,8 @@ class T_prod_dtype(unittest.TestCase):
# We always allow int/uint inputs with float/complex outputs. # We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype) upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and (input_dtype in tensor.discrete_dtypes and
output_dtype in continuous_dtypes) output_dtype in tensor.continuous_dtypes)
): ):
prod_var = x.prod(dtype=output_dtype, axis=axis) prod_var = x.prod(dtype=output_dtype, axis=axis)
assert prod_var.dtype == output_dtype assert prod_var.dtype == output_dtype
...@@ -684,8 +683,8 @@ class T_prod_without_zeros_dtype(unittest.TestCase): ...@@ -684,8 +683,8 @@ class T_prod_without_zeros_dtype(unittest.TestCase):
# We always allow int/uint inputs with float/complex outputs. # We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype) upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and (input_dtype in tensor.discrete_dtypes and
output_dtype in continuous_dtypes) output_dtype in tensor.continuous_dtypes)
): ):
prod_woz_var = ProdWithoutZeros( prod_woz_var = ProdWithoutZeros(
axis=axis, dtype=output_dtype)(x) axis=axis, dtype=output_dtype)(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论