提交 f3afab87 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

More pep8 / pyflakes

上级 034bb5a3
from theano import gof from theano import gof
from theano import gradient as G from theano import gradient as G
from function_module import orig_function from function_module import orig_function
...@@ -33,16 +32,19 @@ class OpFromGraph(gof.Op): ...@@ -33,16 +32,19 @@ class OpFromGraph(gof.Op):
e2 = op(x, y, z) + op(z, y, x) e2 = op(x, y, z) + op(z, y, x)
fn = function([x, y, z], [e2]) fn = function([x, y, z], [e2])
""" """
def __init__(self, inputs, outputs, grad_depth = 1, **kwargs): def __init__(self, inputs, outputs, grad_depth=1, **kwargs):
if not isinstance(outputs, list): if not isinstance(outputs, list):
raise TypeError('outputs must be list', outputs) raise TypeError('outputs must be list', outputs)
for i in inputs + outputs: for i in inputs + outputs:
if not isinstance(i, gof.Variable): if not isinstance(i, gof.Variable):
raise TypeError('inputs and outputs must be Variable instances', i) raise TypeError(
'inputs and outputs must be Variable instances', i)
if 'updates' in kwargs: if 'updates' in kwargs:
raise TypeError('updates are not allowed in kwargs') raise TypeError('updates are not allowed in kwargs')
# TODO: the graph may have implicit inputs like Value and SharedVariable instances.
# TODO: the graph may have implicit inputs like Value and
# SharedVariable instances.
# what impact to they have on the validity of this Op? # what impact to they have on the validity of this Op?
self.fn = orig_function(inputs, outputs, **kwargs) self.fn = orig_function(inputs, outputs, **kwargs)
self.inputs = inputs self.inputs = inputs
...@@ -52,7 +54,8 @@ class OpFromGraph(gof.Op): ...@@ -52,7 +54,8 @@ class OpFromGraph(gof.Op):
if grad_depth > 0: if grad_depth > 0:
output_grads = [t() for t in self.output_types] output_grads = [t() for t in self.output_types]
gd = G.grad_sources_inputs(zip(self.outputs, output_grads), self.inputs) gd = G.grad_sources_inputs(zip(self.outputs, output_grads),
self.inputs)
gs = map(gd.get, self.inputs) gs = map(gd.get, self.inputs)
self.grad_ops = [] self.grad_ops = []
for g in gs: for g in gs:
...@@ -63,8 +66,9 @@ class OpFromGraph(gof.Op): ...@@ -63,8 +66,9 @@ class OpFromGraph(gof.Op):
# to compute the gradient, so we ignore them. # to compute the gradient, so we ignore them.
self.grad_ops.append(OpFromGraph(inputs + output_grads, self.grad_ops.append(OpFromGraph(inputs + output_grads,
[g], [g],
grad_depth = grad_depth - 1, grad_depth=grad_depth - 1,
on_unused_input='ignore')) on_unused_input='ignore'))
def __eq__(self, other): def __eq__(self, other):
#TODO: recognize a copy #TODO: recognize a copy
return self is other return self is other
...@@ -76,7 +80,8 @@ class OpFromGraph(gof.Op): ...@@ -76,7 +80,8 @@ class OpFromGraph(gof.Op):
def make_node(self, *inputs): def make_node(self, *inputs):
for input, type in zip(inputs, self.input_types): for input, type in zip(inputs, self.input_types):
if not type == input.type: if not type == input.type:
raise TypeError("Wrong type, expected %s but got %s" % (type, input.type)) raise TypeError("Wrong type, expected %s but got %s"
% (type, input.type))
return gof.Apply(self, return gof.Apply(self,
inputs, inputs,
[type() for type in self.output_types]) [type() for type in self.output_types])
...@@ -85,8 +90,8 @@ class OpFromGraph(gof.Op): ...@@ -85,8 +90,8 @@ class OpFromGraph(gof.Op):
variables = self.fn(*inputs) variables = self.fn(*inputs)
assert len(variables) == len(outputs) assert len(variables) == len(outputs)
for output, variable in zip(outputs, variables): for output, variable in zip(outputs, variables):
##TODO: when function's output-borrowing semantics are correct, we wont need this ##TODO: when function's output-borrowing semantics are correct,
# copy anymore # we wont need this copy anymore
output[0] = variable.copy() output[0] = variable.copy()
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
...@@ -94,5 +99,3 @@ class OpFromGraph(gof.Op): ...@@ -94,5 +99,3 @@ class OpFromGraph(gof.Op):
return [go(*(inputs + output_grads)) for go in self.grad_ops] return [go(*(inputs + output_grads)) for go in self.grad_ops]
else: else:
raise NotImplementedError raise NotImplementedError
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论