Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
63c513e1
提交
63c513e1
authored
9月 28, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
10月 20, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Simplify Repeat Op to only work with specific axis and vector repeats
上级
3dcf1fb6
显示空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
124 行增加
和
169 行删除
+124
-169
basic.py
pytensor/link/numba/dispatch/basic.py
+1
-1
extra_ops.py
pytensor/link/numba/dispatch/extra_ops.py
+11
-33
extra_ops.py
pytensor/tensor/extra_ops.py
+68
-86
test_extra_ops.py
tests/link/numba/test_extra_ops.py
+3
-15
test_extra_ops.py
tests/tensor/test_extra_ops.py
+41
-34
没有找到文件。
pytensor/link/numba/dispatch/basic.py
浏览文件 @
63c513e1
...
...
@@ -220,7 +220,7 @@ def numba_typify(data, dtype=None, **kwargs):
return
data
def
generate_fallback_impl
(
op
,
node
=
None
,
storage_map
=
None
,
**
kwargs
):
def
generate_fallback_impl
(
op
,
node
,
storage_map
=
None
,
**
kwargs
):
"""Create a Numba compatible function from a Pytensor `Op`."""
warnings
.
warn
(
...
...
pytensor/link/numba/dispatch/extra_ops.py
浏览文件 @
63c513e1
...
...
@@ -6,7 +6,11 @@ import numpy as np
from
pytensor.graph
import
Apply
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
get_numba_type
,
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
(
generate_fallback_impl
,
get_numba_type
,
numba_funcify
,
)
from
pytensor.raise_op
import
CheckAndRaise
from
pytensor.tensor
import
TensorVariable
from
pytensor.tensor.extra_ops
import
(
...
...
@@ -200,39 +204,10 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
@numba_funcify.register
(
Repeat
)
def
numba_funcify_Repeat
(
op
,
node
,
**
kwargs
):
axis
=
op
.
axis
a
,
_
=
node
.
inputs
use_python
=
False
if
axis
is
not
None
:
use_python
=
True
if
use_python
:
warnings
.
warn
(
(
"Numba will use object mode to allow the "
"`axis` argument to `numpy.repeat`."
),
UserWarning
,
)
ret_sig
=
get_numba_type
(
node
.
outputs
[
0
]
.
type
)
@numba_basic.numba_njit
def
repeatop
(
x
,
repeats
):
with
numba
.
objmode
(
ret
=
ret_sig
):
ret
=
np
.
repeat
(
x
,
repeats
,
axis
)
return
ret
else
:
repeats_ndim
=
node
.
inputs
[
1
]
.
ndim
if
repeats_ndim
==
0
:
@numba_basic.numba_njit
(
inline
=
"always"
)
def
repeatop
(
x
,
repeats
):
return
np
.
repeat
(
x
,
repeats
.
item
())
else
:
# Numba only supports axis=None, which in our case is when axis is 0 and the input is a vector
if
axis
==
0
and
a
.
type
.
ndim
==
1
:
@numba_basic.numba_njit
(
inline
=
"always"
)
def
repeatop
(
x
,
repeats
):
...
...
@@ -240,6 +215,9 @@ def numba_funcify_Repeat(op, node, **kwargs):
return
repeatop
else
:
return
generate_fallback_impl
(
op
,
node
)
@numba_funcify.register
(
Unique
)
def
numba_funcify_Unique
(
op
,
node
,
**
kwargs
):
...
...
pytensor/tensor/extra_ops.py
浏览文件 @
63c513e1
...
...
@@ -660,53 +660,72 @@ class Repeat(Op):
__props__
=
(
"axis"
,)
def
__init__
(
self
,
axis
:
int
|
None
=
None
):
if
axis
is
not
None
:
if
not
isinstance
(
axis
,
int
)
or
axis
<
0
:
def
__init__
(
self
,
axis
:
int
):
if
isinstance
(
axis
,
int
):
if
axis
<
0
:
raise
ValueError
(
f
"Repeat Op only accepts positive integer axis, got {axis}. "
"Use the helper `pt.repeat` to handle negative axis."
)
elif
axis
is
None
:
raise
ValueError
(
f
"Repeat only accepts positive integer axis or None, got {axis}"
"Repeat Op only accepts positive integer axis. "
"Use the helper `pt.repeat` to handle axis=None."
)
else
:
raise
TypeError
(
f
"Invalid type for axis {axis}, expected int got {type(axis)}"
)
self
.
axis
=
axis
def
make_node
(
self
,
x
,
repeats
):
x
=
ptb
.
as_tensor_variable
(
x
)
repeats
=
ptb
.
as_tensor_variable
(
repeats
,
dtype
=
"int64"
)
if
repeats
.
dtype
not
in
integer_dtypes
:
raise
TypeError
(
"repeats.dtype must be an integer."
)
if
repeats
.
type
.
ndim
!=
1
:
if
repeats
.
type
.
ndim
==
0
:
raise
ValueError
(
f
"repeats {repeats} must have 1 dimension, got 0. Use the helper `pt.repeat` to handle scalar repeats."
)
else
:
raise
ValueError
(
f
"repeats {repeats} must have 1 dimension, got {repeats.type.ndim}"
)
if
repeats
.
type
.
dtype
not
in
integer_dtypes
:
raise
TypeError
(
f
"repeats {repeats} dtype must be an integer, got {repeats.type.dtype}."
)
# Some dtypes are not supported by numpy's implementation of repeat.
# Until another one is available, we should fail at graph construction
# time, not wait for execution.
ptr_bitwidth
=
LOCAL_BITWIDTH
if
ptr_bitwidth
==
64
:
numpy_unsupported_dtypes
=
(
"uint64"
,)
if
ptr_bitwidth
==
32
:
numpy_unsupported_dtypes
=
(
"uint32"
,
"int64"
,
"uint64"
)
if
repeats
.
dtype
in
numpy_unsupported_dtypes
:
numpy_unsupported_dtypes
=
(
(
"uint64"
,)
if
LOCAL_BITWIDTH
==
64
else
(
"uint64"
,
"uint32"
,
"int64"
)
)
if
repeats
.
type
.
dtype
in
numpy_unsupported_dtypes
:
raise
TypeError
(
(
f
"dtypes {numpy_unsupported_dtypes!s} are not supported by numpy.repeat "
"for the 'repeats' parameter, "
),
repeats
.
dtype
,
f
"repeats {repeats} dtype {repeats.type.dtype} are not supported by numpy.repeat"
)
if
self
.
axis
is
None
:
out_shape
=
[
None
]
else
:
shape
=
list
(
x
.
type
.
shape
)
axis_input_dim_length
=
shape
[
self
.
axis
]
axis_output_dim_length
=
None
if
axis_input_dim_length
is
not
None
:
# If we have a static dim and constant repeats we can infer the length of the output dim
# Right now we only support homogenous constant repeats
try
:
const_reps
=
ptb
.
get_scalar_constant_value
(
repeats
)
const_reps
=
ptb
.
get_
underlying_
scalar_constant_value
(
repeats
)
except
NotScalarConstantError
:
const_reps
=
None
if
const_reps
==
1
:
out_shape
=
x
.
type
.
shape
pass
else
:
out_shape
=
list
(
x
.
type
.
shape
)
out_shape
[
self
.
axis
]
=
None
axis_output_dim_length
=
int
(
const_reps
*
axis_input_dim_length
)
shape
[
self
.
axis
]
=
axis_output_dim_length
out_type
=
TensorType
(
x
.
dtype
,
shape
=
out_
shape
)
out_type
=
TensorType
(
x
.
dtype
,
shape
=
shape
)
return
Apply
(
self
,
[
x
,
repeats
],
[
out_type
()])
def
perform
(
self
,
node
,
inputs
,
output_storage
):
...
...
@@ -720,37 +739,20 @@ class Repeat(Op):
(
x
,
repeats
)
=
inputs
(
gz
,)
=
gout
axis
=
self
.
axis
if
repeats
.
ndim
==
0
:
# When axis is a scalar (same number of reps for all elements),
# We can split the repetitions into their own axis with reshape and sum them back
# to the original element location
sum_axis
=
x
.
ndim
if
axis
is
None
else
axis
+
1
shape
=
list
(
x
.
shape
)
shape
.
insert
(
sum_axis
,
repeats
)
gx
=
gz
.
reshape
(
shape
)
.
sum
(
axis
=
sum_axis
)
elif
repeats
.
ndim
==
1
:
# To sum the gradients that belong to the same repeated x,
# We create a repeated eye and dot product it with the gradient.
axis_size
=
x
.
size
if
axis
is
None
else
x
.
shape
[
axis
]
axis_size
=
x
.
shape
[
axis
]
repeated_eye
=
repeat
(
ptb
.
eye
(
axis_size
),
repeats
,
axis
=
0
)
# A sparse repeat would be neat
if
axis
is
None
:
gx
=
gz
@
repeated_eye
# Undo the ravelling when axis=None
gx
=
gx
.
reshape
(
x
.
shape
)
else
:
# Place gradient axis at end for dot product
gx
=
ptb
.
moveaxis
(
gz
,
axis
,
-
1
)
gx
=
gx
@
repeated_eye
# Place gradient back into the correct axis
gx
=
ptb
.
moveaxis
(
gx
,
-
1
,
axis
)
else
:
raise
ValueError
()
return
[
gx
,
DisconnectedType
()()]
def
infer_shape
(
self
,
fgraph
,
node
,
ins_shapes
):
...
...
@@ -763,21 +765,7 @@ class Repeat(Op):
dtype
=
None
if
repeats
.
dtype
in
(
"uint8"
,
"uint16"
,
"uint32"
):
dtype
=
"int64"
if
axis
is
None
:
if
repeats
.
ndim
==
0
:
if
len
(
i0_shapes
)
==
0
:
out_shape
=
[
repeats
]
else
:
res
=
1
for
d
in
i0_shapes
:
res
=
res
*
d
out_shape
=
(
res
*
repeats
,)
else
:
out_shape
=
[
pt_sum
(
repeats
,
dtype
=
dtype
)]
else
:
if
repeats
.
ndim
==
0
:
out_shape
[
axis
]
=
out_shape
[
axis
]
*
repeats
else
:
out_shape
[
axis
]
=
pt_sum
(
repeats
,
dtype
=
dtype
)
return
[
out_shape
]
...
...
@@ -843,7 +831,10 @@ def repeat(
"""
a
=
ptb
.
as_tensor_variable
(
a
)
if
axis
is
not
None
:
if
axis
is
None
:
axis
=
0
a
=
a
.
flatten
()
else
:
axis
=
normalize_axis_index
(
axis
,
a
.
ndim
)
repeats
=
ptb
.
as_tensor_variable
(
repeats
,
dtype
=
np
.
int64
)
...
...
@@ -851,40 +842,31 @@ def repeat(
if
repeats
.
ndim
>
1
:
raise
ValueError
(
"The dimension of repeats should not exceed 1."
)
if
repeats
.
ndim
==
1
and
not
repeats
.
broadcastable
[
0
]:
if
repeats
.
type
.
broadcastable
==
(
True
,):
# This behaves the same as scalar repeat
repeats
=
repeats
.
squeeze
()
if
repeats
.
ndim
==
1
:
# We only use the Repeat Op for vector repeats
return
Repeat
(
axis
=
axis
)(
a
,
repeats
)
else
:
if
repeats
.
ndim
==
1
:
repeats
=
repeats
[
0
]
if
a
.
dtype
==
"uint64"
:
# Multiplying int64 (shape) by uint64 (repeats) yields a float64
# Which is not valid for the `reshape` operation at the end
raise
TypeError
(
"repeat doesn't support dtype uint64"
)
if
axis
is
None
:
axis
=
0
a
=
a
.
flatten
()
repeat_shape
=
list
(
a
.
shape
)
# Scalar repeat, we implement this with canonical Ops broadcast + reshape
a_shape
=
a
.
shape
# alloc_shape is the shape of the intermediate tensor which has
# an additional dimension comparing to x. We use alloc to
# allocate space for this intermediate tensor to replicate x
# along that additional dimension.
alloc_shape
=
repeat_shape
[:]
alloc_shape
.
insert
(
axis
+
1
,
repeats
)
# Replicate a along a new axis (axis+1) repeats times
broadcast_shape
=
list
(
a_shape
)
broadcast_shape
.
insert
(
axis
+
1
,
repeats
)
broadcast_a
=
broadcast_to
(
ptb
.
expand_dims
(
a
,
axis
+
1
),
broadcast_shape
)
#
repeat_shape is now the shape of output, where shape[axis] becomes
# shape[axis]*repeats.
#
Reshape broadcast_a to the final shape, merging axis and axis+1
repeat_shape
=
list
(
a_shape
)
repeat_shape
[
axis
]
=
repeat_shape
[
axis
]
*
repeats
# After the original tensor is duplicated along the additional
# dimension, we reshape it to the expected output shape
return
ptb
.
alloc
(
ptb
.
expand_dims
(
a
,
axis
+
1
),
*
alloc_shape
)
.
reshape
(
repeat_shape
)
return
broadcast_a
.
reshape
(
repeat_shape
)
class
Bartlett
(
Op
):
...
...
tests/link/numba/test_extra_ops.py
浏览文件 @
63c513e1
...
...
@@ -212,27 +212,15 @@ def test_RavelMultiIndex(arr, shape, mode, order, exc):
@pytest.mark.parametrize
(
"x, repeats, axis, exc"
,
[
(
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
0
,
dtype
=
"int64"
)),
None
,
None
,
),
(
(
pt
.
lmatrix
(),
np
.
zeros
((
2
,
2
),
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
None
,
None
,
),
(
(
pt
.
lvector
(),
np
.
arange
(
2
,
dtype
=
"int64"
)),
(
pt
.
lvector
(),
np
.
array
([
1
,
1
],
dtype
=
"int64"
)),
None
,
(
pt
.
lvector
(),
np
.
array
([
1
,
3
],
dtype
=
"int64"
)),
0
,
None
,
),
(
(
pt
.
lmatrix
(),
np
.
zeros
((
2
,
2
),
dtype
=
"int64"
)),
(
pt
.
l
scalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
(
pt
.
l
vector
(),
np
.
array
([
1
,
3
]
,
dtype
=
"int64"
)),
0
,
UserWarning
,
),
...
...
tests/tensor/test_extra_ops.py
浏览文件 @
63c513e1
...
...
@@ -530,7 +530,7 @@ class TestRepeat(utt.InferShapeTester):
def
setup_method
(
self
):
super
()
.
setup_method
()
self
.
op_class
=
Repeat
self
.
op
=
Repeat
()
self
.
op
=
Repeat
(
axis
=
0
)
# uint64 always fails
# int64 and uint32 also fail if python int are 32-bit
if
LOCAL_BITWIDTH
==
64
:
...
...
@@ -595,40 +595,27 @@ class TestRepeat(utt.InferShapeTester):
def
test_infer_shape
(
self
,
ndim
,
dtype
):
rng
=
np
.
random
.
default_rng
(
4282
)
x
=
TensorType
(
config
.
floatX
,
shape
=
(
None
,)
*
ndim
)()
a_var
=
TensorType
(
config
.
floatX
,
shape
=
(
None
,)
*
ndim
)(
"a"
)
r_var
=
vector
(
"r"
,
dtype
=
dtype
)
shp
=
(
np
.
arange
(
ndim
)
+
1
)
*
3
a
=
rng
.
random
(
shp
)
.
astype
(
config
.
floatX
)
for
axis
in
self
.
_possible_axis
(
ndim
):
if
axis
is
not
None
and
axis
<
0
:
# Operator does not support negative axis
continue
r_var
=
scalar
(
dtype
=
dtype
)
r
=
np
.
asarray
(
3
,
dtype
=
dtype
)
if
dtype
in
self
.
numpy_unsupported_dtypes
:
r_var
=
vector
(
dtype
=
dtype
)
with
pytest
.
raises
(
TypeError
):
repeat
(
x
,
r_var
)
else
:
self
.
_compile_and_check
(
[
x
,
r_var
],
[
Repeat
(
axis
=
axis
)(
x
,
r_var
)],
[
a
,
r
],
self
.
op_class
,
)
repeat
(
a_var
,
r_var
,
axis
=
axis
)
continue
if
axis
is
None
or
axis
<
0
:
# Operator Repeat does not support None or negative axis
continue
r_var
=
vector
(
dtype
=
dtype
)
if
axis
is
None
:
r
=
rng
.
integers
(
1
,
6
,
size
=
a
.
size
)
.
astype
(
dtype
)
elif
a
.
size
>
0
:
r
=
rng
.
integers
(
1
,
6
,
size
=
a
.
shape
[
axis
])
.
astype
(
dtype
)
else
:
r
=
rng
.
integers
(
1
,
6
,
size
=
(
10
,))
.
astype
(
dtype
)
self
.
_compile_and_check
(
[
x
,
r_var
],
[
Repeat
(
axis
=
axis
)(
x
,
r_var
)],
[
a_var
,
r_var
],
[
Repeat
(
axis
=
axis
)(
a_var
,
r_var
)],
[
a
,
r
],
self
.
op_class
,
)
...
...
@@ -647,18 +634,38 @@ class TestRepeat(utt.InferShapeTester):
repeats_size
=
(
x_test
.
shape
[
axis
]
if
axis
is
not
None
else
x_test
.
size
,)
repeats
=
rng
.
integers
(
1
,
6
,
size
=
repeats_size
)
utt
.
verify_grad
(
lambda
x
:
Repeat
(
axis
=
axis
)(
x
,
repeat
s
),
lambda
x
:
repeat
(
x
,
repeats
,
axis
=
axi
s
),
[
x_test
],
)
def
test_broadcastable
(
self
):
x
=
TensorType
(
config
.
floatX
,
shape
=
(
None
,
1
,
None
))()
r
=
Repeat
(
axis
=
1
)(
x
,
2
)
assert
r
.
broadcastable
==
(
False
,
False
,
False
)
r
=
Repeat
(
axis
=
1
)(
x
,
1
)
assert
r
.
broadcastable
==
(
False
,
True
,
False
)
r
=
Repeat
(
axis
=
0
)(
x
,
2
)
assert
r
.
broadcastable
==
(
False
,
True
,
False
)
def
test_static_shape
(
self
):
x
=
TensorType
(
config
.
floatX
,
shape
=
(
None
,
1
,
3
))()
symbolic_r
=
scalar
(
dtype
=
"int32"
)
r
=
repeat
(
x
,
2
,
axis
=
0
)
assert
r
.
type
.
shape
==
(
None
,
1
,
3
)
r
=
repeat
(
x
,
2
,
axis
=
1
)
assert
r
.
type
.
shape
==
(
None
,
2
,
3
)
r
=
repeat
(
x
,
[
2
],
axis
=
1
)
assert
r
.
type
.
shape
==
(
None
,
2
,
3
)
r
=
repeat
(
x
,
symbolic_r
,
axis
=
1
)
assert
r
.
type
.
shape
==
(
None
,
None
,
3
)
r
=
repeat
(
x
,
1
,
axis
=
1
)
assert
r
.
type
.
shape
==
(
None
,
1
,
3
)
r
=
repeat
(
x
,
2
,
axis
=
2
)
assert
r
.
type
.
shape
==
(
None
,
1
,
6
)
r
=
repeat
(
x
,
[
2
,
2
,
2
],
axis
=
2
)
assert
r
.
type
.
shape
==
(
None
,
1
,
6
)
# This case could be implemented in the future
r
=
repeat
(
x
,
[
1
,
2
,
4
],
axis
=
2
)
assert
r
.
type
.
shape
==
(
None
,
1
,
None
)
class
TestBartlett
(
utt
.
InferShapeTester
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论