提交 48aa9e84 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

fixed the confusion between python all / any and theano all / any

上级 cf5e084b
...@@ -34,6 +34,8 @@ _logger=logging.getLogger("theano.tensor.basic") ...@@ -34,6 +34,8 @@ _logger=logging.getLogger("theano.tensor.basic")
#This is needed as we will hide it later #This is needed as we will hide it later
python_complex=complex python_complex=complex
python_any = any
python_all = all
# Define common subsets of dtypes (as strings). # Define common subsets of dtypes (as strings).
int_dtypes = map(str, scal.int_types) int_dtypes = map(str, scal.int_types)
...@@ -54,7 +56,7 @@ def check_equal_numpy(x, y): ...@@ -54,7 +56,7 @@ def check_equal_numpy(x, y):
if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray): if isinstance(x, numpy.ndarray) and isinstance(y, numpy.ndarray):
return x.dtype == y.dtype and x.shape == y.shape and numpy.any(abs(x - y) < 1e-10) return x.dtype == y.dtype and x.shape == y.shape and numpy.any(abs(x - y) < 1e-10)
elif isinstance(x, numpy.random.RandomState) and isinstance(y, numpy.random.RandomState): elif isinstance(x, numpy.random.RandomState) and isinstance(y, numpy.random.RandomState):
return all(numpy.all(a==b) for a, b in zip(x.__getstate__(), y.__getstate__())) return python_all(numpy.all(a==b) for a, b in zip(x.__getstate__(), y.__getstate__()))
else: else:
return x == y return x == y
...@@ -142,7 +144,7 @@ def as_tensor_variable(x, name=None, ndim=None): ...@@ -142,7 +144,7 @@ def as_tensor_variable(x, name=None, ndim=None):
return shape_padleft(x, n_ones=(ndim - x.type.ndim)) return shape_padleft(x, n_ones=(ndim - x.type.ndim))
else: else:
return x return x
if isinstance(x, (tuple, list)) and any(isinstance(xi, Variable) for xi in x): if isinstance(x, (tuple, list)) and python_any(isinstance(xi, Variable) for xi in x):
try: try:
return stack(*x) return stack(*x)
except (TypeError, ValueError): except (TypeError, ValueError):
...@@ -465,7 +467,7 @@ def get_constant_value(v): ...@@ -465,7 +467,7 @@ def get_constant_value(v):
# Ensure the Join is joining only scalar variables (so that # Ensure the Join is joining only scalar variables (so that
# the constant value can be found at the same index as the one # the constant value can be found at the same index as the one
# used in the sub-tensor). # used in the sub-tensor).
all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs) and python_all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs) and
len(v.owner.op.idx_list) == 1): len(v.owner.op.idx_list) == 1):
# Note the '+ 1' is because the first argument to Join is the # Note the '+ 1' is because the first argument to Join is the
...@@ -479,7 +481,7 @@ def get_constant_value(v): ...@@ -479,7 +481,7 @@ def get_constant_value(v):
theano.tensor.opt.MakeVector) and theano.tensor.opt.MakeVector) and
# MakeVector normally accept only scalar as input. # MakeVector normally accept only scalar as input.
# We put this check in case there is change in the future # We put this check in case there is change in the future
all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs) and python_all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs) and
len(v.owner.op.idx_list) == 1): len(v.owner.op.idx_list) == 1):
ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]] ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]]
...@@ -812,7 +814,7 @@ class TensorType(Type): ...@@ -812,7 +814,7 @@ class TensorType(Type):
if b in named_broadcastable: if b in named_broadcastable:
bcast = named_broadcastable[b] bcast = named_broadcastable[b]
else: else:
if any(b): if python_any(b):
bcast = str(b) bcast = str(b)
else: else:
bcast = '%iD' % len(b) bcast = '%iD' % len(b)
...@@ -3935,7 +3937,7 @@ class Split(Op): ...@@ -3935,7 +3937,7 @@ class Split(Op):
if numpy.sum(splits) != len_along_axis: if numpy.sum(splits) != len_along_axis:
raise ValueError('The splits sum to %s, expected %s' % (numpy.sum(splits), len_along_axis)) raise ValueError('The splits sum to %s, expected %s' % (numpy.sum(splits), len_along_axis))
if not all(splits): if not python_all(splits):
raise ValueError('Cannot have a split of zero.') raise ValueError('Cannot have a split of zero.')
# Checking is done, let's roll the splitting algorithm! # Checking is done, let's roll the splitting algorithm!
...@@ -4108,7 +4110,7 @@ class Join(Op): ...@@ -4108,7 +4110,7 @@ class Join(Op):
def _make_node_internal(self, axis, tensors, def _make_node_internal(self, axis, tensors,
as_tensor_variable_args, output_maker): as_tensor_variable_args, output_maker):
orig = as_tensor_variable_args orig = as_tensor_variable_args
if not all(targs.type.ndim for targs in as_tensor_variable_args): if not python_all(targs.type.ndim for targs in as_tensor_variable_args):
raise TypeError('Join cannot handle arguments of dimension 0. For joining scalar values, see @stack'); raise TypeError('Join cannot handle arguments of dimension 0. For joining scalar values, see @stack');
# Handle single-tensor joins immediately. # Handle single-tensor joins immediately.
if len(as_tensor_variable_args) == 1: if len(as_tensor_variable_args) == 1:
...@@ -4166,7 +4168,7 @@ class Join(Op): ...@@ -4166,7 +4168,7 @@ class Join(Op):
outputs = [output_maker(bcastable)] outputs = [output_maker(bcastable)]
node = Apply(self, inputs, outputs) node = Apply(self, inputs, outputs)
if any(not x.type.broadcastable[0] for x in orig): if python_any(not x.type.broadcastable[0] for x in orig):
node.tag.shape_zero = None node.tag.shape_zero = None
else: else:
node.tag.shape_zero = len(orig) node.tag.shape_zero = len(orig)
...@@ -4759,7 +4761,7 @@ def arange(start, stop=None, step=1, dtype=None): ...@@ -4759,7 +4761,7 @@ def arange(start, stop=None, step=1, dtype=None):
config.floatX == 'float32' and config.floatX == 'float32' and
numpy_dtype == 'float64' and numpy_dtype == 'float64' and
# No explicit float64 in the three arguments? # No explicit float64 in the three arguments?
all(dt != 'float64' python_all(dt != 'float64'
for dt in [s.dtype for s in (start, stop, step)])): for dt in [s.dtype for s in (start, stop, step)])):
# We use float32 instead. # We use float32 instead.
assert dtype != 'float64' assert dtype != 'float64'
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论