提交 3179ca7a authored 作者: Ian Goodfellow's avatar Ian Goodfellow

pep8 tensor/basic.py

上级 fd08055e
......@@ -474,7 +474,6 @@ def get_constant_value(v):
code, but I'm not sure where it is.
"""
if isinstance(v, Constant):
if getattr(v.tag, 'unique_value', None) is not None:
data = v.tag.unique_value
......@@ -483,7 +482,7 @@ def get_constant_value(v):
# handle case where data is numpy.array([])
if hasattr(data, 'shape') and len(data.shape) == 0 or \
__builtins__['max'](data.shape) == 0:
assert numpy.all(numpy.array([])==data)
assert numpy.all(numpy.array([]) == data)
return data
try:
numpy.complex(data) # works for all numeric scalars
......@@ -2126,7 +2125,7 @@ class Shape(Op):
# the elements of the tensor variable do not participate
# in the computation of the shape, so they are not really
# part of the graph
return [ DisconnectedType()() ]
return [DisconnectedType()()]
def R_op(self, inputs, eval_points):
return [None]
......@@ -2209,7 +2208,7 @@ class SpecifyShape(Op):
return [new_shape]
def connection_pattern(self, node):
return [[True],[False]]
return [[True], [False]]
def grad(self, inp, grads):
x, s = inp
......@@ -3121,7 +3120,6 @@ class Alloc(gof.Op):
def infer_shape(self, node, input_shapes):
return [node.inputs[1:]]
def connection_pattern(self, node):
rval = [[True]]
......@@ -3810,7 +3808,6 @@ class Subtensor(Op):
else:
return scal.as_scalar(a)
def make_node(self, x, *inputs):
x = as_tensor_variable(x)
inputs = tuple(self.my_as_scalar(a) for a in inputs)
......@@ -3908,7 +3905,7 @@ class Subtensor(Op):
def connection_pattern(self, node):
rval = [ [True] ]
rval = [[True]]
for ipt in node.inputs[1:]:
rval.append([False])
......@@ -4569,7 +4566,7 @@ class IncSubtensor(Op):
def connection_pattern(self, node):
rval = [ [True], [True] ]
rval = [[True], [True]]
for ipt in node.inputs[2:]:
rval.append([False])
......@@ -4979,7 +4976,7 @@ class Join(Op):
gz, = grads
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
rval = [ grad_undefined(self, 0, axis) ]
rval = [grad_undefined(self, 0, axis)]
if 'float' in tensors[0].dtype or 'complex' in tensors[0].dtype:
# assume that this is differentiable
......@@ -4994,7 +4991,7 @@ class Join(Op):
else:
# the output has integer type, so the gradient through it
# is 0
rval = rval + [ tensor.zeros_like() for tensor in tensors ]
rval = rval + [tensor.zeros_like() for tensor in tensors]
return rval
......@@ -5239,6 +5236,7 @@ def vertical_stack(*args):
assert arg.type.ndim == 2
return concatenate(args, axis=0)
class Reshape(Op):
"""Perform a reshape operation of the input x to the new shape shp.
......@@ -5657,7 +5655,7 @@ class ARange(Op):
def connection_pattern(self, node):
return [ [True], [False], [True] ]
return [[True], [False], [True]]
def grad(self, inputs, grads):
start, stop, step = inputs
......@@ -5667,9 +5665,9 @@ class ARange(Op):
# no gradient through them
# stop does not affect the output values,
# just the output shape, so it is disconnected
return [ start.zeros_like(),
return [start.zeros_like(),
DisconnectedType()(),
step.zeros_like() ]
step.zeros_like()]
def R_op(self, inputs, eval_points):
return [None]
......@@ -5903,8 +5901,8 @@ class PermuteRowElements(Op):
# are non-integer, so the gradient with respect to them is
# undefined
return [gx, grad_undefined(self,1,y),
grad_undefined(self,1,inverse)]
return [gx, grad_undefined(self, 1, y),
grad_undefined(self, 1, inverse)]
_permute_row_elements = PermuteRowElements()
......@@ -5970,10 +5968,10 @@ class AdvancedSubtensor1(Op):
def connection_pattern(self, node):
rval = [ [True ] ]
rval = [[True]]
for ipt in node.inputs[1:]:
rval.append([ False ])
rval.append([False])
return rval
......@@ -6081,10 +6079,9 @@ class AdvancedIncSubtensor1(Op):
return self.make_node(eval_points[0], eval_points[1],
*inputs[2:]).outputs
def connection_pattern(self, node):
rval = [[True], [True] ]
rval = [[True], [True]]
for ipt in node.inputs[2:]:
rval.append([False])
......@@ -6099,7 +6096,7 @@ class AdvancedIncSubtensor1(Op):
gx = g_output
gy = advanced_subtensor1(g_output, *idx_list)
return [gx, gy] + [ DisconnectedType()()] * len(idx_list)
return [gx, gy] + [DisconnectedType()()] * len(idx_list)
advanced_inc_subtensor1 = AdvancedIncSubtensor1()
......@@ -6190,7 +6187,7 @@ class AdvancedSubtensor(Op):
def connection_pattern(self, node):
rval = [ [True] ]
rval = [[True]]
for ipt in node.inputs[1:]:
rval.append([False])
......@@ -6290,7 +6287,7 @@ class AdvancedIncSubtensor(Op):
def connection_pattern(self, node):
rval = [ [True], [True] ]
rval = [[True], [True]]
for ipt in node.inputs[2:]:
rval.append([False])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论