提交 3179ca7a authored 作者: Ian Goodfellow's avatar Ian Goodfellow

pep8 tensor/basic.py

上级 fd08055e
...@@ -474,7 +474,6 @@ def get_constant_value(v): ...@@ -474,7 +474,6 @@ def get_constant_value(v):
code, but I'm not sure where it is. code, but I'm not sure where it is.
""" """
if isinstance(v, Constant): if isinstance(v, Constant):
if getattr(v.tag, 'unique_value', None) is not None: if getattr(v.tag, 'unique_value', None) is not None:
data = v.tag.unique_value data = v.tag.unique_value
...@@ -483,7 +482,7 @@ def get_constant_value(v): ...@@ -483,7 +482,7 @@ def get_constant_value(v):
# handle case where data is numpy.array([]) # handle case where data is numpy.array([])
if hasattr(data, 'shape') and len(data.shape) == 0 or \ if hasattr(data, 'shape') and len(data.shape) == 0 or \
__builtins__['max'](data.shape) == 0: __builtins__['max'](data.shape) == 0:
assert numpy.all(numpy.array([])==data) assert numpy.all(numpy.array([]) == data)
return data return data
try: try:
numpy.complex(data) # works for all numeric scalars numpy.complex(data) # works for all numeric scalars
...@@ -2126,7 +2125,7 @@ class Shape(Op): ...@@ -2126,7 +2125,7 @@ class Shape(Op):
# the elements of the tensor variable do not participate # the elements of the tensor variable do not participate
# in the computation of the shape, so they are not really # in the computation of the shape, so they are not really
# part of the graph # part of the graph
return [ DisconnectedType()() ] return [DisconnectedType()()]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [None] return [None]
...@@ -2209,7 +2208,7 @@ class SpecifyShape(Op): ...@@ -2209,7 +2208,7 @@ class SpecifyShape(Op):
return [new_shape] return [new_shape]
def connection_pattern(self, node): def connection_pattern(self, node):
return [[True],[False]] return [[True], [False]]
def grad(self, inp, grads): def grad(self, inp, grads):
x, s = inp x, s = inp
...@@ -3121,7 +3120,6 @@ class Alloc(gof.Op): ...@@ -3121,7 +3120,6 @@ class Alloc(gof.Op):
def infer_shape(self, node, input_shapes): def infer_shape(self, node, input_shapes):
return [node.inputs[1:]] return [node.inputs[1:]]
def connection_pattern(self, node): def connection_pattern(self, node):
rval = [[True]] rval = [[True]]
...@@ -3810,7 +3808,6 @@ class Subtensor(Op): ...@@ -3810,7 +3808,6 @@ class Subtensor(Op):
else: else:
return scal.as_scalar(a) return scal.as_scalar(a)
def make_node(self, x, *inputs): def make_node(self, x, *inputs):
x = as_tensor_variable(x) x = as_tensor_variable(x)
inputs = tuple(self.my_as_scalar(a) for a in inputs) inputs = tuple(self.my_as_scalar(a) for a in inputs)
...@@ -3908,7 +3905,7 @@ class Subtensor(Op): ...@@ -3908,7 +3905,7 @@ class Subtensor(Op):
def connection_pattern(self, node): def connection_pattern(self, node):
rval = [ [True] ] rval = [[True]]
for ipt in node.inputs[1:]: for ipt in node.inputs[1:]:
rval.append([False]) rval.append([False])
...@@ -4569,7 +4566,7 @@ class IncSubtensor(Op): ...@@ -4569,7 +4566,7 @@ class IncSubtensor(Op):
def connection_pattern(self, node): def connection_pattern(self, node):
rval = [ [True], [True] ] rval = [[True], [True]]
for ipt in node.inputs[2:]: for ipt in node.inputs[2:]:
rval.append([False]) rval.append([False])
...@@ -4979,7 +4976,7 @@ class Join(Op): ...@@ -4979,7 +4976,7 @@ class Join(Op):
gz, = grads gz, = grads
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:] axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
rval = [ grad_undefined(self, 0, axis) ] rval = [grad_undefined(self, 0, axis)]
if 'float' in tensors[0].dtype or 'complex' in tensors[0].dtype: if 'float' in tensors[0].dtype or 'complex' in tensors[0].dtype:
# assume that this is differentiable # assume that this is differentiable
...@@ -4994,7 +4991,7 @@ class Join(Op): ...@@ -4994,7 +4991,7 @@ class Join(Op):
else: else:
# the output has integer type, so the gradient through it # the output has integer type, so the gradient through it
# is 0 # is 0
rval = rval + [ tensor.zeros_like() for tensor in tensors ] rval = rval + [tensor.zeros_like() for tensor in tensors]
return rval return rval
...@@ -5239,6 +5236,7 @@ def vertical_stack(*args): ...@@ -5239,6 +5236,7 @@ def vertical_stack(*args):
assert arg.type.ndim == 2 assert arg.type.ndim == 2
return concatenate(args, axis=0) return concatenate(args, axis=0)
class Reshape(Op): class Reshape(Op):
"""Perform a reshape operation of the input x to the new shape shp. """Perform a reshape operation of the input x to the new shape shp.
...@@ -5657,7 +5655,7 @@ class ARange(Op): ...@@ -5657,7 +5655,7 @@ class ARange(Op):
def connection_pattern(self, node): def connection_pattern(self, node):
return [ [True], [False], [True] ] return [[True], [False], [True]]
def grad(self, inputs, grads): def grad(self, inputs, grads):
start, stop, step = inputs start, stop, step = inputs
...@@ -5667,9 +5665,9 @@ class ARange(Op): ...@@ -5667,9 +5665,9 @@ class ARange(Op):
# no gradient through them # no gradient through them
# stop does not affect the output values, # stop does not affect the output values,
# just the output shape, so it is disconnected # just the output shape, so it is disconnected
return [ start.zeros_like(), return [start.zeros_like(),
DisconnectedType()(), DisconnectedType()(),
step.zeros_like() ] step.zeros_like()]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
return [None] return [None]
...@@ -5903,8 +5901,8 @@ class PermuteRowElements(Op): ...@@ -5903,8 +5901,8 @@ class PermuteRowElements(Op):
# are non-integer, so the gradient with respect to them is # are non-integer, so the gradient with respect to them is
# undefined # undefined
return [gx, grad_undefined(self,1,y), return [gx, grad_undefined(self, 1, y),
grad_undefined(self,1,inverse)] grad_undefined(self, 1, inverse)]
_permute_row_elements = PermuteRowElements() _permute_row_elements = PermuteRowElements()
...@@ -5970,10 +5968,10 @@ class AdvancedSubtensor1(Op): ...@@ -5970,10 +5968,10 @@ class AdvancedSubtensor1(Op):
def connection_pattern(self, node): def connection_pattern(self, node):
rval = [ [True ] ] rval = [[True]]
for ipt in node.inputs[1:]: for ipt in node.inputs[1:]:
rval.append([ False ]) rval.append([False])
return rval return rval
...@@ -6081,10 +6079,9 @@ class AdvancedIncSubtensor1(Op): ...@@ -6081,10 +6079,9 @@ class AdvancedIncSubtensor1(Op):
return self.make_node(eval_points[0], eval_points[1], return self.make_node(eval_points[0], eval_points[1],
*inputs[2:]).outputs *inputs[2:]).outputs
def connection_pattern(self, node): def connection_pattern(self, node):
rval = [[True], [True] ] rval = [[True], [True]]
for ipt in node.inputs[2:]: for ipt in node.inputs[2:]:
rval.append([False]) rval.append([False])
...@@ -6099,7 +6096,7 @@ class AdvancedIncSubtensor1(Op): ...@@ -6099,7 +6096,7 @@ class AdvancedIncSubtensor1(Op):
gx = g_output gx = g_output
gy = advanced_subtensor1(g_output, *idx_list) gy = advanced_subtensor1(g_output, *idx_list)
return [gx, gy] + [ DisconnectedType()()] * len(idx_list) return [gx, gy] + [DisconnectedType()()] * len(idx_list)
advanced_inc_subtensor1 = AdvancedIncSubtensor1() advanced_inc_subtensor1 = AdvancedIncSubtensor1()
...@@ -6190,7 +6187,7 @@ class AdvancedSubtensor(Op): ...@@ -6190,7 +6187,7 @@ class AdvancedSubtensor(Op):
def connection_pattern(self, node): def connection_pattern(self, node):
rval = [ [True] ] rval = [[True]]
for ipt in node.inputs[1:]: for ipt in node.inputs[1:]:
rval.append([False]) rval.append([False])
...@@ -6290,7 +6287,7 @@ class AdvancedIncSubtensor(Op): ...@@ -6290,7 +6287,7 @@ class AdvancedIncSubtensor(Op):
def connection_pattern(self, node): def connection_pattern(self, node):
rval = [ [True], [True] ] rval = [[True], [True]]
for ipt in node.inputs[2:]: for ipt in node.inputs[2:]:
rval.append([False]) rval.append([False])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论