提交 52090854 authored 作者: Frederic Bastien's avatar Frederic Bastien

Make the default of Param(...,borrow=None) the same as pfunc.

Also small update to documentation.
上级 59d86606
...@@ -160,7 +160,7 @@ class In(SymbolicInput): ...@@ -160,7 +160,7 @@ class In(SymbolicInput):
True: permit the compiled function to modify the python object being passed as the input True: permit the compiled function to modify the python object being passed as the input
False: do not permit the compiled function to modify the python object being passed as the input. False: do not permit the compiled function to modify the python object being passed as the input.
borrow: Bool (default: False if update is None, True if update is not None) borrow: Bool (default: False if mutable evaluate to False, True otherwise)
True: permit the output of the compiled function to be aliased to the input True: permit the output of the compiled function to be aliased to the input
False: do not permit any output to be aliased to the input False: do not permit any output to be aliased to the input
......
...@@ -9,6 +9,9 @@ from theano.compile import orig_function, In, Out ...@@ -9,6 +9,9 @@ from theano.compile import orig_function, In, Out
from theano.compile.sharedvalue import SharedVariable, shared from theano.compile.sharedvalue import SharedVariable, shared
from theano import config from theano import config
import logging
_logger=logging.getLogger("theano.compile.pfunc")
def rebuild_collect_shared( outputs def rebuild_collect_shared( outputs
, inputs = None , inputs = None
, replace = None , replace = None
...@@ -251,7 +254,7 @@ def rebuild_collect_shared( outputs ...@@ -251,7 +254,7 @@ def rebuild_collect_shared( outputs
class Param(object): class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False, def __init__(self, variable, default=None, name=None, mutable=False,
strict=False, allow_downcast=None, implicit=None, borrow = False): strict=False, allow_downcast=None, implicit=None, borrow = None):
""" """
:param variable: A variable in an expression graph to use as a compiled-function parameter :param variable: A variable in an expression graph to use as a compiled-function parameter
...@@ -283,10 +286,18 @@ class Param(object): ...@@ -283,10 +286,18 @@ class Param(object):
self.default = default self.default = default
self.name = name self.name = name
self.mutable = mutable self.mutable = mutable
# Mutable implies borrow. You can get borrow = False because of the # mutable implies the output can be both aliased to the input and that the input can be
# default and it is a bit annoying to require the user to set both # destroyed. borrow simply implies the output can be aliased to the input. Thus
# borrow and mutable to True # mutable=True should require borrow=True. Raise warning when borrow is explicitely set
# to False with mutable=True.
if mutable: if mutable:
if borrow==False:
_logger.warning("Symbolic input for variable %s (name=%s) has "
"flags mutable=True, borrow=False. This combination is "
"incompatible since mutable=True implies that the "
"input variable may be both aliased (borrow=True) and "
"over-written. We set borrow=True and continue.",
variable, name)
borrow = True borrow = True
self.strict = strict self.strict = strict
self.allow_downcast = allow_downcast self.allow_downcast = allow_downcast
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论