提交 9532c287 authored 作者: Frederic's avatar Frederic

pep8

上级 aa055460
...@@ -16,7 +16,8 @@ import theano ...@@ -16,7 +16,8 @@ import theano
from theano import gof from theano import gof
from theano.gof.python25 import partial from theano.gof.python25 import partial
import theano.compile.mode import theano.compile.mode
from theano.compile.io import In, SymbolicInput, SymbolicInputKit, SymbolicOutput from theano.compile.io import (
In, SymbolicInput, SymbolicInputKit, SymbolicOutput)
from theano.compile.ops import deep_copy_op, view_op from theano.compile.ops import deep_copy_op, view_op
import logging import logging
...@@ -29,9 +30,11 @@ class UnusedInputError(Exception): ...@@ -29,9 +30,11 @@ class UnusedInputError(Exception):
""" """
pass pass
def alias_root(v): def alias_root(v):
"""Return the variable to which v is aliased by view_maps and destroy_maps""" """Return the variable to which v is aliased by view_maps and destroy_maps"""
if v.owner is None: return v if v.owner is None:
return v
vmap = getattr(v.owner.op, 'view_map', {}) vmap = getattr(v.owner.op, 'view_map', {})
dmap = getattr(v.owner.op, 'destroy_map', {}) dmap = getattr(v.owner.op, 'destroy_map', {})
outpos = v.owner.outputs.index(v) outpos = v.owner.outputs.index(v)
...@@ -106,10 +109,11 @@ class Supervisor: ...@@ -106,10 +109,11 @@ class Supervisor:
return True return True
for r in self.protected + list(fgraph.outputs): for r in self.protected + list(fgraph.outputs):
if fgraph.destroyers(r): if fgraph.destroyers(r):
raise gof.InconsistencyError("Trying to destroy a protected Variable.", r) raise gof.InconsistencyError(
"Trying to destroy a protected Variable.", r)
def std_fgraph(input_specs, output_specs, accept_inplace = False): def std_fgraph(input_specs, output_specs, accept_inplace=False):
""" """
Makes an FunctionGraph corresponding to the input specs and the output Makes an FunctionGraph corresponding to the input specs and the output
specs. Any SymbolicInput in the input_specs, if its update field specs. Any SymbolicInput in the input_specs, if its update field
...@@ -134,17 +138,18 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False): ...@@ -134,17 +138,18 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
for node in fgraph.apply_nodes: for node in fgraph.apply_nodes:
if getattr(node.op, 'destroy_map', None): if getattr(node.op, 'destroy_map', None):
if not accept_inplace: if not accept_inplace:
raise TypeError("Graph must not contain inplace operations", node, node.op) raise TypeError("Graph must not contain inplace operations",
node, node.op)
else: else:
fgraph.attach_feature(gof.DestroyHandler()) fgraph.attach_feature(gof.DestroyHandler())
break break
# We need to protect all immutable inputs from inplace operations. # We need to protect all immutable inputs from inplace operations.
fgraph.attach_feature( fgraph.attach_feature(
Supervisor(input Supervisor(input
for spec, input in zip(input_specs, fgraph.inputs) for spec, input in zip(input_specs, fgraph.inputs)
if not (spec.mutable or if not (spec.mutable or
(hasattr(fgraph, 'destroyers') and (hasattr(fgraph, 'destroyers') and
fgraph.destroyers(input))))) fgraph.destroyers(input)))))
# If named nodes are replaced, keep the name # If named nodes are replaced, keep the name
...@@ -155,6 +160,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False): ...@@ -155,6 +160,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
std_fgraph.features = [gof.toolbox.PreserveNames] std_fgraph.features = [gof.toolbox.PreserveNames]
class AliasedMemoryError(Exception): class AliasedMemoryError(Exception):
"""Memory is aliased that should not be""" """Memory is aliased that should not be"""
pass pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论