Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
4b41e092
提交
4b41e092
authored
11月 19, 2024
作者:
Virgile Andreani
提交者:
Ricardo Vieira
11月 19, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add exceptions for hot loops
上级
54fba943
隐藏空白字符变更
内嵌
并排
正在显示
19 个修改的文件
包含
77 行增加
和
48 行删除
+77
-48
builders.py
pytensor/compile/builders.py
+2
-1
types.py
pytensor/compile/function/types.py
+8
-4
ifelse.py
pytensor/ifelse.py
+4
-2
basic.py
pytensor/link/basic.py
+6
-3
basic.py
pytensor/link/c/basic.py
+6
-5
basic.py
pytensor/link/numba/dispatch/basic.py
+2
-1
shape.py
pytensor/link/pytorch/dispatch/shape.py
+2
-1
utils.py
pytensor/link/utils.py
+4
-2
basic.py
pytensor/scalar/basic.py
+4
-2
loop.py
pytensor/scalar/loop.py
+3
-2
op.py
pytensor/scan/op.py
+2
-1
basic.py
pytensor/tensor/basic.py
+2
-1
blockwise.py
pytensor/tensor/blockwise.py
+6
-4
elemwise.py
pytensor/tensor/elemwise.py
+5
-3
basic.py
pytensor/tensor/random/basic.py
+2
-1
utils.py
pytensor/tensor/random/utils.py
+8
-4
subtensor.py
pytensor/tensor/rewriting/subtensor.py
+1
-1
shape.py
pytensor/tensor/shape.py
+6
-8
type.py
pytensor/tensor/type.py
+4
-2
没有找到文件。
pytensor/compile/builders.py
浏览文件 @
4b41e092
...
@@ -863,5 +863,6 @@ class OpFromGraph(Op, HasInnerGraph):
...
@@ -863,5 +863,6 @@ class OpFromGraph(Op, HasInnerGraph):
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
variables
=
self
.
fn
(
*
inputs
)
variables
=
self
.
fn
(
*
inputs
)
assert
len
(
variables
)
==
len
(
outputs
)
assert
len
(
variables
)
==
len
(
outputs
)
for
output
,
variable
in
zip
(
outputs
,
variables
,
strict
=
True
):
# strict=False because asserted above
for
output
,
variable
in
zip
(
outputs
,
variables
,
strict
=
False
):
output
[
0
]
=
variable
output
[
0
]
=
variable
pytensor/compile/function/types.py
浏览文件 @
4b41e092
...
@@ -1002,8 +1002,9 @@ class Function:
...
@@ -1002,8 +1002,9 @@ class Function:
# if we are allowing garbage collection, remove the
# if we are allowing garbage collection, remove the
# output reference from the internal storage cells
# output reference from the internal storage cells
if
getattr
(
self
.
vm
,
"allow_gc"
,
False
):
if
getattr
(
self
.
vm
,
"allow_gc"
,
False
):
# strict=False because we are in a hot loop
for
o_container
,
o_variable
in
zip
(
for
o_container
,
o_variable
in
zip
(
self
.
output_storage
,
self
.
maker
.
fgraph
.
outputs
,
strict
=
Tru
e
self
.
output_storage
,
self
.
maker
.
fgraph
.
outputs
,
strict
=
Fals
e
):
):
if
o_variable
.
owner
is
not
None
:
if
o_variable
.
owner
is
not
None
:
# this node is the variable of computation
# this node is the variable of computation
...
@@ -1012,8 +1013,9 @@ class Function:
...
@@ -1012,8 +1013,9 @@ class Function:
if
getattr
(
self
.
vm
,
"need_update_inputs"
,
True
):
if
getattr
(
self
.
vm
,
"need_update_inputs"
,
True
):
# Update the inputs that have an update function
# Update the inputs that have an update function
# strict=False because we are in a hot loop
for
input
,
storage
in
reversed
(
for
input
,
storage
in
reversed
(
list
(
zip
(
self
.
maker
.
expanded_inputs
,
input_storage
,
strict
=
Tru
e
))
list
(
zip
(
self
.
maker
.
expanded_inputs
,
input_storage
,
strict
=
Fals
e
))
):
):
if
input
.
update
is
not
None
:
if
input
.
update
is
not
None
:
storage
.
data
=
outputs
.
pop
()
storage
.
data
=
outputs
.
pop
()
...
@@ -1044,7 +1046,8 @@ class Function:
...
@@ -1044,7 +1046,8 @@ class Function:
assert
len
(
self
.
output_keys
)
==
len
(
outputs
)
assert
len
(
self
.
output_keys
)
==
len
(
outputs
)
if
output_subset
is
None
:
if
output_subset
is
None
:
return
dict
(
zip
(
self
.
output_keys
,
outputs
,
strict
=
True
))
# strict=False because we are in a hot loop
return
dict
(
zip
(
self
.
output_keys
,
outputs
,
strict
=
False
))
else
:
else
:
return
{
return
{
self
.
output_keys
[
index
]:
outputs
[
index
]
self
.
output_keys
[
index
]:
outputs
[
index
]
...
@@ -1111,8 +1114,9 @@ def _pickle_Function(f):
...
@@ -1111,8 +1114,9 @@ def _pickle_Function(f):
ins
=
list
(
f
.
input_storage
)
ins
=
list
(
f
.
input_storage
)
input_storage
=
[]
input_storage
=
[]
# strict=False because we are in a hot loop
for
(
input
,
indices
,
inputs
),
(
required
,
refeed
,
default
)
in
zip
(
for
(
input
,
indices
,
inputs
),
(
required
,
refeed
,
default
)
in
zip
(
f
.
indices
,
f
.
defaults
,
strict
=
Tru
e
f
.
indices
,
f
.
defaults
,
strict
=
Fals
e
):
):
input_storage
.
append
(
ins
[
0
])
input_storage
.
append
(
ins
[
0
])
del
ins
[
0
]
del
ins
[
0
]
...
...
pytensor/ifelse.py
浏览文件 @
4b41e092
...
@@ -305,7 +305,8 @@ class IfElse(_NoPythonOp):
...
@@ -305,7 +305,8 @@ class IfElse(_NoPythonOp):
if
len
(
ls
)
>
0
:
if
len
(
ls
)
>
0
:
return
ls
return
ls
else
:
else
:
for
out
,
t
in
zip
(
outputs
,
input_true_branch
,
strict
=
True
):
# strict=False because we are in a hot loop
for
out
,
t
in
zip
(
outputs
,
input_true_branch
,
strict
=
False
):
compute_map
[
out
][
0
]
=
1
compute_map
[
out
][
0
]
=
1
val
=
storage_map
[
t
][
0
]
val
=
storage_map
[
t
][
0
]
if
self
.
as_view
:
if
self
.
as_view
:
...
@@ -325,7 +326,8 @@ class IfElse(_NoPythonOp):
...
@@ -325,7 +326,8 @@ class IfElse(_NoPythonOp):
if
len
(
ls
)
>
0
:
if
len
(
ls
)
>
0
:
return
ls
return
ls
else
:
else
:
for
out
,
f
in
zip
(
outputs
,
inputs_false_branch
,
strict
=
True
):
# strict=False because we are in a hot loop
for
out
,
f
in
zip
(
outputs
,
inputs_false_branch
,
strict
=
False
):
compute_map
[
out
][
0
]
=
1
compute_map
[
out
][
0
]
=
1
# can't view both outputs unless destroyhandler
# can't view both outputs unless destroyhandler
# improves
# improves
...
...
pytensor/link/basic.py
浏览文件 @
4b41e092
...
@@ -539,12 +539,14 @@ class WrapLinker(Linker):
...
@@ -539,12 +539,14 @@ class WrapLinker(Linker):
def
f
():
def
f
():
for
inputs
in
input_lists
[
1
:]:
for
inputs
in
input_lists
[
1
:]:
for
input1
,
input2
in
zip
(
inputs0
,
inputs
,
strict
=
True
):
# strict=False because we are in a hot loop
for
input1
,
input2
in
zip
(
inputs0
,
inputs
,
strict
=
False
):
input2
.
storage
[
0
]
=
copy
(
input1
.
storage
[
0
])
input2
.
storage
[
0
]
=
copy
(
input1
.
storage
[
0
])
for
x
in
to_reset
:
for
x
in
to_reset
:
x
[
0
]
=
None
x
[
0
]
=
None
pre
(
self
,
[
input
.
data
for
input
in
input_lists
[
0
]],
order
,
thunk_groups
)
pre
(
self
,
[
input
.
data
for
input
in
input_lists
[
0
]],
order
,
thunk_groups
)
for
i
,
(
thunks
,
node
)
in
enumerate
(
zip
(
thunk_groups
,
order
,
strict
=
True
)):
# strict=False because we are in a hot loop
for
i
,
(
thunks
,
node
)
in
enumerate
(
zip
(
thunk_groups
,
order
,
strict
=
False
)):
try
:
try
:
wrapper
(
self
.
fgraph
,
i
,
node
,
*
thunks
)
wrapper
(
self
.
fgraph
,
i
,
node
,
*
thunks
)
except
Exception
:
except
Exception
:
...
@@ -666,8 +668,9 @@ class JITLinker(PerformLinker):
...
@@ -666,8 +668,9 @@ class JITLinker(PerformLinker):
):
):
outputs
=
fgraph_jit
(
*
[
self
.
input_filter
(
x
[
0
])
for
x
in
thunk_inputs
])
outputs
=
fgraph_jit
(
*
[
self
.
input_filter
(
x
[
0
])
for
x
in
thunk_inputs
])
# strict=False because we are in a hot loop
for
o_var
,
o_storage
,
o_val
in
zip
(
for
o_var
,
o_storage
,
o_val
in
zip
(
fgraph
.
outputs
,
thunk_outputs
,
outputs
,
strict
=
Tru
e
fgraph
.
outputs
,
thunk_outputs
,
outputs
,
strict
=
Fals
e
):
):
compute_map
[
o_var
][
0
]
=
True
compute_map
[
o_var
][
0
]
=
True
o_storage
[
0
]
=
self
.
output_filter
(
o_var
,
o_val
)
o_storage
[
0
]
=
self
.
output_filter
(
o_var
,
o_val
)
...
...
pytensor/link/c/basic.py
浏览文件 @
4b41e092
...
@@ -1993,25 +1993,26 @@ class DualLinker(Linker):
...
@@ -1993,25 +1993,26 @@ class DualLinker(Linker):
)
)
def
f
():
def
f
():
for
input1
,
input2
in
zip
(
i1
,
i2
,
strict
=
True
):
# strict=False because we are in a hot loop
for
input1
,
input2
in
zip
(
i1
,
i2
,
strict
=
False
):
# Set the inputs to be the same in both branches.
# Set the inputs to be the same in both branches.
# The copy is necessary in order for inplace ops not to
# The copy is necessary in order for inplace ops not to
# interfere.
# interfere.
input2
.
storage
[
0
]
=
copy
(
input1
.
storage
[
0
])
input2
.
storage
[
0
]
=
copy
(
input1
.
storage
[
0
])
for
thunk1
,
thunk2
,
node1
,
node2
in
zip
(
for
thunk1
,
thunk2
,
node1
,
node2
in
zip
(
thunks1
,
thunks2
,
order1
,
order2
,
strict
=
Tru
e
thunks1
,
thunks2
,
order1
,
order2
,
strict
=
Fals
e
):
):
for
output
,
storage
in
zip
(
node1
.
outputs
,
thunk1
.
outputs
,
strict
=
Tru
e
):
for
output
,
storage
in
zip
(
node1
.
outputs
,
thunk1
.
outputs
,
strict
=
Fals
e
):
if
output
in
no_recycling
:
if
output
in
no_recycling
:
storage
[
0
]
=
None
storage
[
0
]
=
None
for
output
,
storage
in
zip
(
node2
.
outputs
,
thunk2
.
outputs
,
strict
=
Tru
e
):
for
output
,
storage
in
zip
(
node2
.
outputs
,
thunk2
.
outputs
,
strict
=
Fals
e
):
if
output
in
no_recycling
:
if
output
in
no_recycling
:
storage
[
0
]
=
None
storage
[
0
]
=
None
try
:
try
:
thunk1
()
thunk1
()
thunk2
()
thunk2
()
for
output1
,
output2
in
zip
(
for
output1
,
output2
in
zip
(
thunk1
.
outputs
,
thunk2
.
outputs
,
strict
=
Tru
e
thunk1
.
outputs
,
thunk2
.
outputs
,
strict
=
Fals
e
):
):
self
.
checker
(
output1
,
output2
)
self
.
checker
(
output1
,
output2
)
except
Exception
:
except
Exception
:
...
...
pytensor/link/numba/dispatch/basic.py
浏览文件 @
4b41e092
...
@@ -401,9 +401,10 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
...
@@ -401,9 +401,10 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
else
:
else
:
def
py_perform_return
(
inputs
):
def
py_perform_return
(
inputs
):
# strict=False because we are in a hot loop
return
tuple
(
return
tuple
(
out_type
.
filter
(
out
[
0
])
out_type
.
filter
(
out
[
0
])
for
out_type
,
out
in
zip
(
output_types
,
py_perform
(
inputs
),
strict
=
Tru
e
)
for
out_type
,
out
in
zip
(
output_types
,
py_perform
(
inputs
),
strict
=
Fals
e
)
)
)
@numba_njit
@numba_njit
...
...
pytensor/link/pytorch/dispatch/shape.py
浏览文件 @
4b41e092
...
@@ -34,7 +34,8 @@ def pytorch_funcify_Shape_i(op, **kwargs):
...
@@ -34,7 +34,8 @@ def pytorch_funcify_Shape_i(op, **kwargs):
def
pytorch_funcify_SpecifyShape
(
op
,
node
,
**
kwargs
):
def
pytorch_funcify_SpecifyShape
(
op
,
node
,
**
kwargs
):
def
specifyshape
(
x
,
*
shape
):
def
specifyshape
(
x
,
*
shape
):
assert
x
.
ndim
==
len
(
shape
)
assert
x
.
ndim
==
len
(
shape
)
for
actual
,
expected
in
zip
(
x
.
shape
,
shape
,
strict
=
True
):
# strict=False because asserted above
for
actual
,
expected
in
zip
(
x
.
shape
,
shape
,
strict
=
False
):
if
expected
is
None
:
if
expected
is
None
:
continue
continue
if
actual
!=
expected
:
if
actual
!=
expected
:
...
...
pytensor/link/utils.py
浏览文件 @
4b41e092
...
@@ -190,8 +190,9 @@ def streamline(
...
@@ -190,8 +190,9 @@ def streamline(
for
x
in
no_recycling
:
for
x
in
no_recycling
:
x
[
0
]
=
None
x
[
0
]
=
None
try
:
try
:
# strict=False because we are in a hot loop
for
thunk
,
node
,
old_storage
in
zip
(
for
thunk
,
node
,
old_storage
in
zip
(
thunks
,
order
,
post_thunk_old_storage
,
strict
=
Tru
e
thunks
,
order
,
post_thunk_old_storage
,
strict
=
Fals
e
):
):
thunk
()
thunk
()
for
old_s
in
old_storage
:
for
old_s
in
old_storage
:
...
@@ -206,7 +207,8 @@ def streamline(
...
@@ -206,7 +207,8 @@ def streamline(
for
x
in
no_recycling
:
for
x
in
no_recycling
:
x
[
0
]
=
None
x
[
0
]
=
None
try
:
try
:
for
thunk
,
node
in
zip
(
thunks
,
order
,
strict
=
True
):
# strict=False because we are in a hot loop
for
thunk
,
node
in
zip
(
thunks
,
order
,
strict
=
False
):
thunk
()
thunk
()
except
Exception
:
except
Exception
:
raise_with_op
(
fgraph
,
node
,
thunk
)
raise_with_op
(
fgraph
,
node
,
thunk
)
...
...
pytensor/scalar/basic.py
浏览文件 @
4b41e092
...
@@ -1150,8 +1150,9 @@ class ScalarOp(COp):
...
@@ -1150,8 +1150,9 @@ class ScalarOp(COp):
else
:
else
:
variables
=
from_return_values
(
self
.
impl
(
*
inputs
))
variables
=
from_return_values
(
self
.
impl
(
*
inputs
))
assert
len
(
variables
)
==
len
(
output_storage
)
assert
len
(
variables
)
==
len
(
output_storage
)
# strict=False because we are in a hot loop
for
out
,
storage
,
variable
in
zip
(
for
out
,
storage
,
variable
in
zip
(
node
.
outputs
,
output_storage
,
variables
,
strict
=
Tru
e
node
.
outputs
,
output_storage
,
variables
,
strict
=
Fals
e
):
):
dtype
=
out
.
dtype
dtype
=
out
.
dtype
storage
[
0
]
=
self
.
_cast_scalar
(
variable
,
dtype
)
storage
[
0
]
=
self
.
_cast_scalar
(
variable
,
dtype
)
...
@@ -4328,7 +4329,8 @@ class Composite(ScalarInnerGraphOp):
...
@@ -4328,7 +4329,8 @@ class Composite(ScalarInnerGraphOp):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
outputs
=
self
.
py_perform_fn
(
*
inputs
)
outputs
=
self
.
py_perform_fn
(
*
inputs
)
for
storage
,
out_val
in
zip
(
output_storage
,
outputs
,
strict
=
True
):
# strict=False because we are in a hot loop
for
storage
,
out_val
in
zip
(
output_storage
,
outputs
,
strict
=
False
):
storage
[
0
]
=
out_val
storage
[
0
]
=
out_val
def
grad
(
self
,
inputs
,
output_grads
):
def
grad
(
self
,
inputs
,
output_grads
):
...
...
pytensor/scalar/loop.py
浏览文件 @
4b41e092
...
@@ -93,7 +93,7 @@ class ScalarLoop(ScalarInnerGraphOp):
...
@@ -93,7 +93,7 @@ class ScalarLoop(ScalarInnerGraphOp):
)
)
else
:
else
:
update
=
outputs
update
=
outputs
for
i
,
u
in
zip
(
init
[:
len
(
update
)],
update
,
strict
=
Tru
e
):
for
i
,
u
in
zip
(
init
,
update
,
strict
=
Fals
e
):
if
i
.
type
!=
u
.
type
:
if
i
.
type
!=
u
.
type
:
raise
TypeError
(
raise
TypeError
(
"Init and update types must be the same: "
"Init and update types must be the same: "
...
@@ -207,7 +207,8 @@ class ScalarLoop(ScalarInnerGraphOp):
...
@@ -207,7 +207,8 @@ class ScalarLoop(ScalarInnerGraphOp):
for
i
in
range
(
n_steps
):
for
i
in
range
(
n_steps
):
carry
=
inner_fn
(
*
carry
,
*
constant
)
carry
=
inner_fn
(
*
carry
,
*
constant
)
for
storage
,
out_val
in
zip
(
output_storage
,
carry
,
strict
=
True
):
# strict=False because we are in a hot loop
for
storage
,
out_val
in
zip
(
output_storage
,
carry
,
strict
=
False
):
storage
[
0
]
=
out_val
storage
[
0
]
=
out_val
@property
@property
...
...
pytensor/scan/op.py
浏览文件 @
4b41e092
...
@@ -1278,8 +1278,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1278,8 +1278,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if
len
(
self
.
inner_outputs
)
!=
len
(
other
.
inner_outputs
):
if
len
(
self
.
inner_outputs
)
!=
len
(
other
.
inner_outputs
):
return
False
return
False
# strict=False because length already compared above
for
self_in
,
other_in
in
zip
(
for
self_in
,
other_in
in
zip
(
self
.
inner_inputs
,
other
.
inner_inputs
,
strict
=
Tru
e
self
.
inner_inputs
,
other
.
inner_inputs
,
strict
=
Fals
e
):
):
if
self_in
.
type
!=
other_in
.
type
:
if
self_in
.
type
!=
other_in
.
type
:
return
False
return
False
...
...
pytensor/tensor/basic.py
浏览文件 @
4b41e092
...
@@ -3463,7 +3463,8 @@ class PermuteRowElements(Op):
...
@@ -3463,7 +3463,8 @@ class PermuteRowElements(Op):
# Make sure the output is big enough
# Make sure the output is big enough
out_s
=
[]
out_s
=
[]
for
xdim
,
ydim
in
zip
(
x_s
,
y_s
,
strict
=
True
):
# strict=False because we are in a hot loop
for
xdim
,
ydim
in
zip
(
x_s
,
y_s
,
strict
=
False
):
if
xdim
==
ydim
:
if
xdim
==
ydim
:
outdim
=
xdim
outdim
=
xdim
elif
xdim
==
1
:
elif
xdim
==
1
:
...
...
pytensor/tensor/blockwise.py
浏览文件 @
4b41e092
...
@@ -342,16 +342,17 @@ class Blockwise(Op):
...
@@ -342,16 +342,17 @@ class Blockwise(Op):
def
_check_runtime_broadcast
(
self
,
node
,
inputs
):
def
_check_runtime_broadcast
(
self
,
node
,
inputs
):
batch_ndim
=
self
.
batch_ndim
(
node
)
batch_ndim
=
self
.
batch_ndim
(
node
)
# strict=False because we are in a hot loop
for
dims_and_bcast
in
zip
(
for
dims_and_bcast
in
zip
(
*
[
*
[
zip
(
zip
(
input
.
shape
[:
batch_ndim
],
input
.
shape
[:
batch_ndim
],
sinput
.
type
.
broadcastable
[:
batch_ndim
],
sinput
.
type
.
broadcastable
[:
batch_ndim
],
strict
=
Tru
e
,
strict
=
Fals
e
,
)
)
for
input
,
sinput
in
zip
(
inputs
,
node
.
inputs
,
strict
=
Tru
e
)
for
input
,
sinput
in
zip
(
inputs
,
node
.
inputs
,
strict
=
Fals
e
)
],
],
strict
=
Tru
e
,
strict
=
Fals
e
,
):
):
if
any
(
d
!=
1
for
d
,
_
in
dims_and_bcast
)
and
(
1
,
False
)
in
dims_and_bcast
:
if
any
(
d
!=
1
for
d
,
_
in
dims_and_bcast
)
and
(
1
,
False
)
in
dims_and_bcast
:
raise
ValueError
(
raise
ValueError
(
...
@@ -374,8 +375,9 @@ class Blockwise(Op):
...
@@ -374,8 +375,9 @@ class Blockwise(Op):
if
not
isinstance
(
res
,
tuple
):
if
not
isinstance
(
res
,
tuple
):
res
=
(
res
,)
res
=
(
res
,)
# strict=False because we are in a hot loop
for
node_out
,
out_storage
,
r
in
zip
(
for
node_out
,
out_storage
,
r
in
zip
(
node
.
outputs
,
output_storage
,
res
,
strict
=
Tru
e
node
.
outputs
,
output_storage
,
res
,
strict
=
Fals
e
):
):
out_dtype
=
getattr
(
node_out
,
"dtype"
,
None
)
out_dtype
=
getattr
(
node_out
,
"dtype"
,
None
)
if
out_dtype
and
out_dtype
!=
r
.
dtype
:
if
out_dtype
and
out_dtype
!=
r
.
dtype
:
...
...
pytensor/tensor/elemwise.py
浏览文件 @
4b41e092
...
@@ -737,8 +737,9 @@ class Elemwise(OpenMPOp):
...
@@ -737,8 +737,9 @@ class Elemwise(OpenMPOp):
if
nout
==
1
:
if
nout
==
1
:
variables
=
[
variables
]
variables
=
[
variables
]
# strict=False because we are in a hot loop
for
i
,
(
variable
,
storage
,
nout
)
in
enumerate
(
for
i
,
(
variable
,
storage
,
nout
)
in
enumerate
(
zip
(
variables
,
output_storage
,
node
.
outputs
,
strict
=
Tru
e
)
zip
(
variables
,
output_storage
,
node
.
outputs
,
strict
=
Fals
e
)
):
):
storage
[
0
]
=
variable
=
np
.
asarray
(
variable
,
dtype
=
nout
.
dtype
)
storage
[
0
]
=
variable
=
np
.
asarray
(
variable
,
dtype
=
nout
.
dtype
)
...
@@ -753,12 +754,13 @@ class Elemwise(OpenMPOp):
...
@@ -753,12 +754,13 @@ class Elemwise(OpenMPOp):
@staticmethod
@staticmethod
def
_check_runtime_broadcast
(
node
,
inputs
):
def
_check_runtime_broadcast
(
node
,
inputs
):
# strict=False because we are in a hot loop
for
dims_and_bcast
in
zip
(
for
dims_and_bcast
in
zip
(
*
[
*
[
zip
(
input
.
shape
,
sinput
.
type
.
broadcastable
,
strict
=
False
)
zip
(
input
.
shape
,
sinput
.
type
.
broadcastable
,
strict
=
False
)
for
input
,
sinput
in
zip
(
inputs
,
node
.
inputs
,
strict
=
Tru
e
)
for
input
,
sinput
in
zip
(
inputs
,
node
.
inputs
,
strict
=
Fals
e
)
],
],
strict
=
Tru
e
,
strict
=
Fals
e
,
):
):
if
any
(
d
!=
1
for
d
,
_
in
dims_and_bcast
)
and
(
1
,
False
)
in
dims_and_bcast
:
if
any
(
d
!=
1
for
d
,
_
in
dims_and_bcast
)
and
(
1
,
False
)
in
dims_and_bcast
:
raise
ValueError
(
raise
ValueError
(
...
...
pytensor/tensor/random/basic.py
浏览文件 @
4b41e092
...
@@ -1862,7 +1862,8 @@ class CategoricalRV(RandomVariable):
...
@@ -1862,7 +1862,8 @@ class CategoricalRV(RandomVariable):
# to `p.shape[:-1]` in the call to `vsearchsorted` below.
# to `p.shape[:-1]` in the call to `vsearchsorted` below.
if
len
(
size
)
<
(
p
.
ndim
-
1
):
if
len
(
size
)
<
(
p
.
ndim
-
1
):
raise
ValueError
(
"`size` is incompatible with the shape of `p`"
)
raise
ValueError
(
"`size` is incompatible with the shape of `p`"
)
for
s
,
ps
in
zip
(
reversed
(
size
),
reversed
(
p
.
shape
[:
-
1
]),
strict
=
True
):
# strict=False because we are in a hot loop
for
s
,
ps
in
zip
(
reversed
(
size
),
reversed
(
p
.
shape
[:
-
1
]),
strict
=
False
):
if
s
==
1
and
ps
!=
1
:
if
s
==
1
and
ps
!=
1
:
raise
ValueError
(
"`size` is incompatible with the shape of `p`"
)
raise
ValueError
(
"`size` is incompatible with the shape of `p`"
)
...
...
pytensor/tensor/random/utils.py
浏览文件 @
4b41e092
...
@@ -44,7 +44,8 @@ def params_broadcast_shapes(
...
@@ -44,7 +44,8 @@ def params_broadcast_shapes(
max_fn
=
maximum
if
use_pytensor
else
max
max_fn
=
maximum
if
use_pytensor
else
max
rev_extra_dims
:
list
[
int
]
=
[]
rev_extra_dims
:
list
[
int
]
=
[]
for
ndim_param
,
param_shape
in
zip
(
ndims_params
,
param_shapes
,
strict
=
True
):
# strict=False because we are in a hot loop
for
ndim_param
,
param_shape
in
zip
(
ndims_params
,
param_shapes
,
strict
=
False
):
# We need this in order to use `len`
# We need this in order to use `len`
param_shape
=
tuple
(
param_shape
)
param_shape
=
tuple
(
param_shape
)
extras
=
tuple
(
param_shape
[:
(
len
(
param_shape
)
-
ndim_param
)])
extras
=
tuple
(
param_shape
[:
(
len
(
param_shape
)
-
ndim_param
)])
...
@@ -63,11 +64,12 @@ def params_broadcast_shapes(
...
@@ -63,11 +64,12 @@ def params_broadcast_shapes(
extra_dims
=
tuple
(
reversed
(
rev_extra_dims
))
extra_dims
=
tuple
(
reversed
(
rev_extra_dims
))
# strict=False because we are in a hot loop
bcast_shapes
=
[
bcast_shapes
=
[
(
extra_dims
+
tuple
(
param_shape
)[
-
ndim_param
:])
(
extra_dims
+
tuple
(
param_shape
)[
-
ndim_param
:])
if
ndim_param
>
0
if
ndim_param
>
0
else
extra_dims
else
extra_dims
for
ndim_param
,
param_shape
in
zip
(
ndims_params
,
param_shapes
,
strict
=
Tru
e
)
for
ndim_param
,
param_shape
in
zip
(
ndims_params
,
param_shapes
,
strict
=
Fals
e
)
]
]
return
bcast_shapes
return
bcast_shapes
...
@@ -110,10 +112,11 @@ def broadcast_params(
...
@@ -110,10 +112,11 @@ def broadcast_params(
use_pytensor
=
False
use_pytensor
=
False
param_shapes
=
[]
param_shapes
=
[]
for
p
in
params
:
for
p
in
params
:
# strict=False because we are in a hot loop
param_shape
=
tuple
(
param_shape
=
tuple
(
1
if
bcast
else
s
1
if
bcast
else
s
for
s
,
bcast
in
zip
(
for
s
,
bcast
in
zip
(
p
.
shape
,
getattr
(
p
,
"broadcastable"
,
(
False
,)
*
p
.
ndim
),
strict
=
Tru
e
p
.
shape
,
getattr
(
p
,
"broadcastable"
,
(
False
,)
*
p
.
ndim
),
strict
=
Fals
e
)
)
)
)
use_pytensor
|=
isinstance
(
p
,
Variable
)
use_pytensor
|=
isinstance
(
p
,
Variable
)
...
@@ -124,9 +127,10 @@ def broadcast_params(
...
@@ -124,9 +127,10 @@ def broadcast_params(
)
)
broadcast_to_fn
=
broadcast_to
if
use_pytensor
else
np
.
broadcast_to
broadcast_to_fn
=
broadcast_to
if
use_pytensor
else
np
.
broadcast_to
# strict=False because we are in a hot loop
bcast_params
=
[
bcast_params
=
[
broadcast_to_fn
(
param
,
shape
)
broadcast_to_fn
(
param
,
shape
)
for
shape
,
param
in
zip
(
shapes
,
params
,
strict
=
Tru
e
)
for
shape
,
param
in
zip
(
shapes
,
params
,
strict
=
Fals
e
)
]
]
return
bcast_params
return
bcast_params
...
...
pytensor/tensor/rewriting/subtensor.py
浏览文件 @
4b41e092
...
@@ -683,7 +683,7 @@ def local_subtensor_of_alloc(fgraph, node):
...
@@ -683,7 +683,7 @@ def local_subtensor_of_alloc(fgraph, node):
# Slices to take from val
# Slices to take from val
val_slices
=
[]
val_slices
=
[]
for
i
,
(
sl
,
dim
)
in
enumerate
(
zip
(
slices
,
dims
[:
len
(
slices
)],
strict
=
Tru
e
)):
for
i
,
(
sl
,
dim
)
in
enumerate
(
zip
(
slices
,
dims
,
strict
=
Fals
e
)):
# If val was not copied over that dim,
# If val was not copied over that dim,
# we need to take the appropriate subtensor on it.
# we need to take the appropriate subtensor on it.
if
i
>=
n_added_dims
:
if
i
>=
n_added_dims
:
...
...
pytensor/tensor/shape.py
浏览文件 @
4b41e092
...
@@ -448,8 +448,9 @@ class SpecifyShape(COp):
...
@@ -448,8 +448,9 @@ class SpecifyShape(COp):
raise
AssertionError
(
raise
AssertionError
(
f
"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
f
"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
)
)
# strict=False because we are in a hot loop
if
not
all
(
if
not
all
(
xs
==
s
for
xs
,
s
in
zip
(
x
.
shape
,
shape
,
strict
=
Tru
e
)
if
s
is
not
None
xs
==
s
for
xs
,
s
in
zip
(
x
.
shape
,
shape
,
strict
=
Fals
e
)
if
s
is
not
None
):
):
raise
AssertionError
(
raise
AssertionError
(
f
"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
f
"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
...
@@ -578,15 +579,12 @@ def specify_shape(
...
@@ -578,15 +579,12 @@ def specify_shape(
x
=
ptb
.
as_tensor_variable
(
x
)
# type: ignore[arg-type,unused-ignore]
x
=
ptb
.
as_tensor_variable
(
x
)
# type: ignore[arg-type,unused-ignore]
# The above is a type error in Python 3.9 but not 3.12.
# The above is a type error in Python 3.9 but not 3.12.
# Thus we need to ignore unused-ignore on 3.12.
# Thus we need to ignore unused-ignore on 3.12.
new_shape_info
=
any
(
s
!=
xts
for
(
s
,
xts
)
in
zip
(
shape
,
x
.
type
.
shape
,
strict
=
False
)
if
s
is
not
None
)
# If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
# If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
if
len
(
shape
)
!=
x
.
type
.
ndim
:
if
not
new_shape_info
and
len
(
shape
)
==
x
.
type
.
ndim
:
return
_specify_shape
(
x
,
*
shape
)
new_shape_matches
=
all
(
s
==
xts
for
(
s
,
xts
)
in
zip
(
shape
,
x
.
type
.
shape
,
strict
=
True
)
if
s
is
not
None
)
if
new_shape_matches
:
return
x
return
x
return
_specify_shape
(
x
,
*
shape
)
return
_specify_shape
(
x
,
*
shape
)
...
...
pytensor/tensor/type.py
浏览文件 @
4b41e092
...
@@ -248,9 +248,10 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
...
@@ -248,9 +248,10 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
" PyTensor C code does not support that."
,
" PyTensor C code does not support that."
,
)
)
# strict=False because we are in a hot loop
if
not
all
(
if
not
all
(
ds
==
ts
if
ts
is
not
None
else
True
ds
==
ts
if
ts
is
not
None
else
True
for
ds
,
ts
in
zip
(
data
.
shape
,
self
.
shape
,
strict
=
Tru
e
)
for
ds
,
ts
in
zip
(
data
.
shape
,
self
.
shape
,
strict
=
Fals
e
)
):
):
raise
TypeError
(
raise
TypeError
(
f
"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})"
f
"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})"
...
@@ -319,6 +320,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
...
@@ -319,6 +320,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
return
False
return
False
def
is_super
(
self
,
otype
):
def
is_super
(
self
,
otype
):
# strict=False because we are in a hot loop
if
(
if
(
isinstance
(
otype
,
type
(
self
))
isinstance
(
otype
,
type
(
self
))
and
otype
.
dtype
==
self
.
dtype
and
otype
.
dtype
==
self
.
dtype
...
@@ -327,7 +329,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
...
@@ -327,7 +329,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
# but not less
# but not less
and
all
(
and
all
(
sb
==
ob
or
sb
is
None
sb
==
ob
or
sb
is
None
for
sb
,
ob
in
zip
(
self
.
shape
,
otype
.
shape
,
strict
=
Tru
e
)
for
sb
,
ob
in
zip
(
self
.
shape
,
otype
.
shape
,
strict
=
Fals
e
)
)
)
):
):
return
True
return
True
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论