提交 4f8d207b authored 作者: khaotik's avatar khaotik 提交者: khaotik

PEP8

上级 c0fda9c0
...@@ -26,8 +26,6 @@ class OpFromGraphBase(gof.Op): ...@@ -26,8 +26,6 @@ class OpFromGraphBase(gof.Op):
'inputs and outputs must be Variable instances', i) 'inputs and outputs must be Variable instances', i)
if 'updates' in kwargs or 'givens' in kwargs: if 'updates' in kwargs or 'givens' in kwargs:
raise TypeError('updates and givens are not allowed here') raise TypeError('updates and givens are not allowed here')
# To correctly support shared variables the inner fct should # To correctly support shared variables the inner fct should
# not see them. Otherwise there is a problem with the gradient. # not see them. Otherwise there is a problem with the gradient.
self.shared_inputs = [var for var in gof.graph.inputs(outputs) self.shared_inputs = [var for var in gof.graph.inputs(outputs)
...@@ -46,7 +44,6 @@ class OpFromGraphBase(gof.Op): ...@@ -46,7 +44,6 @@ class OpFromGraphBase(gof.Op):
assert not update_expr assert not update_expr
assert not shared_inputs assert not shared_inputs
self.internal_inputs = internal_inputs self.internal_inputs = internal_inputs
self.internal_outputs = internal_outputs self.internal_outputs = internal_outputs
self.inputs = inputs self.inputs = inputs
...@@ -77,7 +74,7 @@ class OpFromGraphBase(gof.Op): ...@@ -77,7 +74,7 @@ class OpFromGraphBase(gof.Op):
grad_ops_l = self.grad_ops grad_ops_l = self.grad_ops
if isinstance(grad_ops_l, list): if isinstance(grad_ops_l, list):
assert len(grad_ops_l) <= len(self.internal_inputs) assert len(grad_ops_l) <= len(self.internal_inputs)
if len(grad_ops_l)<len(self.internal_inputs): if len(grad_ops_l) < len(self.internal_inputs):
grad_ops_l += [None]*( grad_ops_l += [None]*(
len(self.internal_inputs) - len(grad_ops_l)) len(self.internal_inputs) - len(grad_ops_l))
# It is normal if some inputs are not needed in order # It is normal if some inputs are not needed in order
...@@ -92,11 +89,13 @@ class OpFromGraphBase(gof.Op): ...@@ -92,11 +89,13 @@ class OpFromGraphBase(gof.Op):
disconnected_inputs='ignore') disconnected_inputs='ignore')
), on_unused_input='ignore' ), on_unused_input='ignore'
) for go, inp in izip(grad_ops_l, self.internal_inputs)] ) for go, inp in izip(grad_ops_l, self.internal_inputs)]
# since OpFromGraphBase only accepts input sequence, # since OpFromGraphBase only accepts input sequence,
# additional filtering is needed # additional filtering is needed
grad_ops = lambda inps,grds:[ def grad_ops(inps, grds):
(go(inps, grds) if ov else go(*(inps+grds))) nonlocal gs, grad_ops_l
for go, ov in izip(gs, grad_ops_l)] return [(go(inps, grds) if ov else go(*(inps+grds)))
for go, ov in izip(gs, grad_ops_l)]
else: else:
grad_ops = grad_ops_l grad_ops = grad_ops_l
self.grad_ops = grad_ops self.grad_ops = grad_ops
...@@ -111,10 +110,12 @@ class OpFromGraphBase(gof.Op): ...@@ -111,10 +110,12 @@ class OpFromGraphBase(gof.Op):
if g is None: if g is None:
grad_ops_l.append(lambda *args: None) grad_ops_l.append(lambda *args: None)
else: else:
grad_ops_l.append(type(self)(grad_inps, grad_ops_l.append(type(self)(
[g], grad_inps, [g], on_unused_input='ignore'))
on_unused_input='ignore'))
grad_ops = lambda inps, grds:[go(*(inps+grds)) for go in grad_ops_l] def grad_ops(inps, grds):
nonlocal grad_ops_l
return [go(*(inps+grds)) for go in grad_ops_l]
self.grad_ops = grad_ops self.grad_ops = grad_ops
self.cached_grad_ops = True self.cached_grad_ops = True
return grad_ops(inputs, output_grads) return grad_ops(inputs, output_grads)
...@@ -125,9 +126,9 @@ class OpFromGraphBase(gof.Op): ...@@ -125,9 +126,9 @@ class OpFromGraphBase(gof.Op):
raise TypeError("Wrong type, expected %s but got %s" % raise TypeError("Wrong type, expected %s but got %s" %
(type, input.type)) (type, input.type))
apply_node = gof.Apply(self, apply_node = gof.Apply(
list(inputs) + self.shared_inputs, self, list(inputs) + self.shared_inputs,
[type() for type in self.output_types]) [type() for type in self.output_types])
apply_node.internal_inputs = self.internal_inputs apply_node.internal_inputs = self.internal_inputs
apply_node.internal_outputs = self.internal_outputs apply_node.internal_outputs = self.internal_outputs
return apply_node return apply_node
...@@ -137,7 +138,8 @@ class OpFromGraphBase(gof.Op): ...@@ -137,7 +138,8 @@ class OpFromGraphBase(gof.Op):
Return connection pattern of subfgraph defined by inputs and outputs. Return connection pattern of subfgraph defined by inputs and outputs.
""" """
return io_connection_pattern(self.internal_inputs, self.internal_outputs) return io_connection_pattern(
self.internal_inputs, self.internal_outputs)
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
out_shp = theano.scan_module.scan_utils.infer_shape( out_shp = theano.scan_module.scan_utils.infer_shape(
...@@ -162,9 +164,11 @@ class OpFromGraphBase(gof.Op): ...@@ -162,9 +164,11 @@ class OpFromGraphBase(gof.Op):
used += nb used += nb
return ret return ret
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
raise NotImplementedError() raise NotImplementedError()
class OpFromGraphPrecompiled(OpFromGraphBase): class OpFromGraphPrecompiled(OpFromGraphBase):
""" """
The Op's inner graph is compiled into a theano function. The Op's inner graph is compiled into a theano function.
...@@ -183,12 +187,15 @@ class OpFromGraphPrecompiled(OpFromGraphBase): ...@@ -183,12 +187,15 @@ class OpFromGraphPrecompiled(OpFromGraphBase):
# we wont need this copy anymore # we wont need this copy anymore
output[0] = variable.copy() output[0] = variable.copy()
class OpFromGraphInline(OpFromGraphBase): class OpFromGraphInline(OpFromGraphBase):
""" """
The Op's inner graph is expanded into the outer graph at compile time The Op's inner graph is expanded into the outer graph at compile time
""" """
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
raise RuntimeError(type(self).__name__+' is not supposed to be executed at runtime') raise RuntimeError(
type(self).__name__+' is not supposed to be executed at runtime')
@gof.local_optimizer([OpFromGraphInline]) @gof.local_optimizer([OpFromGraphInline])
def inline_ofg_expansion(node): def inline_ofg_expansion(node):
...@@ -199,7 +206,7 @@ def inline_ofg_expansion(node): ...@@ -199,7 +206,7 @@ def inline_ofg_expansion(node):
return False return False
outputs = theano.clone( outputs = theano.clone(
op.internal_outputs, { op.internal_outputs, {
u:v for u,v in izip( u: v for u, v in izip(
node.op.internal_inputs, node.inputs)}) node.op.internal_inputs, node.inputs)})
return outputs return outputs
...@@ -218,7 +225,8 @@ OpFromGraph = OpFromGraphPrecompiled ...@@ -218,7 +225,8 @@ OpFromGraph = OpFromGraphPrecompiled
# API for OpFromGraph* # API for OpFromGraph*
def op_from_graph( def op_from_graph(
inputs, outputs, inline=False, grad_overrides=None, **kwargs): inputs, outputs, inline=False, grad_overrides=None, **kwargs
):
""" """
This creates an `Op` from inputs and outputs lists of variables. This creates an `Op` from inputs and outputs lists of variables.
The signature is similar to theano.function() and the resulting The signature is similar to theano.function() and the resulting
...@@ -270,8 +278,8 @@ def op_from_graph( ...@@ -270,8 +278,8 @@ def op_from_graph(
invisible to the user. They can be as input to the node or in the invisible to the user. They can be as input to the node or in the
inner graph. inner graph.
- We support unused inputs. This is needed for the grad. - We support unused inputs. This is needed for the grad.
- `inline=True` will cause better runtime optimization at the cost of longer - `inline=True` will cause better runtime optimization at the cost of
compilation, only works with optimizer "fast_run" or "fast_compile" longer compilation, only works with optimizer "fast_run" or "fast_compile"
Examples Examples
-------- --------
...@@ -329,4 +337,3 @@ def op_from_graph( ...@@ -329,4 +337,3 @@ def op_from_graph(
cls_opfromgraph = OpFromGraphPrecompiled cls_opfromgraph = OpFromGraphPrecompiled
return cls_opfromgraph( return cls_opfromgraph(
inputs, outputs, grad_overrides=grad_overrides, **kwargs) inputs, outputs, grad_overrides=grad_overrides, **kwargs)
...@@ -18,7 +18,6 @@ test_params = unittest_tools.parameterized.expand( ...@@ -18,7 +18,6 @@ test_params = unittest_tools.parameterized.expand(
class T_OpFromGraph(unittest_tools.InferShapeTester): class T_OpFromGraph(unittest_tools.InferShapeTester):
@test_params @test_params
def test_straightforward(self, cls_ofg): def test_straightforward(self, cls_ofg):
x, y, z = T.matrices('xyz') x, y, z = T.matrices('xyz')
...@@ -122,7 +121,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -122,7 +121,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
@test_params @test_params
def test_grad_override(self, cls_ofg): def test_grad_override(self, cls_ofg):
x,y = T.vectors('xy') x, y = T.vectors('xy')
def go(inps, gs): def go(inps, gs):
x, y = inps x, y = inps
...@@ -132,8 +131,8 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -132,8 +131,8 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
# single override case # single override case
op_mul = cls_ofg([x, y], [x*y], grad_overrides=go) op_mul = cls_ofg([x, y], [x*y], grad_overrides=go)
xx,yy = T.vector('xx'), T.vector('yy') xx, yy = T.vector('xx'), T.vector('yy')
zz = T.sum(op_mul(xx,yy)) zz = T.sum(op_mul(xx, yy))
dx, dy = T.grad(zz, [xx, yy]) dx, dy = T.grad(zz, [xx, yy])
fn = function([xx, yy], [dx, dy]) fn = function([xx, yy], [dx, dy])
xv = numpy.random.rand(16).astype(config.floatX) xv = numpy.random.rand(16).astype(config.floatX)
...@@ -247,4 +246,3 @@ class T_OpFromGraph(unittest_tools.InferShapeTester): ...@@ -247,4 +246,3 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
np.ones([3, 4], dtype=config.floatX)], np.ones([3, 4], dtype=config.floatX)],
cls_ofg, cls_ofg,
check_topo=is_compile) check_topo=is_compile)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论