Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9665120e
提交
9665120e
authored
7月 20, 2022
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
7月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Enable type checking for NumPy types
上级
bb40791b
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
36 行增加
和
33 行删除
+36
-33
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-0
basic.py
aesara/graph/basic.py
+3
-3
op.py
aesara/link/c/op.py
+9
-15
printing.py
aesara/printing.py
+7
-7
__init__.py
aesara/tensor/__init__.py
+14
-7
type.py
aesara/tensor/random/type.py
+1
-1
setup.cfg
setup.cfg
+1
-0
没有找到文件。
.pre-commit-config.yaml
浏览文件 @
9665120e
...
@@ -51,5 +51,6 @@ repos:
...
@@ -51,5 +51,6 @@ repos:
hooks
:
hooks
:
-
id
:
mypy
-
id
:
mypy
additional_dependencies
:
additional_dependencies
:
-
numpy>=1.20
-
types-filelock
-
types-filelock
-
types-setuptools
-
types-setuptools
aesara/graph/basic.py
浏览文件 @
9665120e
...
@@ -1671,14 +1671,14 @@ def equal_computations(
...
@@ -1671,14 +1671,14 @@ def equal_computations(
for
x
,
y
in
zip
(
xs
,
ys
):
for
x
,
y
in
zip
(
xs
,
ys
):
if
not
isinstance
(
x
,
Variable
)
and
not
isinstance
(
y
,
Variable
):
if
not
isinstance
(
x
,
Variable
)
and
not
isinstance
(
y
,
Variable
):
return
cast
(
bool
,
np
.
array_equal
(
x
,
y
)
)
return
np
.
array_equal
(
x
,
y
)
if
not
isinstance
(
x
,
Variable
):
if
not
isinstance
(
x
,
Variable
):
if
isinstance
(
y
,
Constant
):
if
isinstance
(
y
,
Constant
):
return
cast
(
bool
,
np
.
array_equal
(
y
.
data
,
x
)
)
return
np
.
array_equal
(
y
.
data
,
x
)
return
False
return
False
if
not
isinstance
(
y
,
Variable
):
if
not
isinstance
(
y
,
Variable
):
if
isinstance
(
x
,
Constant
):
if
isinstance
(
x
,
Constant
):
return
cast
(
bool
,
np
.
array_equal
(
x
.
data
,
y
)
)
return
np
.
array_equal
(
x
.
data
,
y
)
return
False
return
False
if
x
.
owner
and
not
y
.
owner
:
if
x
.
owner
and
not
y
.
owner
:
return
False
return
False
...
...
aesara/link/c/op.py
浏览文件 @
9665120e
...
@@ -544,25 +544,19 @@ class ExternalCOp(COp):
...
@@ -544,25 +544,19 @@ class ExternalCOp(COp):
vname
=
variable_names
[
i
]
vname
=
variable_names
[
i
]
macro_name
=
"DTYPE_"
+
vname
macro_items
=
(
f
"DTYPE_{vname}"
,
f
"npy_{v.type.dtype}"
)
macro_value
=
"npy_"
+
v
.
type
.
dtype
define_macros
.
append
(
define_template
%
macro_items
)
undef_macros
.
append
(
undef_template
%
macro_items
[
0
])
define_macros
.
append
(
define_template
%
(
macro_name
,
macro_value
))
undef_macros
.
append
(
undef_template
%
macro_name
)
d
=
np
.
dtype
(
v
.
type
.
dtype
)
d
=
np
.
dtype
(
v
.
type
.
dtype
)
macro_name
=
"TYPENUM_"
+
vname
macro_items_2
=
(
f
"TYPENUM_{vname}"
,
d
.
num
)
macro_value
=
d
.
num
define_macros
.
append
(
define_template
%
macro_items_2
)
undef_macros
.
append
(
undef_template
%
macro_items_2
[
0
])
define_macros
.
append
(
define_template
%
(
macro_name
,
macro_value
))
undef_macros
.
append
(
undef_template
%
macro_name
)
macro_name
=
"ITEMSIZE_"
+
vname
macro_value
=
d
.
itemsize
define_macros
.
append
(
define_template
%
(
macro_name
,
macro_value
))
macro_items_3
=
(
f
"ITEMSIZE_{vname}"
,
d
.
itemsize
)
undef_macros
.
append
(
undef_template
%
macro_name
)
define_macros
.
append
(
define_template
%
macro_items_3
)
undef_macros
.
append
(
undef_template
%
macro_items_3
[
0
])
# Generate a macro to mark code as being apply-specific
# Generate a macro to mark code as being apply-specific
define_macros
.
append
(
define_template
%
(
"APPLY_SPECIFIC(str)"
,
f
"str##_{name}"
))
define_macros
.
append
(
define_template
%
(
"APPLY_SPECIFIC(str)"
,
f
"str##_{name}"
))
...
...
aesara/printing.py
浏览文件 @
9665120e
...
@@ -104,7 +104,7 @@ def op_debug_information(op: Op, node: Apply) -> Dict[Apply, Dict[Variable, str]
...
@@ -104,7 +104,7 @@ def op_debug_information(op: Op, node: Apply) -> Dict[Apply, Dict[Variable, str]
def
debugprint
(
def
debugprint
(
obj
:
Union
[
graph_like
:
Union
[
Union
[
Variable
,
Apply
,
Function
,
FunctionGraph
],
Union
[
Variable
,
Apply
,
Function
,
FunctionGraph
],
Sequence
[
Union
[
Variable
,
Apply
,
Function
,
FunctionGraph
]],
Sequence
[
Union
[
Variable
,
Apply
,
Function
,
FunctionGraph
]],
],
],
...
@@ -139,7 +139,7 @@ def debugprint(
...
@@ -139,7 +139,7 @@ def debugprint(
Parameters
Parameters
----------
----------
obj
graph_like
The object(s) to be printed.
The object(s) to be printed.
depth
depth
Print graph to this depth (``-1`` for unlimited).
Print graph to this depth (``-1`` for unlimited).
...
@@ -149,7 +149,7 @@ def debugprint(
...
@@ -149,7 +149,7 @@ def debugprint(
When `file` extends `TextIO`, print to it; when `file` is
When `file` extends `TextIO`, print to it; when `file` is
equal to ``"str"``, return a string; when `file` is ``None``, print to
equal to ``"str"``, return a string; when `file` is ``None``, print to
`sys.stdout`.
`sys.stdout`.
id
s
id
_type
Determines the type of identifier used for `Variable`\s:
Determines the type of identifier used for `Variable`\s:
- ``"id"``: print the python id value,
- ``"id"``: print the python id value,
- ``"int"``: print integer character,
- ``"int"``: print integer character,
...
@@ -213,12 +213,12 @@ def debugprint(
...
@@ -213,12 +213,12 @@ def debugprint(
topo_orders
:
List
[
Optional
[
List
[
Apply
]]]
=
[]
topo_orders
:
List
[
Optional
[
List
[
Apply
]]]
=
[]
storage_maps
:
List
[
Optional
[
StorageMapType
]]
=
[]
storage_maps
:
List
[
Optional
[
StorageMapType
]]
=
[]
if
isinstance
(
obj
,
(
list
,
tuple
,
set
)):
if
isinstance
(
graph_like
,
(
list
,
tuple
,
set
)):
lobj
=
obj
graphs
=
graph_like
else
:
else
:
lobj
=
[
obj
]
graphs
=
(
graph_like
,)
for
obj
in
lobj
:
for
obj
in
graphs
:
if
isinstance
(
obj
,
Variable
):
if
isinstance
(
obj
,
Variable
):
outputs_to_print
.
append
(
obj
)
outputs_to_print
.
append
(
obj
)
profile_list
.
append
(
None
)
profile_list
.
append
(
None
)
...
...
aesara/tensor/__init__.py
浏览文件 @
9665120e
"""Symbolic tensor types and constructor functions."""
"""Symbolic tensor types and constructor functions."""
from
functools
import
singledispatch
from
functools
import
singledispatch
from
typing
import
Any
,
Callable
,
NoReturn
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
NoReturn
,
Optional
,
Sequence
,
Union
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.basic
import
Constant
,
Variable
from
aesara.graph.op
import
Op
from
aesara.graph.op
import
Op
if
TYPE_CHECKING
:
from
numpy.typing
import
ArrayLike
,
NDArray
TensorLike
=
Union
[
Variable
,
Sequence
[
Variable
],
"ArrayLike"
]
def
as_tensor_variable
(
def
as_tensor_variable
(
x
:
Any
,
name
:
Optional
[
str
]
=
None
,
ndim
:
Optional
[
int
]
=
None
,
**
kwargs
x
:
TensorLike
,
name
:
Optional
[
str
]
=
None
,
ndim
:
Optional
[
int
]
=
None
,
**
kwargs
)
->
"TensorVariable"
:
)
->
"TensorVariable"
:
"""Convert `x` into an equivalent `TensorVariable`.
"""Convert `x` into an equivalent `TensorVariable`.
...
@@ -44,12 +51,12 @@ def as_tensor_variable(
...
@@ -44,12 +51,12 @@ def as_tensor_variable(
@singledispatch
@singledispatch
def
_as_tensor_variable
(
def
_as_tensor_variable
(
x
,
name
:
Optional
[
str
],
ndim
:
Optional
[
int
],
**
kwargs
x
:
TensorLike
,
name
:
Optional
[
str
],
ndim
:
Optional
[
int
],
**
kwargs
)
->
"TensorVariable"
:
)
->
"TensorVariable"
:
raise
NotImplementedError
(
f
"Cannot convert {x} to a tensor variable."
)
raise
NotImplementedError
(
f
"Cannot convert {x
!r
} to a tensor variable."
)
def
get_vector_length
(
v
:
Any
)
:
def
get_vector_length
(
v
:
TensorLike
)
->
int
:
"""Return the run-time length of a symbolic vector, when possible.
"""Return the run-time length of a symbolic vector, when possible.
Parameters
Parameters
...
@@ -80,13 +87,13 @@ def get_vector_length(v: Any):
...
@@ -80,13 +87,13 @@ def get_vector_length(v: Any):
@singledispatch
@singledispatch
def
_get_vector_length
(
op
:
Union
[
Op
,
Variable
],
var
:
Variable
):
def
_get_vector_length
(
op
:
Union
[
Op
,
Variable
],
var
:
Variable
)
->
int
:
"""`Op`-based dispatch for `get_vector_length`."""
"""`Op`-based dispatch for `get_vector_length`."""
raise
ValueError
(
f
"Length of {var} cannot be determined"
)
raise
ValueError
(
f
"Length of {var} cannot be determined"
)
@_get_vector_length.register
(
Constant
)
@_get_vector_length.register
(
Constant
)
def
_get_vector_length_Constant
(
var_inst
,
var
)
:
def
_get_vector_length_Constant
(
op
:
Union
[
Op
,
Variable
],
var
:
Constant
)
->
int
:
return
len
(
var
.
data
)
return
len
(
var
.
data
)
...
...
aesara/tensor/random/type.py
浏览文件 @
9665120e
...
@@ -28,7 +28,7 @@ class RandomType(Type[T]):
...
@@ -28,7 +28,7 @@ class RandomType(Type[T]):
@staticmethod
@staticmethod
def
may_share_memory
(
a
:
T
,
b
:
T
):
def
may_share_memory
(
a
:
T
,
b
:
T
):
return
a
.
_bit_generator
is
b
.
_bit_generator
return
a
.
_bit_generator
is
b
.
_bit_generator
# type: ignore[attr-defined]
class
RandomStateType
(
RandomType
[
np
.
random
.
RandomState
]):
class
RandomStateType
(
RandomType
[
np
.
random
.
RandomState
]):
...
...
setup.cfg
浏览文件 @
9665120e
...
@@ -83,6 +83,7 @@ warn_unreachable = True
...
@@ -83,6 +83,7 @@ warn_unreachable = True
show_error_codes = True
show_error_codes = True
allow_redefinition = False
allow_redefinition = False
files = aesara,tests
files = aesara,tests
plugins = numpy.typing.mypy_plugin
[mypy-versioneer]
[mypy-versioneer]
check_untyped_defs = False
check_untyped_defs = False
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论