提交 631e5dc9 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix Theano function overhead increase due to gh-6275. Shared variable was not…

Fix Theano function overhead increase due to gh-6275. Shared variable was not triggering the input alias check and that was right.
上级 f4ef0469
......@@ -385,8 +385,12 @@ class Function(object):
# TODO: this only need to be set if there is more then 1 input
self._check_for_aliased_inputs = False
for i in maker.inputs:
if (isinstance(i, In) and ((hasattr(i, 'borrow') and i.borrow) or
(hasattr(i, 'mutable') and i.mutable))):
# If the input is a shared variable, the memory region is
# under Theano control and so we don't need to check if it
# is aliased as we never do that.
if (isinstance(i, In) and not i.shared and
(getattr(i, 'borrow', False) or
getattr(i, 'mutable', False))):
self._check_for_aliased_inputs = True
break
......
......@@ -582,6 +582,35 @@ class T_function(unittest.TestCase):
except TypeError:
assert(func(first=1) == x)
def test_check_for_aliased_inputs(self):
b = np.random.rand(5, 4)
s1 = theano.shared(b)
s2 = theano.shared(b)
x1 = theano.tensor.vector()
# Assert cases we should not check for aliased inputs
for d in [dict(outputs=[s1 + 1]),
dict(outputs=[s1 + 1, s2 + 3]),
dict(outputs=[s1 + 1], updates=[(s2, s2 + 3)]),
dict(inputs=[x1], outputs=[x1 + 1], updates=[(s2, s2 + 3)])]:
if "inputs" not in d:
d["inputs"] = []
f = theano.function(**d)
assert not f._check_for_aliased_inputs, d
# Assert cases we should check for aliased inputs
for d in [dict(inputs=[theano.In(x1, borrow=True)],
outputs=[x1 + 1], updates=[(s2, s2 + 3)]),
dict(inputs=[theano.In(x1, borrow=True, mutable=True)],
outputs=[x1 + 1], updates=[(s2, s2 + 3)]),
dict(inputs=[theano.In(x1, mutable=True)],
outputs=[x1 + 1], updates=[(s2, s2 + 3)])]:
if "inputs" not in d:
d["inputs"] = []
f = theano.function(**d)
assert f._check_for_aliased_inputs, d
class T_picklefunction(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论