提交 25127eff authored 作者: Frederic's avatar Frederic

Fix the constant cache during compilation.

上级 6c9ad5f6
...@@ -76,7 +76,13 @@ class Optimizer(object): ...@@ -76,7 +76,13 @@ class Optimizer(object):
opt.apply(fgraph) opt.apply(fgraph)
""" """
self.add_requirements(fgraph) self.add_requirements(fgraph)
return self.apply(fgraph, *args, **kwargs) try:
orig = theano.tensor.basic.constant.enable
theano.tensor.basic.constant.enable = False
ret = self.apply(fgraph, *args, **kwargs)
finally:
theano.tensor.basic.constant.enable = orig
return ret
def __call__(self, fgraph): def __call__(self, fgraph):
"""WRITEME """WRITEME
......
...@@ -411,6 +411,8 @@ def constant(x, name=None, ndim=None, dtype=None): ...@@ -411,6 +411,8 @@ def constant(x, name=None, ndim=None, dtype=None):
#But we don't want to cache too much stuff #But we don't want to cache too much stuff
#So we cache integer with dtype [u]int and float where the value is between -10 and 10 #So we cache integer with dtype [u]int and float where the value is between -10 and 10
#We want to cache all broadcast pattern for scalar. #We want to cache all broadcast pattern for scalar.
if not constant.enable:
return ret
sig = ret.signature() sig = ret.signature()
if (sig not in constant_cache and ret.data.size == 1 and if (sig not in constant_cache and ret.data.size == 1 and
ret.data <= 10 and ret.data >= -10 and ret.data <= 10 and ret.data >= -10 and
...@@ -419,6 +421,7 @@ def constant(x, name=None, ndim=None, dtype=None): ...@@ -419,6 +421,7 @@ def constant(x, name=None, ndim=None, dtype=None):
constant_cache[sig] = ret constant_cache[sig] = ret
return constant_cache.get(sig, ret) return constant_cache.get(sig, ret)
constant.enable = True
def _obj_is_wrappable_as_tensor(x): def _obj_is_wrappable_as_tensor(x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论