提交 9532c287 authored 作者: Frederic's avatar Frederic

pep8

上级 aa055460
......@@ -16,7 +16,8 @@ import theano
from theano import gof
from theano.gof.python25 import partial
import theano.compile.mode
from theano.compile.io import In, SymbolicInput, SymbolicInputKit, SymbolicOutput
from theano.compile.io import (
In, SymbolicInput, SymbolicInputKit, SymbolicOutput)
from theano.compile.ops import deep_copy_op, view_op
import logging
......@@ -29,9 +30,11 @@ class UnusedInputError(Exception):
"""
pass
def alias_root(v):
"""Return the variable to which v is aliased by view_maps and destroy_maps"""
if v.owner is None: return v
if v.owner is None:
return v
vmap = getattr(v.owner.op, 'view_map', {})
dmap = getattr(v.owner.op, 'destroy_map', {})
outpos = v.owner.outputs.index(v)
......@@ -106,10 +109,11 @@ class Supervisor:
return True
for r in self.protected + list(fgraph.outputs):
if fgraph.destroyers(r):
raise gof.InconsistencyError("Trying to destroy a protected Variable.", r)
raise gof.InconsistencyError(
"Trying to destroy a protected Variable.", r)
def std_fgraph(input_specs, output_specs, accept_inplace = False):
def std_fgraph(input_specs, output_specs, accept_inplace=False):
"""
Makes an FunctionGraph corresponding to the input specs and the output
specs. Any SymbolicInput in the input_specs, if its update field
......@@ -134,17 +138,18 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
for node in fgraph.apply_nodes:
if getattr(node.op, 'destroy_map', None):
if not accept_inplace:
raise TypeError("Graph must not contain inplace operations", node, node.op)
raise TypeError("Graph must not contain inplace operations",
node, node.op)
else:
fgraph.attach_feature(gof.DestroyHandler())
break
# We need to protect all immutable inputs from inplace operations.
fgraph.attach_feature(
Supervisor(input
for spec, input in zip(input_specs, fgraph.inputs)
if not (spec.mutable or
(hasattr(fgraph, 'destroyers') and
Supervisor(input
for spec, input in zip(input_specs, fgraph.inputs)
if not (spec.mutable or
(hasattr(fgraph, 'destroyers') and
fgraph.destroyers(input)))))
# If named nodes are replaced, keep the name
......@@ -155,6 +160,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
std_fgraph.features = [gof.toolbox.PreserveNames]
class AliasedMemoryError(Exception):
"""Memory is aliased that should not be"""
pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论