提交 e798bce3 authored 作者: Frederic's avatar Frederic

Raise a good error when we create a FunctionGraph with Cached Constant.

上级 25127eff
...@@ -43,7 +43,7 @@ from theano.gof.compiledir import \ ...@@ -43,7 +43,7 @@ from theano.gof.compiledir import \
local_bitwidth, python_int_bitwidth local_bitwidth, python_int_bitwidth
from theano.gof.fg import \ from theano.gof.fg import \
InconsistencyError, MissingInputError, FunctionGraph CachedConstantError, InconsistencyError, MissingInputError, FunctionGraph
from theano.gof.destroyhandler import \ from theano.gof.destroyhandler import \
DestroyHandler DestroyHandler
......
...@@ -20,6 +20,14 @@ from theano.gof.python25 import OrderedDict ...@@ -20,6 +20,14 @@ from theano.gof.python25 import OrderedDict
from theano.misc.ordered_set import OrderedSet from theano.misc.ordered_set import OrderedSet
class CachedConstantError(Exception):
"""An exception thrown when we put in a FunctionGraph a Constant
that is cached. This should not happen as the user can reuse this
cached constant in other FunctionGraph.
"""
pass
class InconsistencyError(Exception): class InconsistencyError(Exception):
""" """
This exception should be thrown by listeners to FunctionGraph when the This exception should be thrown by listeners to FunctionGraph when the
...@@ -122,6 +130,11 @@ class FunctionGraph(utils.object2): ...@@ -122,6 +130,11 @@ class FunctionGraph(utils.object2):
### Setup a Variable ### ### Setup a Variable ###
def __setup_r__(self, r): def __setup_r__(self, r):
# sets up r so it belongs to this fgraph # sets up r so it belongs to this fgraph
if getattr(r, 'cached', False):
raise CachedConstantError(
"You manually constructed a FunctionGraph, but you passed it a"
" graph that have cached constant. This should happen."
" Clone the graph before building the FunctionGraph")
if (hasattr(r, 'fgraph') and if (hasattr(r, 'fgraph') and
r.fgraph is not None and r.fgraph is not None and
r.fgraph is not self): r.fgraph is not self):
......
...@@ -419,6 +419,8 @@ def constant(x, name=None, ndim=None, dtype=None): ...@@ -419,6 +419,8 @@ def constant(x, name=None, ndim=None, dtype=None):
(ret.dtype in int_dtypes or ret.dtype in uint_dtypes or (ret.dtype in int_dtypes or ret.dtype in uint_dtypes or
(ret.dtype in float_dtypes and int(ret.data) == ret.data))): (ret.dtype in float_dtypes and int(ret.data) == ret.data))):
constant_cache[sig] = ret constant_cache[sig] = ret
# This is needed to raise a good error to the user.
ret.cached = True
return constant_cache.get(sig, ret) return constant_cache.get(sig, ret)
constant.enable = True constant.enable = True
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论