提交 25127eff authored 作者: Frederic's avatar Frederic

Fix the constant cache during compilation.

上级 6c9ad5f6
......@@ -1306,12 +1306,12 @@ def orig_function(inputs, outputs, mode=None, accept_inplace=False,
else:
Maker = getattr(mode, 'function_maker', FunctionMaker)
fn = Maker(inputs,
outputs,
mode,
accept_inplace=accept_inplace,
profile=profile,
on_unused_input=on_unused_input).create(
defaults)
outputs,
mode,
accept_inplace=accept_inplace,
profile=profile,
on_unused_input=on_unused_input).create(
defaults)
t2 = time.time()
if profile:
......
......@@ -76,7 +76,13 @@ class Optimizer(object):
opt.apply(fgraph)
"""
self.add_requirements(fgraph)
return self.apply(fgraph, *args, **kwargs)
try:
orig = theano.tensor.basic.constant.enable
theano.tensor.basic.constant.enable = False
ret = self.apply(fgraph, *args, **kwargs)
finally:
theano.tensor.basic.constant.enable = orig
return ret
def __call__(self, fgraph):
"""WRITEME
......
......@@ -411,6 +411,8 @@ def constant(x, name=None, ndim=None, dtype=None):
#But we don't want to cache too much stuff
#So we cache integer with dtype [u]int and float where the value is between -10 and 10
#We want to cache all broadcast pattern for scalar.
if not constant.enable:
return ret
sig = ret.signature()
if (sig not in constant_cache and ret.data.size == 1 and
ret.data <= 10 and ret.data >= -10 and
......@@ -419,6 +421,7 @@ def constant(x, name=None, ndim=None, dtype=None):
constant_cache[sig] = ret
return constant_cache.get(sig, ret)
constant.enable = True
def _obj_is_wrappable_as_tensor(x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论