提交 f3afab87 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

More pep8 / pyflakes

上级 034bb5a3
from theano import gof
from theano import gradient as G
from function_module import orig_function
......@@ -33,16 +32,19 @@ class OpFromGraph(gof.Op):
e2 = op(x, y, z) + op(z, y, x)
fn = function([x, y, z], [e2])
"""
def __init__(self, inputs, outputs, grad_depth = 1, **kwargs):
def __init__(self, inputs, outputs, grad_depth=1, **kwargs):
if not isinstance(outputs, list):
raise TypeError('outputs must be list', outputs)
for i in inputs + outputs:
if not isinstance(i, gof.Variable):
raise TypeError('inputs and outputs must be Variable instances', i)
raise TypeError(
'inputs and outputs must be Variable instances', i)
if 'updates' in kwargs:
raise TypeError('updates are not allowed in kwargs')
# TODO: the graph may have implicit inputs like Value and SharedVariable instances.
# TODO: the graph may have implicit inputs like Value and
# SharedVariable instances.
# what impact to they have on the validity of this Op?
self.fn = orig_function(inputs, outputs, **kwargs)
self.inputs = inputs
......@@ -52,7 +54,8 @@ class OpFromGraph(gof.Op):
if grad_depth > 0:
output_grads = [t() for t in self.output_types]
gd = G.grad_sources_inputs(zip(self.outputs, output_grads), self.inputs)
gd = G.grad_sources_inputs(zip(self.outputs, output_grads),
self.inputs)
gs = map(gd.get, self.inputs)
self.grad_ops = []
for g in gs:
......@@ -63,8 +66,9 @@ class OpFromGraph(gof.Op):
# to compute the gradient, so we ignore them.
self.grad_ops.append(OpFromGraph(inputs + output_grads,
[g],
grad_depth = grad_depth - 1,
grad_depth=grad_depth - 1,
on_unused_input='ignore'))
def __eq__(self, other):
#TODO: recognize a copy
return self is other
......@@ -76,7 +80,8 @@ class OpFromGraph(gof.Op):
def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types):
if not type == input.type:
raise TypeError("Wrong type, expected %s but got %s" % (type, input.type))
raise TypeError("Wrong type, expected %s but got %s"
% (type, input.type))
return gof.Apply(self,
inputs,
[type() for type in self.output_types])
......@@ -85,8 +90,8 @@ class OpFromGraph(gof.Op):
variables = self.fn(*inputs)
assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables):
##TODO: when function's output-borrowing semantics are correct, we wont need this
# copy anymore
##TODO: when function's output-borrowing semantics are correct,
# we wont need this copy anymore
output[0] = variable.copy()
def grad(self, inputs, output_grads):
......@@ -94,5 +99,3 @@ class OpFromGraph(gof.Op):
return [go(*(inputs + output_grads)) for go in self.grad_ops]
else:
raise NotImplementedError
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论