提交 aeabc8a9 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Added borrow flag to Param as well, as it should be ( to some degree ?) in

sync with In.
上级 8f856565
...@@ -164,6 +164,11 @@ class In(SymbolicInput): ...@@ -164,6 +164,11 @@ class In(SymbolicInput):
True: permit the compiled function to modify the python object being passed as the input True: permit the compiled function to modify the python object being passed as the input
False: do not permit the compiled function to modify the python object being passed as the input. False: do not permit the compiled function to modify the python object being passed as the input.
borrow: Bool (default: False if update is None, True if update is not None)
True: permit the output of the compiled function to be aliased to the input
False: do not permit any output to be aliased to the input
strict: Bool (default: False) strict: Bool (default: False)
True: means that the value you pass for this input must have exactly the right type True: means that the value you pass for this input must have exactly the right type
False: the value you pass for this input may be cast automatically to the proper type False: the value you pass for this input may be cast automatically to the proper type
......
...@@ -244,7 +244,7 @@ def rebuild_collect_shared( outputs ...@@ -244,7 +244,7 @@ def rebuild_collect_shared( outputs
class Param(object): class Param(object):
def __init__(self, variable, default=None, name=None, mutable=False, def __init__(self, variable, default=None, name=None, mutable=False,
strict=False, allow_downcast=None, implicit=None): strict=False, allow_downcast=None, implicit=None, borrow = False):
""" """
:param variable: A variable in an expression graph to use as a compiled-function parameter :param variable: A variable in an expression graph to use as a compiled-function parameter
...@@ -255,6 +255,11 @@ class Param(object): ...@@ -255,6 +255,11 @@ class Param(object):
:param mutable: True -> function is allowed to modify this argument. :param mutable: True -> function is allowed to modify this argument.
:param borrow: True -> function is allowed to alias some output to
this input
False: do not permit any output to be aliased to the input
:param strict: False -> function arguments may be copied or casted to match the :param strict: False -> function arguments may be copied or casted to match the
type required by the parameter `variable`. True -> function arguments must exactly match the type type required by the parameter `variable`. True -> function arguments must exactly match the type
required by `variable`. required by `variable`.
...@@ -274,6 +279,7 @@ class Param(object): ...@@ -274,6 +279,7 @@ class Param(object):
self.strict = strict self.strict = strict
self.allow_downcast = allow_downcast self.allow_downcast = allow_downcast
self.implicit = implicit self.implicit = implicit
self.borrow = borrow
def pfunc(params, outputs=None, mode=None, updates=[], givens=[], def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
no_default_updates=False, accept_inplace=False, name=None, no_default_updates=False, accept_inplace=False, name=None,
...@@ -396,6 +402,7 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None): ...@@ -396,6 +402,7 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
value=param.default, value=param.default,
mutable=param.mutable, mutable=param.mutable,
strict=param.strict, strict=param.strict,
borrow = param.borrow,
allow_downcast=param.allow_downcast, allow_downcast=param.allow_downcast,
implicit = param.implicit) implicit = param.implicit)
raise TypeError('Unknown parameter type: %s' % type(param)) raise TypeError('Unknown parameter type: %s' % type(param))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论