提交 da45efb4 authored 作者: James Bergstra's avatar James Bergstra

added config variable shared.value_borrows to control the copying behaviour of…

added config variable shared.value_borrows to control the copying behaviour of the shared.value property
上级 f9b46251
"""Provide a simple user friendly API """ """Provide a simple user friendly API to Theano-managed memory"""
__docformat__ = 'restructuredtext en' __docformat__ = 'restructuredtext en'
import traceback import traceback
...@@ -14,6 +14,15 @@ def warn(*msg): _logger.warn(' '.join(str(m) for m in msg)) ...@@ -14,6 +14,15 @@ def warn(*msg): _logger.warn(' '.join(str(m) for m in msg))
def warning(*msg): _logger.warning(' '.join(str(m) for m in msg)) def warning(*msg): _logger.warning(' '.join(str(m) for m in msg))
def error(*msg): _logger.error(' '.join(str(m) for m in msg)) def error(*msg): _logger.error(' '.join(str(m) for m in msg))
from theano.configparser import TheanoConfigParser, AddConfigVar, EnumStr, StrParam, IntParam, FloatParam, BoolParam
from theano import config
AddConfigVar('shared.value_borrows',
("False: shared variables 'value' property is guaranteed to not"
" alias theano-managed memory. True: no guarantee, but faster."
" For more control consider using shared.get_value() instead."),
BoolParam(True))
class SharedVariable(Variable): class SharedVariable(Variable):
""" """
Variable that is (defaults to being) shared between functions that it appears in. Variable that is (defaults to being) shared between functions that it appears in.
...@@ -106,17 +115,19 @@ class SharedVariable(Variable): ...@@ -106,17 +115,19 @@ class SharedVariable(Variable):
cp.tag = copy.copy(self.tag) cp.tag = copy.copy(self.tag)
return cp return cp
def get_borrowed_value(self): def _value_get(self):
return self.get_value(borrow=True) return self.get_value(borrow=config.shared.value_borrows)
def set_borrowed_value(self, new_value): def _value_set(self, new_value):
return self.set_value(new_value, borrow=True) return self.set_value(new_value, borrow=config.shared.value_borrows)
#TODO: USE A CONFIG VARIABLE TO set these get/set methods to the non-borrowing versions #TODO: USE A CONFIG VARIABLE TO set these get/set methods to the non-borrowing versions
# Semantically things are clearer when using non-borrow versions. That should be the # Semantically things are clearer when using non-borrow versions. That should be the
# default. The default support transparently (if slowly) when the 'raw' value is in a # default. The default support transparently (if slowly) when the 'raw' value is in a
# different memory space (e.g. GPU or other machine). # different memory space (e.g. GPU or other machine).
value = property(get_borrowed_value, set_borrowed_value, value = property(_value_get, _value_set,
doc="shortcut for self.get_borrowed_value() and self.set_borrowed_value() which COPIES data") doc=("shortcut for self.get_value() and self.set_value()."
"The `borrow` argument to these methods is read from "
"`theano.config.shared.value_borrows`"))
def filter_update(self, update): def filter_update(self, update):
......
...@@ -511,7 +511,7 @@ class Test_aliasing_rules(unittest.TestCase): ...@@ -511,7 +511,7 @@ class Test_aliasing_rules(unittest.TestCase):
assert not numpy.may_share_memory(orig_a, data_of(A)) assert not numpy.may_share_memory(orig_a, data_of(A))
# rule #2 reading back from theano-managed memory # rule #2 reading back from theano-managed memory
assert not numpy.may_share_memory(A.value, data_of(A)) assert not numpy.may_share_memory(A.get_value(borrow=False), data_of(A))
def test_potential_output_aliasing_induced_by_updates(self): def test_potential_output_aliasing_induced_by_updates(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论