Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
461832c5
提交
461832c5
authored
11月 20, 2022
作者:
Brandon T. Willard
提交者:
Ricardo Vieira
11月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Clean up docstrings and errors relating to SharedVariable
上级
b132036d
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
53 行增加
和
90 行删除
+53
-90
shared.rst
doc/library/compile/shared.rst
+3
-4
pfunc.py
pytensor/compile/function/pfunc.py
+16
-12
sharedvalue.py
pytensor/compile/sharedvalue.py
+11
-59
var.py
pytensor/tensor/random/var.py
+1
-1
sharedvar.py
pytensor/tensor/sharedvar.py
+6
-14
test_pfunc.py
tests/compile/function/test_pfunc.py
+16
-0
没有找到文件。
doc/library/compile/shared.rst
浏览文件 @
461832c5
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
.. class:: SharedVariable
.. class:: SharedVariable
Variable with
Storage that is shared between
functions that it appears in.
Variable with
storage that is shared between the compiled
functions that it appears in.
These variables are meant to be created by registered *shared constructors*
These variables are meant to be created by registered *shared constructors*
(see :func:`shared_constructor`).
(see :func:`shared_constructor`).
...
@@ -68,7 +68,6 @@
...
@@ -68,7 +68,6 @@
A container to use for this SharedVariable when it is an implicit function parameter.
A container to use for this SharedVariable when it is an implicit function parameter.
:type: class:`Container`
.. autofunction:: shared
.. autofunction:: shared
...
@@ -76,10 +75,10 @@
...
@@ -76,10 +75,10 @@
Append `ctor` to the list of shared constructors (see :func:`shared`).
Append `ctor` to the list of shared constructors (see :func:`shared`).
Each registered constructor `
`ctor`
` will be called like this:
Each registered constructor `
ctor
` will be called like this:
.. code-block:: python
.. code-block:: python
ctor(value, name=name, strict=strict, **kwargs)
ctor(value, name=name, strict=strict, **kwargs)
If it do not support given value, it must raise a
TypeError
.
If it do not support given value, it must raise a
`TypeError`
.
pytensor/compile/function/pfunc.py
浏览文件 @
461832c5
...
@@ -78,10 +78,12 @@ def rebuild_collect_shared(
...
@@ -78,10 +78,12 @@ def rebuild_collect_shared(
shared_inputs
=
[]
shared_inputs
=
[]
def
clone_v_get_shared_updates
(
v
,
copy_inputs_over
):
def
clone_v_get_shared_updates
(
v
,
copy_inputs_over
):
"""
r"""Clones a variable and its inputs recursively until all are in `clone_d`.
Clones a variable and its inputs recursively until all are in clone_d.
Also appends all shared variables met along the way to shared inputs,
Also, it appends all `SharedVariable`\s met along the way to
and their default_update (if applicable) to update_d and update_expr.
`shared_inputs` and their corresponding
`SharedVariable.default_update`\s (when applicable) to `update_d` and
`update_expr`.
"""
"""
# this co-recurses with clone_a
# this co-recurses with clone_a
...
@@ -419,22 +421,24 @@ def construct_pfunc_ins_and_outs(
...
@@ -419,22 +421,24 @@ def construct_pfunc_ins_and_outs(
givens
=
[]
givens
=
[]
if
not
isinstance
(
params
,
(
list
,
tuple
)):
if
not
isinstance
(
params
,
(
list
,
tuple
)):
raise
Exception
(
"in pfunc() the first argument must be a list or "
"
a tuple"
)
raise
TypeError
(
"The `params` argument must be a list or
a tuple"
)
if
not
isinstance
(
no_default_updates
,
bool
)
and
not
isinstance
(
if
not
isinstance
(
no_default_updates
,
bool
)
and
not
isinstance
(
no_default_updates
,
list
no_default_updates
,
list
):
):
raise
TypeError
(
"
no_default_update should be either a boolean or "
"a
list"
)
raise
TypeError
(
"
The `no_default_update` argument must be a boolean or
list"
)
if
len
(
updates
)
>
0
and
any
(
if
len
(
updates
)
>
0
and
not
all
(
isinstance
(
v
,
Variable
)
for
v
in
iter_over_pairs
(
updates
)
isinstance
(
pair
,
(
tuple
,
list
))
and
len
(
pair
)
==
2
and
isinstance
(
pair
[
0
],
Variable
)
for
pair
in
iter_over_pairs
(
updates
)
):
):
raise
ValueError
(
raise
TypeError
(
"The updates parameter must be an OrderedDict/dict or a list of "
"The `updates` parameter must be an ordered mapping or a list of pairs"
"lists/tuples with 2 elements"
)
)
#
t
ransform params into pytensor.compile.In objects.
#
T
ransform params into pytensor.compile.In objects.
inputs
=
[
inputs
=
[
_pfunc_param_to_in
(
p
,
allow_downcast
=
allow_input_downcast
)
for
p
in
params
_pfunc_param_to_in
(
p
,
allow_downcast
=
allow_input_downcast
)
for
p
in
params
]
]
...
...
pytensor/compile/sharedvalue.py
浏览文件 @
461832c5
"""
"""Provide a simple user friendly API to PyTensor-managed memory."""
Provide a simple user friendly API to PyTensor-managed memory.
"""
import
copy
import
copy
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
...
@@ -30,52 +27,14 @@ def collect_new_shareds():
...
@@ -30,52 +27,14 @@ def collect_new_shareds():
class
SharedVariable
(
Variable
):
class
SharedVariable
(
Variable
):
"""
"""Variable that is shared between compiled functions."""
Variable that is (defaults to being) shared between functions that
it appears in.
Parameters
----------
name : str
The name for this variable (see `Variable`).
type : str
The type for this variable (see `Variable`).
value
A value to associate with this variable (a new container will be
created).
strict
True : assignments to .value will not be cast or copied, so they must
have the correct type.
allow_downcast
Only applies if `strict` is False.
True : allow assigned value to lose precision when cast during
assignment.
False : never allow precision loss.
None : only allow downcasting of a Python float to a scalar floatX.
container
The container to use for this variable. Illegal to pass this as well as
a value.
Notes
-----
For more user-friendly constructor, see `shared`.
"""
container
:
Optional
[
Container
]
=
None
# Container object
container
=
None
"""
"""
A container to use for this SharedVariable when it is an implicit
A container to use for this SharedVariable when it is an implicit
function parameter.
function parameter.
:type: `Container`
"""
"""
# default_update
# If this member is present, its value will be used as the "update" for
# this Variable, unless another update value has been passed to "function",
# or the "no_default_updates" list passed to "function" contains it.
def
__init__
(
self
,
name
,
type
,
value
,
strict
,
allow_downcast
=
None
,
container
=
None
):
def
__init__
(
self
,
name
,
type
,
value
,
strict
,
allow_downcast
=
None
,
container
=
None
):
super
()
.
__init__
(
type
=
type
,
name
=
name
,
owner
=
None
,
index
=
None
)
super
()
.
__init__
(
type
=
type
,
name
=
name
,
owner
=
None
,
index
=
None
)
...
@@ -207,37 +166,30 @@ def shared_constructor(ctor, remove=False):
...
@@ -207,37 +166,30 @@ def shared_constructor(ctor, remove=False):
def
shared
(
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
**
kwargs
):
def
shared
(
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
**
kwargs
):
"""Return a SharedVariable Variable, initialized with a copy or
r"""Create a `SharedVariable` initialized with a copy or reference of `value`.
reference of `value`.
This function iterates over constructor functions to find a
This function iterates over constructor functions to find a
suitable
SharedVariable
subclass. The suitable one is the first
suitable
`SharedVariable`
subclass. The suitable one is the first
constructor that accept the given value. See the documentation of
constructor that accept the given value. See the documentation of
:func:`shared_constructor` for the definition of a constructor
:func:`shared_constructor` for the definition of a constructor
function.
function.
This function is meant as a convenient default. If you want to use a
This function is meant as a convenient default. If you want to use a
specific shared variable constructor, consider calling it directly.
specific constructor, consider calling it directly.
``pytensor.shared`` is a shortcut to this function.
.. attribute:: constructors
A list of shared variable constructors that will be tried in reverse
`pytensor.shared` is a shortcut to this function.
order.
Notes
Notes
-----
-----
By passing kwargs, you effectively limit the set of potential constructors
By passing kwargs, you effectively limit the set of potential constructors
to those that can accept those kwargs.
to those that can accept those kwargs.
Some shared variable have `
`borrow`` as extra kwargs
.
Some shared variable have `
borrow` as a kwarg
.
Some shared variable have ``broadcastable`` as extra kwargs
. As shared
`SharedVariable`\s of `TensorType` have `broadcastable` as a kwarg
. As shared
variable shapes can change, all dimensions default to not being
variable shapes can change, all dimensions default to not being
broadcastable, even if ``value`` has a shape of 1 along some dimension.
broadcastable, even if `value` has a shape of 1 along some dimension.
This parameter allows you to create for example a `row` or `column` 2d
This parameter allows one to create for example a row or column tensor.
tensor.
"""
"""
...
...
pytensor/tensor/random/var.py
浏览文件 @
461832c5
...
@@ -22,7 +22,7 @@ class RandomGeneratorSharedVariable(SharedVariable):
...
@@ -22,7 +22,7 @@ class RandomGeneratorSharedVariable(SharedVariable):
def
randomgen_constructor
(
def
randomgen_constructor
(
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
borrow
=
False
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
borrow
=
False
):
):
r"""`SharedVariable`
C
onstructor for NumPy's `Generator` and/or `RandomState`."""
r"""`SharedVariable`
c
onstructor for NumPy's `Generator` and/or `RandomState`."""
if
isinstance
(
value
,
np
.
random
.
RandomState
):
if
isinstance
(
value
,
np
.
random
.
RandomState
):
rng_sv_type
=
RandomStateSharedVariable
rng_sv_type
=
RandomStateSharedVariable
rng_type
=
random_state_type
rng_type
=
random_state_type
...
...
pytensor/tensor/sharedvar.py
浏览文件 @
461832c5
...
@@ -41,8 +41,7 @@ def tensor_constructor(
...
@@ -41,8 +41,7 @@ def tensor_constructor(
target
=
"cpu"
,
target
=
"cpu"
,
broadcastable
=
None
,
broadcastable
=
None
,
):
):
"""
r"""`SharedVariable` constructor for `TensorType`\s.
SharedVariable Constructor for TensorType.
Notes
Notes
-----
-----
...
@@ -64,9 +63,8 @@ def tensor_constructor(
...
@@ -64,9 +63,8 @@ def tensor_constructor(
if
not
isinstance
(
value
,
np
.
ndarray
):
if
not
isinstance
(
value
,
np
.
ndarray
):
raise
TypeError
()
raise
TypeError
()
# if no shape is given, then the default is to assume that
# If no shape is given, then the default is to assume that the value might
# the value might be resized in any dimension in the future.
# be resized in any dimension in the future.
#
if
shape
is
None
:
if
shape
is
None
:
shape
=
(
None
,)
*
len
(
value
.
shape
)
shape
=
(
None
,)
*
len
(
value
.
shape
)
type
=
TensorType
(
value
.
dtype
,
shape
=
shape
)
type
=
TensorType
(
value
.
dtype
,
shape
=
shape
)
...
@@ -79,13 +77,6 @@ def tensor_constructor(
...
@@ -79,13 +77,6 @@ def tensor_constructor(
)
)
# TensorSharedVariable brings in the tensor operators, is not ideal, but works
# as long as we don't do purely scalar-scalar operations
# _tensor_py_operators is first to have its version of __{gt,ge,lt,le}__
#
# N.B. THERE IS ANOTHER CLASS CALLED ScalarSharedVariable in the
# pytensor.scalar.sharedvar file. It is not registered as a shared_constructor,
# this one is.
class
ScalarSharedVariable
(
_tensor_py_operators
,
SharedVariable
):
class
ScalarSharedVariable
(
_tensor_py_operators
,
SharedVariable
):
pass
pass
...
@@ -94,8 +85,9 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
...
@@ -94,8 +85,9 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
def
scalar_constructor
(
def
scalar_constructor
(
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
borrow
=
False
,
target
=
"cpu"
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
borrow
=
False
,
target
=
"cpu"
):
):
"""
"""`SharedVariable` constructor for scalar values.
SharedVariable constructor for scalar values. Default: int64 or float64.
Default: int64 or float64.
Notes
Notes
-----
-----
...
...
tests/compile/function/test_pfunc.py
浏览文件 @
461832c5
...
@@ -36,6 +36,22 @@ def data_of(s):
...
@@ -36,6 +36,22 @@ def data_of(s):
class
TestPfunc
:
class
TestPfunc
:
def
test_errors
(
self
):
a
=
lscalar
()
b
=
shared
(
1
)
with
pytest
.
raises
(
TypeError
):
pfunc
({
a
},
a
+
b
)
with
pytest
.
raises
(
TypeError
):
pfunc
([
a
],
a
+
b
,
no_default_updates
=
1
)
with
pytest
.
raises
(
TypeError
):
pfunc
([
a
],
a
+
b
,
updates
=
[{
b
,
a
}])
with
pytest
.
raises
(
TypeError
):
pfunc
([
a
],
a
+
b
,
updates
=
[(
1
,
b
)])
def
test_doc
(
self
):
def
test_doc
(
self
):
# Ensure the code given in pfunc.txt works as expected
# Ensure the code given in pfunc.txt works as expected
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论