提交 3bbb4456 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed bug where std_fgraph was populated with Feature subclasses instead

of instances thereof
上级 3672b24b
...@@ -148,7 +148,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False): ...@@ -148,7 +148,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace = False):
return fgraph, map(SymbolicOutput, updates) return fgraph, map(SymbolicOutput, updates)
std_fgraph.features = [gof.toolbox.PreserveNames] std_fgraph.features = [gof.toolbox.PreserveNames()]
class UncomputableFeature(gof.Feature): class UncomputableFeature(gof.Feature):
"""A feature that ensures the graph never contains any """A feature that ensures the graph never contains any
...@@ -165,7 +165,7 @@ class UncomputableFeature(gof.Feature): ...@@ -165,7 +165,7 @@ class UncomputableFeature(gof.Feature):
def on_import(self, fgraph, node): def on_import(self, fgraph, node):
gof.op.raise_if_uncomputable(node) gof.op.raise_if_uncomputable(node)
std_fgraph.features.append(UncomputableFeature) std_fgraph.features.append(UncomputableFeature())
class AliasedMemoryError(Exception): class AliasedMemoryError(Exception):
......
...@@ -440,6 +440,13 @@ class FunctionGraph(utils.object2): ...@@ -440,6 +440,13 @@ class FunctionGraph(utils.object2):
""" """
if feature in self._features: if feature in self._features:
return # the feature is already present return # the feature is already present
#it would be nice if we could require a specific class instead of
#a "workalike" so we could do actual error checking
#if not isinstance(feature, toolbox.Feature):
# raise TypeError("Expected gof.toolbox.Feature instance, got "+\
# str(type(feature)))
attach = getattr(feature, 'on_attach', None) attach = getattr(feature, 'on_attach', None)
if attach is not None: if attach is not None:
try: try:
......
...@@ -317,6 +317,7 @@ class PrintListener(Feature): ...@@ -317,6 +317,7 @@ class PrintListener(Feature):
class PreserveNames(Feature): class PreserveNames(Feature):
def on_change_input(self, fgraph, mode, i, r, new_r, reason=None): def on_change_input(self, fgraph, mode, i, r, new_r, reason=None):
if r.name is not None and new_r.name is None: if r.name is not None and new_r.name is None:
new_r.name = r.name new_r.name = r.name
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论