Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
08eaf0e3
提交
08eaf0e3
authored
11月 20, 2022
作者:
Brandon T. Willard
提交者:
Ricardo Vieira
11月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use singledispatch to register SharedVariable type constructors
上级
461832c5
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
41 行增加
和
84 行删除
+41
-84
sharedvalue.py
pytensor/compile/sharedvalue.py
+15
-49
sharedvar.py
pytensor/sparse/sharedvar.py
+4
-7
var.py
pytensor/tensor/random/var.py
+2
-3
sharedvar.py
pytensor/tensor/sharedvar.py
+20
-25
没有找到文件。
pytensor/compile/sharedvalue.py
浏览文件 @
08eaf0e3
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
import
copy
import
copy
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
functools
import
singledispatch
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
pytensor.graph.basic
import
Variable
from
pytensor.graph.basic
import
Variable
...
@@ -157,14 +158,6 @@ class SharedVariable(Variable):
...
@@ -157,14 +158,6 @@ class SharedVariable(Variable):
self
.
_default_update
=
value
self
.
_default_update
=
value
def
shared_constructor
(
ctor
,
remove
=
False
):
if
remove
:
shared
.
constructors
.
remove
(
ctor
)
else
:
shared
.
constructors
.
append
(
ctor
)
return
ctor
def
shared
(
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
**
kwargs
):
def
shared
(
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
**
kwargs
):
r"""Create a `SharedVariable` initialized with a copy or reference of `value`.
r"""Create a `SharedVariable` initialized with a copy or reference of `value`.
...
@@ -193,53 +186,26 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
...
@@ -193,53 +186,26 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
"""
"""
try
:
if
isinstance
(
value
,
Variable
):
if
isinstance
(
value
,
Variable
):
raise
TypeError
(
"Shared variable values can not be symbolic."
)
raise
TypeError
(
"Shared variable constructor needs numeric "
"values and not symbolic variables."
)
for
ctor
in
reversed
(
shared
.
constructors
):
try
:
var
=
ctor
(
value
,
name
=
name
,
strict
=
strict
,
allow_downcast
=
allow_downcast
,
**
kwargs
,
)
add_tag_trace
(
var
)
return
var
except
TypeError
:
continue
# This may happen when kwargs were supplied
# if kwargs were given, the generic_constructor won't be callable.
#
# This was done on purpose, the rationale being that if kwargs
# were supplied, the user didn't want them to be ignored.
try
:
var
=
shared_constructor
(
value
,
name
=
name
,
strict
=
strict
,
allow_downcast
=
allow_downcast
,
**
kwargs
,
)
add_tag_trace
(
var
)
return
var
except
MemoryError
as
e
:
except
MemoryError
as
e
:
e
.
args
=
e
.
args
+
(
"Consider using `pytensor.shared(..., borrow=True)`"
,)
e
.
args
=
e
.
args
+
(
"Consider using `pytensor.shared(..., borrow=True)`"
,)
raise
raise
raise
TypeError
(
"No suitable SharedVariable constructor could be found."
" Are you sure all kwargs are supported?"
" We do not support the parameter dtype or type."
f
' value="{value}". parameters="{kwargs}"'
)
shared
.
constructors
=
[]
@singledispatch
@shared_constructor
def
shared_constructor
(
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
**
kwargs
):
def
generic_constructor
(
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
):
"""
SharedVariable Constructor.
"""
return
SharedVariable
(
return
SharedVariable
(
type
=
generic
,
type
=
generic
,
value
=
value
,
value
=
value
,
...
...
pytensor/sparse/sharedvar.py
浏览文件 @
08eaf0e3
...
@@ -11,21 +11,18 @@ class SparseTensorSharedVariable(_sparse_py_operators, SharedVariable):
...
@@ -11,21 +11,18 @@ class SparseTensorSharedVariable(_sparse_py_operators, SharedVariable):
format
=
property
(
lambda
self
:
self
.
type
.
format
)
format
=
property
(
lambda
self
:
self
.
type
.
format
)
@shared_constructor
@shared_constructor
.register
(
scipy
.
sparse
.
spmatrix
)
def
sparse_constructor
(
def
sparse_constructor
(
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
borrow
=
False
,
format
=
None
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
borrow
=
False
,
format
=
None
):
):
if
not
isinstance
(
value
,
scipy
.
sparse
.
spmatrix
):
raise
TypeError
(
"Expected a sparse matrix in the sparse shared variable constructor. Received: "
,
value
.
__class__
,
)
if
format
is
None
:
if
format
is
None
:
format
=
value
.
format
format
=
value
.
format
type
=
SparseTensorType
(
format
=
format
,
dtype
=
value
.
dtype
)
type
=
SparseTensorType
(
format
=
format
,
dtype
=
value
.
dtype
)
if
not
borrow
:
if
not
borrow
:
value
=
copy
.
deepcopy
(
value
)
value
=
copy
.
deepcopy
(
value
)
return
SparseTensorSharedVariable
(
return
SparseTensorSharedVariable
(
type
=
type
,
value
=
value
,
name
=
name
,
strict
=
strict
,
allow_downcast
=
allow_downcast
type
=
type
,
value
=
value
,
name
=
name
,
strict
=
strict
,
allow_downcast
=
allow_downcast
)
)
pytensor/tensor/random/var.py
浏览文件 @
08eaf0e3
...
@@ -18,7 +18,8 @@ class RandomGeneratorSharedVariable(SharedVariable):
...
@@ -18,7 +18,8 @@ class RandomGeneratorSharedVariable(SharedVariable):
)
)
@shared_constructor
@shared_constructor.register
(
np
.
random
.
RandomState
)
@shared_constructor.register
(
np
.
random
.
Generator
)
def
randomgen_constructor
(
def
randomgen_constructor
(
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
borrow
=
False
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
borrow
=
False
):
):
...
@@ -29,8 +30,6 @@ def randomgen_constructor(
...
@@ -29,8 +30,6 @@ def randomgen_constructor(
elif
isinstance
(
value
,
np
.
random
.
Generator
):
elif
isinstance
(
value
,
np
.
random
.
Generator
):
rng_sv_type
=
RandomGeneratorSharedVariable
rng_sv_type
=
RandomGeneratorSharedVariable
rng_type
=
random_generator_type
rng_type
=
random_generator_type
else
:
raise
TypeError
()
if
not
borrow
:
if
not
borrow
:
value
=
copy
.
deepcopy
(
value
)
value
=
copy
.
deepcopy
(
value
)
...
...
pytensor/tensor/sharedvar.py
浏览文件 @
08eaf0e3
import
traceback
import
warnings
import
warnings
import
numpy
as
np
import
numpy
as
np
...
@@ -30,7 +29,7 @@ def _get_vector_length_TensorSharedVariable(var_inst, var):
...
@@ -30,7 +29,7 @@ def _get_vector_length_TensorSharedVariable(var_inst, var):
return
len
(
var
.
get_value
(
borrow
=
True
))
return
len
(
var
.
get_value
(
borrow
=
True
))
@shared_constructor
@shared_constructor
.register
(
np
.
ndarray
)
def
tensor_constructor
(
def
tensor_constructor
(
value
,
value
,
name
=
None
,
name
=
None
,
...
@@ -60,14 +59,13 @@ def tensor_constructor(
...
@@ -60,14 +59,13 @@ def tensor_constructor(
if
target
!=
"cpu"
:
if
target
!=
"cpu"
:
raise
TypeError
(
"not for cpu"
)
raise
TypeError
(
"not for cpu"
)
if
not
isinstance
(
value
,
np
.
ndarray
):
raise
TypeError
()
# If no shape is given, then the default is to assume that the value might
# If no shape is given, then the default is to assume that the value might
# be resized in any dimension in the future.
# be resized in any dimension in the future.
if
shape
is
None
:
if
shape
is
None
:
shape
=
(
None
,)
*
len
(
value
.
shape
)
shape
=
(
None
,)
*
value
.
ndim
type
=
TensorType
(
value
.
dtype
,
shape
=
shape
)
type
=
TensorType
(
value
.
dtype
,
shape
=
shape
)
return
TensorSharedVariable
(
return
TensorSharedVariable
(
type
=
type
,
type
=
type
,
value
=
np
.
array
(
value
,
copy
=
(
not
borrow
)),
value
=
np
.
array
(
value
,
copy
=
(
not
borrow
)),
...
@@ -81,7 +79,10 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
...
@@ -81,7 +79,10 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
pass
pass
@shared_constructor
@shared_constructor.register
(
np
.
number
)
@shared_constructor.register
(
float
)
@shared_constructor.register
(
int
)
@shared_constructor.register
(
complex
)
def
scalar_constructor
(
def
scalar_constructor
(
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
borrow
=
False
,
target
=
"cpu"
value
,
name
=
None
,
strict
=
False
,
allow_downcast
=
None
,
borrow
=
False
,
target
=
"cpu"
):
):
...
@@ -101,28 +102,22 @@ def scalar_constructor(
...
@@ -101,28 +102,22 @@ def scalar_constructor(
if
target
!=
"cpu"
:
if
target
!=
"cpu"
:
raise
TypeError
(
"not for cpu"
)
raise
TypeError
(
"not for cpu"
)
if
not
isinstance
(
value
,
(
np
.
number
,
float
,
int
,
complex
)):
raise
TypeError
()
try
:
try
:
dtype
=
value
.
dtype
dtype
=
value
.
dtype
except
Exception
:
except
AttributeError
:
dtype
=
np
.
asarray
(
value
)
.
dtype
dtype
=
np
.
asarray
(
value
)
.
dtype
dtype
=
str
(
dtype
)
dtype
=
str
(
dtype
)
value
=
_asarray
(
value
,
dtype
=
dtype
)
value
=
_asarray
(
value
,
dtype
=
dtype
)
tensor_type
=
TensorType
(
dtype
=
str
(
value
.
dtype
),
shape
=
[]
)
tensor_type
=
TensorType
(
dtype
=
str
(
value
.
dtype
),
shape
=
()
)
try
:
# Do not pass the dtype to asarray because we want this to fail if
# Do not pass the dtype to asarray because we want this to fail if
# strict is True and the types do not match.
# strict is True and the types do not match.
rval
=
ScalarSharedVariable
(
rval
=
ScalarSharedVariable
(
type
=
tensor_type
,
type
=
tensor_type
,
value
=
np
.
array
(
value
,
copy
=
True
),
value
=
np
.
array
(
value
,
copy
=
True
),
name
=
name
,
name
=
name
,
strict
=
strict
,
strict
=
strict
,
allow_downcast
=
allow_downcast
,
allow_downcast
=
allow_downcast
,
)
)
return
rval
return
rval
except
Exception
:
traceback
.
print_exc
()
raise
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论