Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
4b41e092
提交
4b41e092
authored
11月 19, 2024
作者:
Virgile Andreani
提交者:
Ricardo Vieira
11月 19, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add exceptions for hot loops
上级
54fba943
显示空白字符变更
内嵌
并排
正在显示
19 个修改的文件
包含
77 行增加
和
48 行删除
+77
-48
builders.py
pytensor/compile/builders.py
+2
-1
types.py
pytensor/compile/function/types.py
+8
-4
ifelse.py
pytensor/ifelse.py
+4
-2
basic.py
pytensor/link/basic.py
+6
-3
basic.py
pytensor/link/c/basic.py
+6
-5
basic.py
pytensor/link/numba/dispatch/basic.py
+2
-1
shape.py
pytensor/link/pytorch/dispatch/shape.py
+2
-1
utils.py
pytensor/link/utils.py
+4
-2
basic.py
pytensor/scalar/basic.py
+4
-2
loop.py
pytensor/scalar/loop.py
+3
-2
op.py
pytensor/scan/op.py
+2
-1
basic.py
pytensor/tensor/basic.py
+2
-1
blockwise.py
pytensor/tensor/blockwise.py
+6
-4
elemwise.py
pytensor/tensor/elemwise.py
+5
-3
basic.py
pytensor/tensor/random/basic.py
+2
-1
utils.py
pytensor/tensor/random/utils.py
+8
-4
subtensor.py
pytensor/tensor/rewriting/subtensor.py
+1
-1
shape.py
pytensor/tensor/shape.py
+6
-8
type.py
pytensor/tensor/type.py
+4
-2
没有找到文件。
pytensor/compile/builders.py
浏览文件 @
4b41e092
...
...
@@ -863,5 +863,6 @@ class OpFromGraph(Op, HasInnerGraph):
def
perform
(
self
,
node
,
inputs
,
outputs
):
variables
=
self
.
fn
(
*
inputs
)
assert
len
(
variables
)
==
len
(
outputs
)
for
output
,
variable
in
zip
(
outputs
,
variables
,
strict
=
True
):
# strict=False because asserted above
for
output
,
variable
in
zip
(
outputs
,
variables
,
strict
=
False
):
output
[
0
]
=
variable
pytensor/compile/function/types.py
浏览文件 @
4b41e092
...
...
@@ -1002,8 +1002,9 @@ class Function:
# if we are allowing garbage collection, remove the
# output reference from the internal storage cells
if
getattr
(
self
.
vm
,
"allow_gc"
,
False
):
# strict=False because we are in a hot loop
for
o_container
,
o_variable
in
zip
(
self
.
output_storage
,
self
.
maker
.
fgraph
.
outputs
,
strict
=
Tru
e
self
.
output_storage
,
self
.
maker
.
fgraph
.
outputs
,
strict
=
Fals
e
):
if
o_variable
.
owner
is
not
None
:
# this node is the variable of computation
...
...
@@ -1012,8 +1013,9 @@ class Function:
if
getattr
(
self
.
vm
,
"need_update_inputs"
,
True
):
# Update the inputs that have an update function
# strict=False because we are in a hot loop
for
input
,
storage
in
reversed
(
list
(
zip
(
self
.
maker
.
expanded_inputs
,
input_storage
,
strict
=
Tru
e
))
list
(
zip
(
self
.
maker
.
expanded_inputs
,
input_storage
,
strict
=
Fals
e
))
):
if
input
.
update
is
not
None
:
storage
.
data
=
outputs
.
pop
()
...
...
@@ -1044,7 +1046,8 @@ class Function:
assert
len
(
self
.
output_keys
)
==
len
(
outputs
)
if
output_subset
is
None
:
return
dict
(
zip
(
self
.
output_keys
,
outputs
,
strict
=
True
))
# strict=False because we are in a hot loop
return
dict
(
zip
(
self
.
output_keys
,
outputs
,
strict
=
False
))
else
:
return
{
self
.
output_keys
[
index
]:
outputs
[
index
]
...
...
@@ -1111,8 +1114,9 @@ def _pickle_Function(f):
ins
=
list
(
f
.
input_storage
)
input_storage
=
[]
# strict=False because we are in a hot loop
for
(
input
,
indices
,
inputs
),
(
required
,
refeed
,
default
)
in
zip
(
f
.
indices
,
f
.
defaults
,
strict
=
Tru
e
f
.
indices
,
f
.
defaults
,
strict
=
Fals
e
):
input_storage
.
append
(
ins
[
0
])
del
ins
[
0
]
...
...
pytensor/ifelse.py
浏览文件 @
4b41e092
...
...
@@ -305,7 +305,8 @@ class IfElse(_NoPythonOp):
if
len
(
ls
)
>
0
:
return
ls
else
:
for
out
,
t
in
zip
(
outputs
,
input_true_branch
,
strict
=
True
):
# strict=False because we are in a hot loop
for
out
,
t
in
zip
(
outputs
,
input_true_branch
,
strict
=
False
):
compute_map
[
out
][
0
]
=
1
val
=
storage_map
[
t
][
0
]
if
self
.
as_view
:
...
...
@@ -325,7 +326,8 @@ class IfElse(_NoPythonOp):
if
len
(
ls
)
>
0
:
return
ls
else
:
for
out
,
f
in
zip
(
outputs
,
inputs_false_branch
,
strict
=
True
):
# strict=False because we are in a hot loop
for
out
,
f
in
zip
(
outputs
,
inputs_false_branch
,
strict
=
False
):
compute_map
[
out
][
0
]
=
1
# can't view both outputs unless destroyhandler
# improves
...
...
pytensor/link/basic.py
浏览文件 @
4b41e092
...
...
@@ -539,12 +539,14 @@ class WrapLinker(Linker):
def
f
():
for
inputs
in
input_lists
[
1
:]:
for
input1
,
input2
in
zip
(
inputs0
,
inputs
,
strict
=
True
):
# strict=False because we are in a hot loop
for
input1
,
input2
in
zip
(
inputs0
,
inputs
,
strict
=
False
):
input2
.
storage
[
0
]
=
copy
(
input1
.
storage
[
0
])
for
x
in
to_reset
:
x
[
0
]
=
None
pre
(
self
,
[
input
.
data
for
input
in
input_lists
[
0
]],
order
,
thunk_groups
)
for
i
,
(
thunks
,
node
)
in
enumerate
(
zip
(
thunk_groups
,
order
,
strict
=
True
)):
# strict=False because we are in a hot loop
for
i
,
(
thunks
,
node
)
in
enumerate
(
zip
(
thunk_groups
,
order
,
strict
=
False
)):
try
:
wrapper
(
self
.
fgraph
,
i
,
node
,
*
thunks
)
except
Exception
:
...
...
@@ -666,8 +668,9 @@ class JITLinker(PerformLinker):
):
outputs
=
fgraph_jit
(
*
[
self
.
input_filter
(
x
[
0
])
for
x
in
thunk_inputs
])
# strict=False because we are in a hot loop
for
o_var
,
o_storage
,
o_val
in
zip
(
fgraph
.
outputs
,
thunk_outputs
,
outputs
,
strict
=
Tru
e
fgraph
.
outputs
,
thunk_outputs
,
outputs
,
strict
=
Fals
e
):
compute_map
[
o_var
][
0
]
=
True
o_storage
[
0
]
=
self
.
output_filter
(
o_var
,
o_val
)
...
...
pytensor/link/c/basic.py
浏览文件 @
4b41e092
...
...
@@ -1993,25 +1993,26 @@ class DualLinker(Linker):
)
def
f
():
for
input1
,
input2
in
zip
(
i1
,
i2
,
strict
=
True
):
# strict=False because we are in a hot loop
for
input1
,
input2
in
zip
(
i1
,
i2
,
strict
=
False
):
# Set the inputs to be the same in both branches.
# The copy is necessary in order for inplace ops not to
# interfere.
input2
.
storage
[
0
]
=
copy
(
input1
.
storage
[
0
])
for
thunk1
,
thunk2
,
node1
,
node2
in
zip
(
thunks1
,
thunks2
,
order1
,
order2
,
strict
=
Tru
e
thunks1
,
thunks2
,
order1
,
order2
,
strict
=
Fals
e
):
for
output
,
storage
in
zip
(
node1
.
outputs
,
thunk1
.
outputs
,
strict
=
Tru
e
):
for
output
,
storage
in
zip
(
node1
.
outputs
,
thunk1
.
outputs
,
strict
=
Fals
e
):
if
output
in
no_recycling
:
storage
[
0
]
=
None
for
output
,
storage
in
zip
(
node2
.
outputs
,
thunk2
.
outputs
,
strict
=
Tru
e
):
for
output
,
storage
in
zip
(
node2
.
outputs
,
thunk2
.
outputs
,
strict
=
Fals
e
):
if
output
in
no_recycling
:
storage
[
0
]
=
None
try
:
thunk1
()
thunk2
()
for
output1
,
output2
in
zip
(
thunk1
.
outputs
,
thunk2
.
outputs
,
strict
=
Tru
e
thunk1
.
outputs
,
thunk2
.
outputs
,
strict
=
Fals
e
):
self
.
checker
(
output1
,
output2
)
except
Exception
:
...
...
pytensor/link/numba/dispatch/basic.py
浏览文件 @
4b41e092
...
...
@@ -401,9 +401,10 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
else
:
def
py_perform_return
(
inputs
):
# strict=False because we are in a hot loop
return
tuple
(
out_type
.
filter
(
out
[
0
])
for
out_type
,
out
in
zip
(
output_types
,
py_perform
(
inputs
),
strict
=
Tru
e
)
for
out_type
,
out
in
zip
(
output_types
,
py_perform
(
inputs
),
strict
=
Fals
e
)
)
@numba_njit
...
...
pytensor/link/pytorch/dispatch/shape.py
浏览文件 @
4b41e092
...
...
@@ -34,7 +34,8 @@ def pytorch_funcify_Shape_i(op, **kwargs):
def
pytorch_funcify_SpecifyShape
(
op
,
node
,
**
kwargs
):
def
specifyshape
(
x
,
*
shape
):
assert
x
.
ndim
==
len
(
shape
)
for
actual
,
expected
in
zip
(
x
.
shape
,
shape
,
strict
=
True
):
# strict=False because asserted above
for
actual
,
expected
in
zip
(
x
.
shape
,
shape
,
strict
=
False
):
if
expected
is
None
:
continue
if
actual
!=
expected
:
...
...
pytensor/link/utils.py
浏览文件 @
4b41e092
...
...
@@ -190,8 +190,9 @@ def streamline(
for
x
in
no_recycling
:
x
[
0
]
=
None
try
:
# strict=False because we are in a hot loop
for
thunk
,
node
,
old_storage
in
zip
(
thunks
,
order
,
post_thunk_old_storage
,
strict
=
Tru
e
thunks
,
order
,
post_thunk_old_storage
,
strict
=
Fals
e
):
thunk
()
for
old_s
in
old_storage
:
...
...
@@ -206,7 +207,8 @@ def streamline(
for
x
in
no_recycling
:
x
[
0
]
=
None
try
:
for
thunk
,
node
in
zip
(
thunks
,
order
,
strict
=
True
):
# strict=False because we are in a hot loop
for
thunk
,
node
in
zip
(
thunks
,
order
,
strict
=
False
):
thunk
()
except
Exception
:
raise_with_op
(
fgraph
,
node
,
thunk
)
...
...
pytensor/scalar/basic.py
浏览文件 @
4b41e092
...
...
@@ -1150,8 +1150,9 @@ class ScalarOp(COp):
else
:
variables
=
from_return_values
(
self
.
impl
(
*
inputs
))
assert
len
(
variables
)
==
len
(
output_storage
)
# strict=False because we are in a hot loop
for
out
,
storage
,
variable
in
zip
(
node
.
outputs
,
output_storage
,
variables
,
strict
=
Tru
e
node
.
outputs
,
output_storage
,
variables
,
strict
=
Fals
e
):
dtype
=
out
.
dtype
storage
[
0
]
=
self
.
_cast_scalar
(
variable
,
dtype
)
...
...
@@ -4328,7 +4329,8 @@ class Composite(ScalarInnerGraphOp):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
outputs
=
self
.
py_perform_fn
(
*
inputs
)
for
storage
,
out_val
in
zip
(
output_storage
,
outputs
,
strict
=
True
):
# strict=False because we are in a hot loop
for
storage
,
out_val
in
zip
(
output_storage
,
outputs
,
strict
=
False
):
storage
[
0
]
=
out_val
def
grad
(
self
,
inputs
,
output_grads
):
...
...
pytensor/scalar/loop.py
浏览文件 @
4b41e092
...
...
@@ -93,7 +93,7 @@ class ScalarLoop(ScalarInnerGraphOp):
)
else
:
update
=
outputs
for
i
,
u
in
zip
(
init
[:
len
(
update
)],
update
,
strict
=
Tru
e
):
for
i
,
u
in
zip
(
init
,
update
,
strict
=
Fals
e
):
if
i
.
type
!=
u
.
type
:
raise
TypeError
(
"Init and update types must be the same: "
...
...
@@ -207,7 +207,8 @@ class ScalarLoop(ScalarInnerGraphOp):
for
i
in
range
(
n_steps
):
carry
=
inner_fn
(
*
carry
,
*
constant
)
for
storage
,
out_val
in
zip
(
output_storage
,
carry
,
strict
=
True
):
# strict=False because we are in a hot loop
for
storage
,
out_val
in
zip
(
output_storage
,
carry
,
strict
=
False
):
storage
[
0
]
=
out_val
@property
...
...
pytensor/scan/op.py
浏览文件 @
4b41e092
...
...
@@ -1278,8 +1278,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if
len
(
self
.
inner_outputs
)
!=
len
(
other
.
inner_outputs
):
return
False
# strict=False because length already compared above
for
self_in
,
other_in
in
zip
(
self
.
inner_inputs
,
other
.
inner_inputs
,
strict
=
Tru
e
self
.
inner_inputs
,
other
.
inner_inputs
,
strict
=
Fals
e
):
if
self_in
.
type
!=
other_in
.
type
:
return
False
...
...
pytensor/tensor/basic.py
浏览文件 @
4b41e092
...
...
@@ -3463,7 +3463,8 @@ class PermuteRowElements(Op):
# Make sure the output is big enough
out_s
=
[]
for
xdim
,
ydim
in
zip
(
x_s
,
y_s
,
strict
=
True
):
# strict=False because we are in a hot loop
for
xdim
,
ydim
in
zip
(
x_s
,
y_s
,
strict
=
False
):
if
xdim
==
ydim
:
outdim
=
xdim
elif
xdim
==
1
:
...
...
pytensor/tensor/blockwise.py
浏览文件 @
4b41e092
...
...
@@ -342,16 +342,17 @@ class Blockwise(Op):
def
_check_runtime_broadcast
(
self
,
node
,
inputs
):
batch_ndim
=
self
.
batch_ndim
(
node
)
# strict=False because we are in a hot loop
for
dims_and_bcast
in
zip
(
*
[
zip
(
input
.
shape
[:
batch_ndim
],
sinput
.
type
.
broadcastable
[:
batch_ndim
],
strict
=
Tru
e
,
strict
=
Fals
e
,
)
for
input
,
sinput
in
zip
(
inputs
,
node
.
inputs
,
strict
=
Tru
e
)
for
input
,
sinput
in
zip
(
inputs
,
node
.
inputs
,
strict
=
Fals
e
)
],
strict
=
Tru
e
,
strict
=
Fals
e
,
):
if
any
(
d
!=
1
for
d
,
_
in
dims_and_bcast
)
and
(
1
,
False
)
in
dims_and_bcast
:
raise
ValueError
(
...
...
@@ -374,8 +375,9 @@ class Blockwise(Op):
if
not
isinstance
(
res
,
tuple
):
res
=
(
res
,)
# strict=False because we are in a hot loop
for
node_out
,
out_storage
,
r
in
zip
(
node
.
outputs
,
output_storage
,
res
,
strict
=
Tru
e
node
.
outputs
,
output_storage
,
res
,
strict
=
Fals
e
):
out_dtype
=
getattr
(
node_out
,
"dtype"
,
None
)
if
out_dtype
and
out_dtype
!=
r
.
dtype
:
...
...
pytensor/tensor/elemwise.py
浏览文件 @
4b41e092
...
...
@@ -737,8 +737,9 @@ class Elemwise(OpenMPOp):
if
nout
==
1
:
variables
=
[
variables
]
# strict=False because we are in a hot loop
for
i
,
(
variable
,
storage
,
nout
)
in
enumerate
(
zip
(
variables
,
output_storage
,
node
.
outputs
,
strict
=
Tru
e
)
zip
(
variables
,
output_storage
,
node
.
outputs
,
strict
=
Fals
e
)
):
storage
[
0
]
=
variable
=
np
.
asarray
(
variable
,
dtype
=
nout
.
dtype
)
...
...
@@ -753,12 +754,13 @@ class Elemwise(OpenMPOp):
@staticmethod
def
_check_runtime_broadcast
(
node
,
inputs
):
# strict=False because we are in a hot loop
for
dims_and_bcast
in
zip
(
*
[
zip
(
input
.
shape
,
sinput
.
type
.
broadcastable
,
strict
=
False
)
for
input
,
sinput
in
zip
(
inputs
,
node
.
inputs
,
strict
=
Tru
e
)
for
input
,
sinput
in
zip
(
inputs
,
node
.
inputs
,
strict
=
Fals
e
)
],
strict
=
Tru
e
,
strict
=
Fals
e
,
):
if
any
(
d
!=
1
for
d
,
_
in
dims_and_bcast
)
and
(
1
,
False
)
in
dims_and_bcast
:
raise
ValueError
(
...
...
pytensor/tensor/random/basic.py
浏览文件 @
4b41e092
...
...
@@ -1862,7 +1862,8 @@ class CategoricalRV(RandomVariable):
# to `p.shape[:-1]` in the call to `vsearchsorted` below.
if
len
(
size
)
<
(
p
.
ndim
-
1
):
raise
ValueError
(
"`size` is incompatible with the shape of `p`"
)
for
s
,
ps
in
zip
(
reversed
(
size
),
reversed
(
p
.
shape
[:
-
1
]),
strict
=
True
):
# strict=False because we are in a hot loop
for
s
,
ps
in
zip
(
reversed
(
size
),
reversed
(
p
.
shape
[:
-
1
]),
strict
=
False
):
if
s
==
1
and
ps
!=
1
:
raise
ValueError
(
"`size` is incompatible with the shape of `p`"
)
...
...
pytensor/tensor/random/utils.py
浏览文件 @
4b41e092
...
...
@@ -44,7 +44,8 @@ def params_broadcast_shapes(
max_fn
=
maximum
if
use_pytensor
else
max
rev_extra_dims
:
list
[
int
]
=
[]
for
ndim_param
,
param_shape
in
zip
(
ndims_params
,
param_shapes
,
strict
=
True
):
# strict=False because we are in a hot loop
for
ndim_param
,
param_shape
in
zip
(
ndims_params
,
param_shapes
,
strict
=
False
):
# We need this in order to use `len`
param_shape
=
tuple
(
param_shape
)
extras
=
tuple
(
param_shape
[:
(
len
(
param_shape
)
-
ndim_param
)])
...
...
@@ -63,11 +64,12 @@ def params_broadcast_shapes(
extra_dims
=
tuple
(
reversed
(
rev_extra_dims
))
# strict=False because we are in a hot loop
bcast_shapes
=
[
(
extra_dims
+
tuple
(
param_shape
)[
-
ndim_param
:])
if
ndim_param
>
0
else
extra_dims
for
ndim_param
,
param_shape
in
zip
(
ndims_params
,
param_shapes
,
strict
=
Tru
e
)
for
ndim_param
,
param_shape
in
zip
(
ndims_params
,
param_shapes
,
strict
=
Fals
e
)
]
return
bcast_shapes
...
...
@@ -110,10 +112,11 @@ def broadcast_params(
use_pytensor
=
False
param_shapes
=
[]
for
p
in
params
:
# strict=False because we are in a hot loop
param_shape
=
tuple
(
1
if
bcast
else
s
for
s
,
bcast
in
zip
(
p
.
shape
,
getattr
(
p
,
"broadcastable"
,
(
False
,)
*
p
.
ndim
),
strict
=
Tru
e
p
.
shape
,
getattr
(
p
,
"broadcastable"
,
(
False
,)
*
p
.
ndim
),
strict
=
Fals
e
)
)
use_pytensor
|=
isinstance
(
p
,
Variable
)
...
...
@@ -124,9 +127,10 @@ def broadcast_params(
)
broadcast_to_fn
=
broadcast_to
if
use_pytensor
else
np
.
broadcast_to
# strict=False because we are in a hot loop
bcast_params
=
[
broadcast_to_fn
(
param
,
shape
)
for
shape
,
param
in
zip
(
shapes
,
params
,
strict
=
Tru
e
)
for
shape
,
param
in
zip
(
shapes
,
params
,
strict
=
Fals
e
)
]
return
bcast_params
...
...
pytensor/tensor/rewriting/subtensor.py
浏览文件 @
4b41e092
...
...
@@ -683,7 +683,7 @@ def local_subtensor_of_alloc(fgraph, node):
# Slices to take from val
val_slices
=
[]
for
i
,
(
sl
,
dim
)
in
enumerate
(
zip
(
slices
,
dims
[:
len
(
slices
)],
strict
=
Tru
e
)):
for
i
,
(
sl
,
dim
)
in
enumerate
(
zip
(
slices
,
dims
,
strict
=
Fals
e
)):
# If val was not copied over that dim,
# we need to take the appropriate subtensor on it.
if
i
>=
n_added_dims
:
...
...
pytensor/tensor/shape.py
浏览文件 @
4b41e092
...
...
@@ -448,8 +448,9 @@ class SpecifyShape(COp):
raise
AssertionError
(
f
"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
)
# strict=False because we are in a hot loop
if
not
all
(
xs
==
s
for
xs
,
s
in
zip
(
x
.
shape
,
shape
,
strict
=
Tru
e
)
if
s
is
not
None
xs
==
s
for
xs
,
s
in
zip
(
x
.
shape
,
shape
,
strict
=
Fals
e
)
if
s
is
not
None
):
raise
AssertionError
(
f
"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
...
...
@@ -578,15 +579,12 @@ def specify_shape(
x
=
ptb
.
as_tensor_variable
(
x
)
# type: ignore[arg-type,unused-ignore]
# The above is a type error in Python 3.9 but not 3.12.
# Thus we need to ignore unused-ignore on 3.12.
new_shape_info
=
any
(
s
!=
xts
for
(
s
,
xts
)
in
zip
(
shape
,
x
.
type
.
shape
,
strict
=
False
)
if
s
is
not
None
)
# If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
if
len
(
shape
)
!=
x
.
type
.
ndim
:
return
_specify_shape
(
x
,
*
shape
)
new_shape_matches
=
all
(
s
==
xts
for
(
s
,
xts
)
in
zip
(
shape
,
x
.
type
.
shape
,
strict
=
True
)
if
s
is
not
None
)
if
new_shape_matches
:
if
not
new_shape_info
and
len
(
shape
)
==
x
.
type
.
ndim
:
return
x
return
_specify_shape
(
x
,
*
shape
)
...
...
pytensor/tensor/type.py
浏览文件 @
4b41e092
...
...
@@ -248,9 +248,10 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
" PyTensor C code does not support that."
,
)
# strict=False because we are in a hot loop
if
not
all
(
ds
==
ts
if
ts
is
not
None
else
True
for
ds
,
ts
in
zip
(
data
.
shape
,
self
.
shape
,
strict
=
Tru
e
)
for
ds
,
ts
in
zip
(
data
.
shape
,
self
.
shape
,
strict
=
Fals
e
)
):
raise
TypeError
(
f
"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})"
...
...
@@ -319,6 +320,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
return
False
def
is_super
(
self
,
otype
):
# strict=False because we are in a hot loop
if
(
isinstance
(
otype
,
type
(
self
))
and
otype
.
dtype
==
self
.
dtype
...
...
@@ -327,7 +329,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
# but not less
and
all
(
sb
==
ob
or
sb
is
None
for
sb
,
ob
in
zip
(
self
.
shape
,
otype
.
shape
,
strict
=
Tru
e
)
for
sb
,
ob
in
zip
(
self
.
shape
,
otype
.
shape
,
strict
=
Fals
e
)
)
):
return
True
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论