提交 823cffaa authored 作者: gdesjardins's avatar gdesjardins

Re: ticket #573

Implemented unittests which test that the contract for In is respected: we test that aliasing the input is impossible with borrow=False.
上级 260f7392
......@@ -4,6 +4,13 @@ __docformat__ = 'restructuredtext en'
from theano import gof
from sharedvalue import SharedVariable
import logging
_logger=logging.getLogger("theano.compile.io")
_logger.setLevel(logging.WARNING)
def warning(*args):
_logger.warning("WARNING: "+' '.join(str(a) for a in args))
class SymbolicInput(object):
"""
Represents a symbolic input for use with function or FunctionMaker.
......@@ -184,7 +191,26 @@ class In(SymbolicInput):
# try to keep it synchronized.
def __init__(self, variable, name=None, value=None, update=None,
mutable=None, strict=False, allow_downcast=False, autoname=True,
implicit=None, borrow=False):
implicit=None, borrow=None):
# mutable implies the output can be both aliased to the input and that the input can be
# destroyed. borrow simply implies the output can be aliased to the input. Thus
# mutable=True should require borrow=True. Raise warning when borrow is explicitely set
# to False with mutable=True.
if mutable:
if borrow==False:
warning("Symbolic input for variable %s (name=%s) has flags "\
"mutable=True, borrow=False. This combination is "\
"incompatible since mutable=True implies that the input "\
"variable may be both aliased (borrow=True) and over-"\
"written. We set borrow=True and continue." % (variable, name))
borrow = True
# borrow=None basically means False. We can't set default value to False because of the
# above business with mutable.
if borrow is None:
borrow = False
if implicit is None:
implicit = (isinstance(value, gof.Container) or
isinstance(value, SharedVariable))
......
......@@ -257,9 +257,10 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
for sv in shared_inputs:
if sv in update_d:
si = In(variable=sv, value=sv.container, mutable=True,
update=update_d[sv])
borrow=True, update=update_d[sv])
else:
si = In(variable=sv, value=sv.container, mutable=False)
si = In(variable=sv, value=sv.container,
mutable=False, borrow=True)
inputs.append(si)
return orig_function(inputs, cloned_outputs, mode,
......
......@@ -281,6 +281,26 @@ class T_function(unittest.TestCase):
self.failUnless(dec[s] == -1)
def test_borrow_input(self):
"""
Tests that the contract for io.In is respected. When borrow=False, it should be
impossible for outputs to be aliased to the input variables provided by the user,
either through a view-map or a destroy map. New tests should be added in the future
when borrow=True is implemented.
"""
a = T.dmatrix()
aval = numpy.random.rand(3,3)
# when borrow=False, test that a destroy map cannot alias output to input
f = theano.function([In(a, borrow=False)], Out(a+1, borrow=True))
assert numpy.all(f(aval) == aval+1)
assert not numpy.may_share_memory(aval, f(aval))
# when borrow=False, test that a viewmap cannot alias output to input
f = theano.function([In(a, borrow=False)], Out(a[0,:], borrow=True))
assert numpy.all(f(aval) == aval[0,:])
assert not numpy.may_share_memory(aval, f(aval))
def test_borrow_output(self):
a = T.dmatrix()
f = function([a], Out(a, borrow=False))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论