提交 42b8b66e authored 作者: Ian Goodfellow's avatar Ian Goodfellow 提交者: Frederic

use izip where appropriate

上级 8054c46b
...@@ -65,7 +65,7 @@ def check_equal_numpy(x, y): ...@@ -65,7 +65,7 @@ def check_equal_numpy(x, y):
elif (isinstance(x, numpy.random.RandomState) and elif (isinstance(x, numpy.random.RandomState) and
isinstance(y, numpy.random.RandomState)): isinstance(y, numpy.random.RandomState)):
return python_all(numpy.all(a == b) for a, b in return python_all(numpy.all(a == b) for a, b in
zip(x.__getstate__(), y.__getstate__())) izip(x.__getstate__(), y.__getstate__()))
else: else:
return x == y return x == y
...@@ -3823,7 +3823,7 @@ class Subtensor(Op): ...@@ -3823,7 +3823,7 @@ class Subtensor(Op):
# infer the broadcasting pattern # infer the broadcasting pattern
padded = (idx_list padded = (idx_list
+ [slice(None, None, None)] * (x.type.ndim - len(idx_list))) + [slice(None, None, None)] * (x.type.ndim - len(idx_list)))
broadcastable = [bc for p, bc in zip(padded, x.type.broadcastable) broadcastable = [bc for p, bc in izip(padded, x.type.broadcastable)
if isinstance(p, slice)] if isinstance(p, slice)]
input_types = Subtensor.collapse(idx_list, input_types = Subtensor.collapse(idx_list,
...@@ -3832,7 +3832,7 @@ class Subtensor(Op): ...@@ -3832,7 +3832,7 @@ class Subtensor(Op):
raise IndexError( raise IndexError(
"Not enough inputs to fill in the Subtensor template.", "Not enough inputs to fill in the Subtensor template.",
inputs, idx_list) inputs, idx_list)
for input, expected_type in zip(inputs, input_types): for input, expected_type in izip(inputs, input_types):
if input.type != expected_type: if input.type != expected_type:
raise TypeError( raise TypeError(
"Wrong type for Subtensor template. Expected %s, got %s." "Wrong type for Subtensor template. Expected %s, got %s."
...@@ -4458,7 +4458,7 @@ class IncSubtensor(Op): ...@@ -4458,7 +4458,7 @@ class IncSubtensor(Op):
raise IndexError( raise IndexError(
"Not enough inputs to fill in the Subtensor template.", "Not enough inputs to fill in the Subtensor template.",
inputs, idx_list) inputs, idx_list)
for input, expected_type in zip(inputs, input_types): for input, expected_type in izip(inputs, input_types):
if input.type != expected_type: if input.type != expected_type:
raise TypeError( raise TypeError(
"Wrong type for Subtensor template. Expected %s, got %s." "Wrong type for Subtensor template. Expected %s, got %s."
...@@ -5830,7 +5830,7 @@ class PermuteRowElements(Op): ...@@ -5830,7 +5830,7 @@ class PermuteRowElements(Op):
# Compute the broadcastable pattern of the output # Compute the broadcastable pattern of the output
out_broadcastable = [xb and yb for xb, yb in out_broadcastable = [xb and yb for xb, yb in
zip(x.type.broadcastable, y.type.broadcastable)] izip(x.type.broadcastable, y.type.broadcastable)]
out_type = tensor(dtype=x.type.dtype, broadcastable=out_broadcastable) out_type = tensor(dtype=x.type.dtype, broadcastable=out_broadcastable)
inputlist = [x, y, inverse] inputlist = [x, y, inverse]
...@@ -5897,7 +5897,7 @@ class PermuteRowElements(Op): ...@@ -5897,7 +5897,7 @@ class PermuteRowElements(Op):
# Make sure the output is big enough # Make sure the output is big enough
out_s = [] out_s = []
for xdim, ydim in zip(x_s, y_s): for xdim, ydim in izip(x_s, y_s):
if xdim == ydim: if xdim == ydim:
outdim = xdim outdim = xdim
elif xdim == 1: elif xdim == 1:
......
...@@ -539,14 +539,14 @@ class Elemwise(Op): ...@@ -539,14 +539,14 @@ class Elemwise(Op):
# it is multiplied by nout because Elemwise supports multiple outputs # it is multiplied by nout because Elemwise supports multiple outputs
# (nout of them) # (nout of them)
out_broadcastables = [[all(bcast) out_broadcastables = [[all(bcast)
for bcast in zip(*[input.type.broadcastable for bcast in izip(*[input.type.broadcastable
for input in inputs])]] * shadow.nout for input in inputs])]] * shadow.nout
#inplace_pattern maps output idx -> input idx #inplace_pattern maps output idx -> input idx
inplace_pattern = self.inplace_pattern inplace_pattern = self.inplace_pattern
if inplace_pattern: if inplace_pattern:
for overwriter, overwritten in inplace_pattern.items(): for overwriter, overwritten in inplace_pattern.items():
for ob, ib in zip(out_broadcastables[overwriter], for ob, ib in izip(out_broadcastables[overwriter],
inputs[overwritten].type.broadcastable): inputs[overwritten].type.broadcastable):
if ib and not ob: if ib and not ob:
raise ValueError(( raise ValueError((
...@@ -561,7 +561,7 @@ class Elemwise(Op): ...@@ -561,7 +561,7 @@ class Elemwise(Op):
([i.type.dtype for i in inputs], out_dtypes, inplace_pattern))) ([i.type.dtype for i in inputs], out_dtypes, inplace_pattern)))
outputs = [TensorType(dtype=dtype, broadcastable=broadcastable)() outputs = [TensorType(dtype=dtype, broadcastable=broadcastable)()
for dtype, broadcastable in zip(out_dtypes, out_broadcastables) for dtype, broadcastable in izip(out_dtypes, out_broadcastables)
] ]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
...@@ -609,7 +609,7 @@ class Elemwise(Op): ...@@ -609,7 +609,7 @@ class Elemwise(Op):
bgrads = self._bgrad(inputs, ograds) bgrads = self._bgrad(inputs, ograds)
rop_out = None rop_out = None
for jdx, (inp, eval_point) in enumerate(zip(inputs, for jdx, (inp, eval_point) in enumerate(izip(inputs,
eval_points)): eval_points)):
# if None, then we can just ignore this branch .. # if None, then we can just ignore this branch ..
# what we do is to assume that for any non-differentiable # what we do is to assume that for any non-differentiable
...@@ -663,7 +663,7 @@ class Elemwise(Op): ...@@ -663,7 +663,7 @@ class Elemwise(Op):
# can tell this op did # can tell this op did
# the right thing. # the right thing.
new_rval = [] new_rval = []
for elem, ipt in zip(rval, inputs): for elem, ipt in izip(rval, inputs):
if isinstance(elem.type, (NullType, DisconnectedType)): if isinstance(elem.type, (NullType, DisconnectedType)):
new_rval.append(elem) new_rval.append(elem)
else: else:
...@@ -758,7 +758,7 @@ class Elemwise(Op): ...@@ -758,7 +758,7 @@ class Elemwise(Op):
*[transform(ipt) for ipt in node.inputs]) *[transform(ipt) for ipt in node.inputs])
return new_r return new_r
ret = [] ret = []
for scalar_igrad, ipt in zip(scalar_igrads, inputs): for scalar_igrad, ipt in izip(scalar_igrads, inputs):
if scalar_igrad is None: if scalar_igrad is None:
# undefined gradient # undefined gradient
ret.append(None) ret.append(None)
...@@ -769,7 +769,7 @@ class Elemwise(Op): ...@@ -769,7 +769,7 @@ class Elemwise(Op):
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
maxsize = max(len(input.shape) for input in inputs) maxsize = max(len(input.shape) for input in inputs)
for dims in zip(*[([(1, True)] * (maxsize - len(input.shape)) for dims in izip(*[([(1, True)] * (maxsize - len(input.shape))
+ zip(input.shape, sinput.type.broadcastable)) + zip(input.shape, sinput.type.broadcastable))
for input, sinput in zip(inputs, node.inputs)]): for input, sinput in zip(inputs, node.inputs)]):
if max(d for d, b in dims) != 1 and (1, False) in dims: if max(d for d, b in dims) != 1 and (1, False) in dims:
...@@ -801,7 +801,7 @@ class Elemwise(Op): ...@@ -801,7 +801,7 @@ class Elemwise(Op):
# Determine the shape of outputs # Determine the shape of outputs
out_shape = [] out_shape = []
for values in zip(*[input.shape for input in inputs]): for values in izip(*[input.shape for input in inputs]):
if numpy.prod(values) == 0: if numpy.prod(values) == 0:
# All non-broadcasted dimensions should be zero # All non-broadcasted dimensions should be zero
assert max(values) <= 1 assert max(values) <= 1
...@@ -811,7 +811,7 @@ class Elemwise(Op): ...@@ -811,7 +811,7 @@ class Elemwise(Op):
out_shape = tuple(out_shape) out_shape = tuple(out_shape)
if not self.inplace_pattern: if not self.inplace_pattern:
for output, storage in zip(node.outputs, output_storage): for output, storage in izip(node.outputs, output_storage):
odat = storage[0] odat = storage[0]
if odat is not None: if odat is not None:
if odat.shape != out_shape: if odat.shape != out_shape:
...@@ -823,7 +823,7 @@ class Elemwise(Op): ...@@ -823,7 +823,7 @@ class Elemwise(Op):
storage[0] = odat storage[0] = odat
else: else:
for i, (output, storage) in enumerate( for i, (output, storage) in enumerate(
zip(node.outputs, output_storage)): izip(node.outputs, output_storage)):
#i is an output idx #i is an output idx
if i in self.inplace_pattern: if i in self.inplace_pattern:
odat = inputs[self.inplace_pattern[i]] odat = inputs[self.inplace_pattern[i]]
...@@ -917,7 +917,7 @@ class Elemwise(Op): ...@@ -917,7 +917,7 @@ class Elemwise(Op):
else: else:
# there must be some input that is not broadcastable in # there must be some input that is not broadcastable in
# dimension 'dim' # dimension 'dim'
for ishp, i in zip(i_shapes, node.inputs): for ishp, i in izip(i_shapes, node.inputs):
if isinstance(i.type, theano.scalar.Scalar): if isinstance(i.type, theano.scalar.Scalar):
continue # we skip scalar continue # we skip scalar
if not i.type.broadcastable[dim]: if not i.type.broadcastable[dim]:
...@@ -960,7 +960,7 @@ class Elemwise(Op): ...@@ -960,7 +960,7 @@ class Elemwise(Op):
# These are the outputs that we will need to allocate # These are the outputs that we will need to allocate
# (output, name, name of the c type), transposed # (output, name, name of the c type), transposed
real = zip(*[(r, s, r.type.dtype_specs()[1]) real = zip(*[(r, s, r.type.dtype_specs()[1])
for r, s in zip(node.outputs, onames) if r not in dmap]) for r, s in izip(node.outputs, onames) if r not in dmap])
if real: if real:
real_outputs, real_onames, real_odtypes = real real_outputs, real_onames, real_odtypes = real
else: else:
...@@ -970,7 +970,7 @@ class Elemwise(Op): ...@@ -970,7 +970,7 @@ class Elemwise(Op):
# (output, name), transposed (c type name not needed since we don't # (output, name), transposed (c type name not needed since we don't
# need to allocate. # need to allocate.
aliased = zip(*[(r, s) aliased = zip(*[(r, s)
for (r, s) in zip(node.outputs, onames) if r in dmap]) for (r, s) in izip(node.outputs, onames) if r in dmap])
if aliased: if aliased:
aliased_outputs, aliased_onames = aliased aliased_outputs, aliased_onames = aliased
else: else:
...@@ -986,7 +986,7 @@ class Elemwise(Op): ...@@ -986,7 +986,7 @@ class Elemwise(Op):
# dimensionality) # dimensionality)
nnested = len(orders[0]) nnested = len(orders[0])
sub = dict(sub) sub = dict(sub)
for i, (input, iname) in enumerate(zip(inputs, inames)): for i, (input, iname) in enumerate(izip(inputs, inames)):
# the c generators will substitute the input names for # the c generators will substitute the input names for
# references to loop variables lv0, lv1, ... # references to loop variables lv0, lv1, ...
sub['lv%i' % i] = iname sub['lv%i' % i] = iname
...@@ -998,7 +998,7 @@ class Elemwise(Op): ...@@ -998,7 +998,7 @@ class Elemwise(Op):
# We loop over the "real" outputs, i.e., those that are not # We loop over the "real" outputs, i.e., those that are not
# inplace (must be allocated) and we declare/allocate/check # inplace (must be allocated) and we declare/allocate/check
# them # them
for output, oname, odtype in zip( for output, oname, odtype in izip(
real_outputs, real_onames, real_odtypes): real_outputs, real_onames, real_odtypes):
i += 1 # before this loop, i = number of inputs i += 1 # before this loop, i = number of inputs
sub['lv%i' % i] = oname sub['lv%i' % i] = oname
...@@ -1014,7 +1014,7 @@ class Elemwise(Op): ...@@ -1014,7 +1014,7 @@ class Elemwise(Op):
# inplace (overwrite the contents of one of the inputs) and # inplace (overwrite the contents of one of the inputs) and
# make the output pointers point to theur corresponding input # make the output pointers point to theur corresponding input
# pointers. # pointers.
for output, oname in zip(aliased_outputs, aliased_onames): for output, oname in izip(aliased_outputs, aliased_onames):
olv_index = inputs.index(dmap[output][0]) olv_index = inputs.index(dmap[output][0])
iname = inames[olv_index] iname = inames[olv_index]
# We make the output point to the corresponding input and # We make the output point to the corresponding input and
...@@ -1040,7 +1040,7 @@ class Elemwise(Op): ...@@ -1040,7 +1040,7 @@ class Elemwise(Op):
# not be declared, as they are #defined in defines # not be declared, as they are #defined in defines
task_decl = "".join([ task_decl = "".join([
"%(dtype)s& %(name)s_i = *%(name)s_iter;\n" % locals() "%(dtype)s& %(name)s_i = *%(name)s_iter;\n" % locals()
for name, dtype in zip(inames + list(real_onames), for name, dtype in izip(inames + list(real_onames),
idtypes + list(real_odtypes))]) idtypes + list(real_odtypes))])
# We generate the C code of the inner loop using the scalar op # We generate the C code of the inner loop using the scalar op
...@@ -1339,7 +1339,7 @@ class CAReduce(Op): ...@@ -1339,7 +1339,7 @@ class CAReduce(Op):
nnested = len(order1) nnested = len(order1)
sub = dict(sub) sub = dict(sub)
for i, (input, iname) in enumerate(zip(node.inputs, inames)): for i, (input, iname) in enumerate(izip(node.inputs, inames)):
sub['lv%i' % i] = iname sub['lv%i' % i] = iname
decl = cgen.make_declare([order], [idtype], sub) decl = cgen.make_declare([order], [idtype], sub)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论