提交 28ea8a17 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

PEP8 fixes

上级 f7bb0730
...@@ -3985,7 +3985,6 @@ class IncSubtensor(Op): ...@@ -3985,7 +3985,6 @@ class IncSubtensor(Op):
else: else:
return () return ()
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
return [shapes[0]] return [shapes[0]]
...@@ -4004,13 +4003,13 @@ class IncSubtensor(Op): ...@@ -4004,13 +4003,13 @@ class IncSubtensor(Op):
if self.set_instead_of_inc: if self.set_instead_of_inc:
gx = set_subtensor( gx = set_subtensor(
Subtensor(idx_list=self.idx_list)(g_output,*idx_list), Subtensor(idx_list=self.idx_list)(g_output, *idx_list),
zeros_like(y)) zeros_like(y))
else: else:
gx = g_output gx = g_output
gy = Subtensor(idx_list = self.idx_list)(g_output, *idx_list) gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
return [gx, gy] + [None]*len(idx_list) return [gx, gy] + [None] * len(idx_list)
def split(x, splits_size, n_splits, axis=0): def split(x, splits_size, n_splits, axis=0):
...@@ -4149,7 +4148,6 @@ class Rebroadcast(Op): ...@@ -4149,7 +4148,6 @@ class Rebroadcast(Op):
items.sort() # no ambiguity because each item key is unique items.sort() # no ambiguity because each item key is unique
return hash(type(self)) ^ hash(tuple(items)) return hash(type(self)) ^ hash(tuple(items))
def __str__(self): def __str__(self):
if len(self.axis) == 0: if len(self.axis) == 0:
broadcast_pattern = [] broadcast_pattern = []
...@@ -4215,6 +4213,7 @@ def addbroadcast(x, *axes): ...@@ -4215,6 +4213,7 @@ def addbroadcast(x, *axes):
rval = Rebroadcast(*[(axis, True) for axis in axes])(x) rval = Rebroadcast(*[(axis, True) for axis in axes])(x)
return theano.tensor.opt.apply_rebroadcast_opt(rval) return theano.tensor.opt.apply_rebroadcast_opt(rval)
def unbroadcast(x, *axes): def unbroadcast(x, *axes):
""" """
Make the input impossible to broadcast in the specified axes. Make the input impossible to broadcast in the specified axes.
...@@ -4230,9 +4229,11 @@ def patternbroadcast(x, broadcastable): ...@@ -4230,9 +4229,11 @@ def patternbroadcast(x, broadcastable):
""" """
Make the input adopt a specific broadcasting pattern. Make the input adopt a specific broadcasting pattern.
We apply the opt here not to pollute the graph especially during the gpu optimization We apply the opt here not to pollute the graph especially during the gpu
optimization.
""" """
rval = Rebroadcast(*[(i,broadcastable[i]) for i in xrange(len(broadcastable))])(x) rval = Rebroadcast(*[(i, broadcastable[i])
for i in xrange(len(broadcastable))])(x)
return theano.tensor.opt.apply_rebroadcast_opt(rval) return theano.tensor.opt.apply_rebroadcast_opt(rval)
...@@ -4331,8 +4332,8 @@ class Join(Op): ...@@ -4331,8 +4332,8 @@ class Join(Op):
# be broadcastable for the output. # be broadcastable for the output.
for x in as_tensor_variable_args: for x in as_tensor_variable_args:
for current_axis, bflag in enumerate(x.type.broadcastable): for current_axis, bflag in enumerate(x.type.broadcastable):
# Not sure if this Op supports/supported/will support # Not sure if this Op supports/supported/will support
# negative indices, but just to be sure... # negative indices, but just to be sure...
if current_axis == axis % ndim: if current_axis == axis % ndim:
continue continue
if bflag: if bflag:
...@@ -4554,14 +4555,15 @@ def stack(*tensors): ...@@ -4554,14 +4555,15 @@ def stack(*tensors):
# This should be an optimization! # This should be an optimization!
# Doing it here make the graph less canonicalized # Doing it here make the graph less canonicalized
# (more type need to be understood by all optimization) # (more type need to be understood by all optimization)
# And DebugMode can't detect error in this code as it is not in an optimization. # And DebugMode can't detect error in this code as it is not in an
# optimization.
# See ticket #660 # See ticket #660
if numpy.all([ if numpy.all([
# in case there is direct int in tensors. # in case there is direct int in tensors.
isinstance(t, (numpy.number, float, int, python_complex)) or isinstance(t, (numpy.number, float, int, python_complex)) or
(isinstance(t, Variable) and (isinstance(t, Variable) and
isinstance(t.type, TensorType) and isinstance(t.type, TensorType) and
t.ndim==0) t.ndim == 0)
for t in tensors]): for t in tensors]):
# in case there is direct int # in case there is direct int
tensors = map(as_tensor_variable, tensors) tensors = map(as_tensor_variable, tensors)
...@@ -4650,7 +4652,9 @@ def vertical_stack(*args): ...@@ -4650,7 +4652,9 @@ def vertical_stack(*args):
return concatenate(args, axis=0) return concatenate(args, axis=0)
if 0: #vertical and horizontal stacking are deprecated. Better to use stack() and join(). # Vertical and horizontal stacking are deprecated. Better to use stack() and
# join().
if 0:
class VerticalStack(Op): class VerticalStack(Op):
""" """
Vertically stack two L{TensorType}s. Vertically stack two L{TensorType}s.
...@@ -4788,8 +4792,8 @@ class Reshape(Op): ...@@ -4788,8 +4792,8 @@ class Reshape(Op):
# because it tries to replace the Shape_i node by the switch # because it tries to replace the Shape_i node by the switch
# statement, which depends on Shape_i. # statement, which depends on Shape_i.
#return [tuple([switch(eq(node.inputs[1][i], -1), #return [tuple([switch(eq(node.inputs[1][i], -1),
# theano.tensor.opt.Shape_i(i)(node.outputs[0]), # theano.tensor.opt.Shape_i(i)(node.outputs[0]),
# node.inputs[1][i]) # node.inputs[1][i])
# for i in xrange(self.ndim)] # for i in xrange(self.ndim)]
# )] # )]
...@@ -5168,17 +5172,17 @@ class PermuteRowElements(Op): ...@@ -5168,17 +5172,17 @@ class PermuteRowElements(Op):
if xs0 == ys0: if xs0 == ys0:
for i in xrange(xs0): for i in xrange(xs0):
self._rec_perform(node, x[i], y[i], inverse, out[i], self._rec_perform(node, x[i], y[i], inverse, out[i],
curdim+1) curdim + 1)
elif ys0 == 1 and node.inputs[1].type.broadcastable[curdim]: elif ys0 == 1 and node.inputs[1].type.broadcastable[curdim]:
# Broadcast y # Broadcast y
for i in xrange(xs0): for i in xrange(xs0):
self._rec_perform(node, x[i], y[0], inverse, out[i], self._rec_perform(node, x[i], y[0], inverse, out[i],
curdim+1) curdim + 1)
elif xs0 == 1 and node.inputs[0].type.broadcastable[curdim]: elif xs0 == 1 and node.inputs[0].type.broadcastable[curdim]:
# Broadcast x # Broadcast x
for i in xrange(ys0): for i in xrange(ys0):
self._rec_perform(node, x[0], y[i], inverse, out[i], self._rec_perform(node, x[0], y[i], inverse, out[i],
curdim+1) curdim + 1)
else: else:
raise ValueError('Dimension mismatch: %s, %s' % (xs0, ys0)) raise ValueError('Dimension mismatch: %s, %s' % (xs0, ys0))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论