提交 4f8d207b authored 作者: khaotik's avatar khaotik 提交者: khaotik

PEP8

上级 c0fda9c0
......@@ -26,8 +26,6 @@ class OpFromGraphBase(gof.Op):
'inputs and outputs must be Variable instances', i)
if 'updates' in kwargs or 'givens' in kwargs:
raise TypeError('updates and givens are not allowed here')
# To correctly support shared variables the inner fct should
# not see them. Otherwise there is a problem with the gradient.
self.shared_inputs = [var for var in gof.graph.inputs(outputs)
......@@ -46,7 +44,6 @@ class OpFromGraphBase(gof.Op):
assert not update_expr
assert not shared_inputs
self.internal_inputs = internal_inputs
self.internal_outputs = internal_outputs
self.inputs = inputs
......@@ -77,7 +74,7 @@ class OpFromGraphBase(gof.Op):
grad_ops_l = self.grad_ops
if isinstance(grad_ops_l, list):
assert len(grad_ops_l) <= len(self.internal_inputs)
if len(grad_ops_l)<len(self.internal_inputs):
if len(grad_ops_l) < len(self.internal_inputs):
grad_ops_l += [None]*(
len(self.internal_inputs) - len(grad_ops_l))
# It is normal if some inputs are not needed in order
......@@ -92,11 +89,13 @@ class OpFromGraphBase(gof.Op):
disconnected_inputs='ignore')
), on_unused_input='ignore'
) for go, inp in izip(grad_ops_l, self.internal_inputs)]
# since OpFromGraphBase only accepts input sequence,
# additional filtering is needed
grad_ops = lambda inps,grds:[
(go(inps, grds) if ov else go(*(inps+grds)))
for go, ov in izip(gs, grad_ops_l)]
def grad_ops(inps, grds):
nonlocal gs, grad_ops_l
return [(go(inps, grds) if ov else go(*(inps+grds)))
for go, ov in izip(gs, grad_ops_l)]
else:
grad_ops = grad_ops_l
self.grad_ops = grad_ops
......@@ -111,10 +110,12 @@ class OpFromGraphBase(gof.Op):
if g is None:
grad_ops_l.append(lambda *args: None)
else:
grad_ops_l.append(type(self)(grad_inps,
[g],
on_unused_input='ignore'))
grad_ops = lambda inps, grds:[go(*(inps+grds)) for go in grad_ops_l]
grad_ops_l.append(type(self)(
grad_inps, [g], on_unused_input='ignore'))
def grad_ops(inps, grds):
nonlocal grad_ops_l
return [go(*(inps+grds)) for go in grad_ops_l]
self.grad_ops = grad_ops
self.cached_grad_ops = True
return grad_ops(inputs, output_grads)
......@@ -125,9 +126,9 @@ class OpFromGraphBase(gof.Op):
raise TypeError("Wrong type, expected %s but got %s" %
(type, input.type))
apply_node = gof.Apply(self,
list(inputs) + self.shared_inputs,
[type() for type in self.output_types])
apply_node = gof.Apply(
self, list(inputs) + self.shared_inputs,
[type() for type in self.output_types])
apply_node.internal_inputs = self.internal_inputs
apply_node.internal_outputs = self.internal_outputs
return apply_node
......@@ -137,7 +138,8 @@ class OpFromGraphBase(gof.Op):
Return connection pattern of subfgraph defined by inputs and outputs.
"""
return io_connection_pattern(self.internal_inputs, self.internal_outputs)
return io_connection_pattern(
self.internal_inputs, self.internal_outputs)
def infer_shape(self, node, shapes):
out_shp = theano.scan_module.scan_utils.infer_shape(
......@@ -162,9 +164,11 @@ class OpFromGraphBase(gof.Op):
used += nb
return ret
def perform(self, node, inputs, outputs):
raise NotImplementedError()
class OpFromGraphPrecompiled(OpFromGraphBase):
"""
The Op's inner graph is compiled into a theano function.
......@@ -183,12 +187,15 @@ class OpFromGraphPrecompiled(OpFromGraphBase):
# we wont need this copy anymore
output[0] = variable.copy()
class OpFromGraphInline(OpFromGraphBase):
"""
The Op's inner graph is expanded into the outer graph at compile time
"""
def perform(self, node, inputs, outputs):
raise RuntimeError(type(self).__name__+' is not supposed to be executed at runtime')
raise RuntimeError(
type(self).__name__+' is not supposed to be executed at runtime')
@gof.local_optimizer([OpFromGraphInline])
def inline_ofg_expansion(node):
......@@ -199,7 +206,7 @@ def inline_ofg_expansion(node):
return False
outputs = theano.clone(
op.internal_outputs, {
u:v for u,v in izip(
u: v for u, v in izip(
node.op.internal_inputs, node.inputs)})
return outputs
......@@ -218,7 +225,8 @@ OpFromGraph = OpFromGraphPrecompiled
# API for OpFromGraph*
def op_from_graph(
inputs, outputs, inline=False, grad_overrides=None, **kwargs):
inputs, outputs, inline=False, grad_overrides=None, **kwargs
):
"""
This creates an `Op` from inputs and outputs lists of variables.
The signature is similar to theano.function() and the resulting
......@@ -270,8 +278,8 @@ def op_from_graph(
invisible to the user. They can be as input to the node or in the
inner graph.
- We support unused inputs. This is needed for the grad.
- `inline=True` will cause better runtime optimization at the cost of longer
compilation, only works with optimizer "fast_run" or "fast_compile"
- `inline=True` will cause better runtime optimization at the cost of
longer compilation, only works with optimizer "fast_run" or "fast_compile"
Examples
--------
......@@ -329,4 +337,3 @@ def op_from_graph(
cls_opfromgraph = OpFromGraphPrecompiled
return cls_opfromgraph(
inputs, outputs, grad_overrides=grad_overrides, **kwargs)
......@@ -18,7 +18,6 @@ test_params = unittest_tools.parameterized.expand(
class T_OpFromGraph(unittest_tools.InferShapeTester):
@test_params
def test_straightforward(self, cls_ofg):
x, y, z = T.matrices('xyz')
......@@ -122,7 +121,7 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
@test_params
def test_grad_override(self, cls_ofg):
x,y = T.vectors('xy')
x, y = T.vectors('xy')
def go(inps, gs):
x, y = inps
......@@ -132,8 +131,8 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
# single override case
op_mul = cls_ofg([x, y], [x*y], grad_overrides=go)
xx,yy = T.vector('xx'), T.vector('yy')
zz = T.sum(op_mul(xx,yy))
xx, yy = T.vector('xx'), T.vector('yy')
zz = T.sum(op_mul(xx, yy))
dx, dy = T.grad(zz, [xx, yy])
fn = function([xx, yy], [dx, dy])
xv = numpy.random.rand(16).astype(config.floatX)
......@@ -247,4 +246,3 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
np.ones([3, 4], dtype=config.floatX)],
cls_ofg,
check_topo=is_compile)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论