Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
63c513e1
提交
63c513e1
authored
9月 28, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
10月 20, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Simplify Repeat Op to only work with specific axis and vector repeats
上级
3dcf1fb6
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
142 行增加
和
187 行删除
+142
-187
basic.py
pytensor/link/numba/dispatch/basic.py
+1
-1
extra_ops.py
pytensor/link/numba/dispatch/extra_ops.py
+13
-35
extra_ops.py
pytensor/tensor/extra_ops.py
+79
-97
test_extra_ops.py
tests/link/numba/test_extra_ops.py
+3
-15
test_extra_ops.py
tests/tensor/test_extra_ops.py
+46
-39
没有找到文件。
pytensor/link/numba/dispatch/basic.py
浏览文件 @
63c513e1
...
@@ -220,7 +220,7 @@ def numba_typify(data, dtype=None, **kwargs):
...
@@ -220,7 +220,7 @@ def numba_typify(data, dtype=None, **kwargs):
return
data
return
data
def
generate_fallback_impl
(
op
,
node
=
None
,
storage_map
=
None
,
**
kwargs
):
def
generate_fallback_impl
(
op
,
node
,
storage_map
=
None
,
**
kwargs
):
"""Create a Numba compatible function from a Pytensor `Op`."""
"""Create a Numba compatible function from a Pytensor `Op`."""
warnings
.
warn
(
warnings
.
warn
(
...
...
pytensor/link/numba/dispatch/extra_ops.py
浏览文件 @
63c513e1
...
@@ -6,7 +6,11 @@ import numpy as np
...
@@ -6,7 +6,11 @@ import numpy as np
from
pytensor.graph
import
Apply
from
pytensor.graph
import
Apply
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch.basic
import
get_numba_type
,
numba_funcify
from
pytensor.link.numba.dispatch.basic
import
(
generate_fallback_impl
,
get_numba_type
,
numba_funcify
,
)
from
pytensor.raise_op
import
CheckAndRaise
from
pytensor.raise_op
import
CheckAndRaise
from
pytensor.tensor
import
TensorVariable
from
pytensor.tensor
import
TensorVariable
from
pytensor.tensor.extra_ops
import
(
from
pytensor.tensor.extra_ops
import
(
...
@@ -200,45 +204,19 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
...
@@ -200,45 +204,19 @@ def numba_funcify_RavelMultiIndex(op, node, **kwargs):
@numba_funcify.register
(
Repeat
)
@numba_funcify.register
(
Repeat
)
def
numba_funcify_Repeat
(
op
,
node
,
**
kwargs
):
def
numba_funcify_Repeat
(
op
,
node
,
**
kwargs
):
axis
=
op
.
axis
axis
=
op
.
axis
a
,
_
=
node
.
inputs
use_python
=
False
# Numba only supports axis=None, which in our case is when axis is 0 and the input is a vector
if
axis
==
0
and
a
.
type
.
ndim
==
1
:
if
axis
is
not
None
:
use_python
=
True
if
use_python
:
warnings
.
warn
(
(
"Numba will use object mode to allow the "
"`axis` argument to `numpy.repeat`."
),
UserWarning
,
)
ret_sig
=
get_numba_type
(
node
.
outputs
[
0
]
.
type
)
@numba_basic.numba_njit
@numba_basic.numba_njit
(
inline
=
"always"
)
def
repeatop
(
x
,
repeats
):
def
repeatop
(
x
,
repeats
):
with
numba
.
objmode
(
ret
=
ret_sig
):
return
np
.
repeat
(
x
,
repeats
)
ret
=
np
.
repeat
(
x
,
repeats
,
axis
)
return
ret
else
:
return
repeatop
repeats_ndim
=
node
.
inputs
[
1
]
.
ndim
if
repeats_ndim
==
0
:
else
:
return
generate_fallback_impl
(
op
,
node
)
@numba_basic.numba_njit
(
inline
=
"always"
)
def
repeatop
(
x
,
repeats
):
return
np
.
repeat
(
x
,
repeats
.
item
())
else
:
@numba_basic.numba_njit
(
inline
=
"always"
)
def
repeatop
(
x
,
repeats
):
return
np
.
repeat
(
x
,
repeats
)
return
repeatop
@numba_funcify.register
(
Unique
)
@numba_funcify.register
(
Unique
)
...
...
pytensor/tensor/extra_ops.py
浏览文件 @
63c513e1
...
@@ -660,53 +660,72 @@ class Repeat(Op):
...
@@ -660,53 +660,72 @@ class Repeat(Op):
__props__
=
(
"axis"
,)
__props__
=
(
"axis"
,)
def
__init__
(
self
,
axis
:
int
|
None
=
None
):
def
__init__
(
self
,
axis
:
int
):
if
axis
is
not
None
:
if
isinstance
(
axis
,
int
)
:
if
not
isinstance
(
axis
,
int
)
or
axis
<
0
:
if
axis
<
0
:
raise
ValueError
(
raise
ValueError
(
f
"Repeat only accepts positive integer axis or None, got {axis}"
f
"Repeat Op only accepts positive integer axis, got {axis}. "
"Use the helper `pt.repeat` to handle negative axis."
)
)
elif
axis
is
None
:
raise
ValueError
(
"Repeat Op only accepts positive integer axis. "
"Use the helper `pt.repeat` to handle axis=None."
)
else
:
raise
TypeError
(
f
"Invalid type for axis {axis}, expected int got {type(axis)}"
)
self
.
axis
=
axis
self
.
axis
=
axis
def
make_node
(
self
,
x
,
repeats
):
def
make_node
(
self
,
x
,
repeats
):
x
=
ptb
.
as_tensor_variable
(
x
)
x
=
ptb
.
as_tensor_variable
(
x
)
repeats
=
ptb
.
as_tensor_variable
(
repeats
,
dtype
=
"int64"
)
repeats
=
ptb
.
as_tensor_variable
(
repeats
,
dtype
=
"int64"
)
if
repeats
.
dtype
not
in
integer_dtypes
:
if
repeats
.
type
.
ndim
!=
1
:
raise
TypeError
(
"repeats.dtype must be an integer."
)
if
repeats
.
type
.
ndim
==
0
:
raise
ValueError
(
f
"repeats {repeats} must have 1 dimension, got 0. Use the helper `pt.repeat` to handle scalar repeats."
)
else
:
raise
ValueError
(
f
"repeats {repeats} must have 1 dimension, got {repeats.type.ndim}"
)
if
repeats
.
type
.
dtype
not
in
integer_dtypes
:
raise
TypeError
(
f
"repeats {repeats} dtype must be an integer, got {repeats.type.dtype}."
)
# Some dtypes are not supported by numpy's implementation of repeat.
# Some dtypes are not supported by numpy's implementation of repeat.
# Until another one is available, we should fail at graph construction
# Until another one is available, we should fail at graph construction
# time, not wait for execution.
# time, not wait for execution.
ptr_bitwidth
=
LOCAL_BITWIDTH
numpy_unsupported_dtypes
=
(
if
ptr_bitwidth
==
64
:
(
"uint64"
,)
if
LOCAL_BITWIDTH
==
64
else
(
"uint64"
,
"uint32"
,
"int64"
)
numpy_unsupported_dtypes
=
(
"uint64"
,)
)
if
ptr_bitwidth
==
32
:
if
repeats
.
type
.
dtype
in
numpy_unsupported_dtypes
:
numpy_unsupported_dtypes
=
(
"uint32"
,
"int64"
,
"uint64"
)
if
repeats
.
dtype
in
numpy_unsupported_dtypes
:
raise
TypeError
(
raise
TypeError
(
(
f
"repeats {repeats} dtype {repeats.type.dtype} are not supported by numpy.repeat"
f
"dtypes {numpy_unsupported_dtypes!s} are not supported by numpy.repeat "
"for the 'repeats' parameter, "
),
repeats
.
dtype
,
)
)
if
self
.
axis
is
None
:
shape
=
list
(
x
.
type
.
shape
)
out_shape
=
[
None
]
axis_input_dim_length
=
shape
[
self
.
axis
]
else
:
axis_output_dim_length
=
None
if
axis_input_dim_length
is
not
None
:
# If we have a static dim and constant repeats we can infer the length of the output dim
# Right now we only support homogenous constant repeats
try
:
try
:
const_reps
=
ptb
.
get_scalar_constant_value
(
repeats
)
const_reps
=
ptb
.
get_
underlying_
scalar_constant_value
(
repeats
)
except
NotScalarConstantError
:
except
NotScalarConstantError
:
const_reps
=
None
pass
if
const_reps
==
1
:
out_shape
=
x
.
type
.
shape
else
:
else
:
out_shape
=
list
(
x
.
type
.
shape
)
axis_output_dim_length
=
int
(
const_reps
*
axis_input_dim_length
)
out_shape
[
self
.
axis
]
=
None
shape
[
self
.
axis
]
=
axis_output_dim_length
out_type
=
TensorType
(
x
.
dtype
,
shape
=
out_
shape
)
out_type
=
TensorType
(
x
.
dtype
,
shape
=
shape
)
return
Apply
(
self
,
[
x
,
repeats
],
[
out_type
()])
return
Apply
(
self
,
[
x
,
repeats
],
[
out_type
()])
def
perform
(
self
,
node
,
inputs
,
output_storage
):
def
perform
(
self
,
node
,
inputs
,
output_storage
):
...
@@ -720,36 +739,19 @@ class Repeat(Op):
...
@@ -720,36 +739,19 @@ class Repeat(Op):
(
x
,
repeats
)
=
inputs
(
x
,
repeats
)
=
inputs
(
gz
,)
=
gout
(
gz
,)
=
gout
axis
=
self
.
axis
axis
=
self
.
axis
if
repeats
.
ndim
==
0
:
# When axis is a scalar (same number of reps for all elements),
# We can split the repetitions into their own axis with reshape and sum them back
# to the original element location
sum_axis
=
x
.
ndim
if
axis
is
None
else
axis
+
1
shape
=
list
(
x
.
shape
)
shape
.
insert
(
sum_axis
,
repeats
)
gx
=
gz
.
reshape
(
shape
)
.
sum
(
axis
=
sum_axis
)
elif
repeats
.
ndim
==
1
:
# To sum the gradients that belong to the same repeated x,
# We create a repeated eye and dot product it with the gradient.
axis_size
=
x
.
size
if
axis
is
None
else
x
.
shape
[
axis
]
repeated_eye
=
repeat
(
ptb
.
eye
(
axis_size
),
repeats
,
axis
=
0
)
# A sparse repeat would be neat
if
axis
is
None
:
gx
=
gz
@
repeated_eye
# Undo the ravelling when axis=None
gx
=
gx
.
reshape
(
x
.
shape
)
else
:
# Place gradient axis at end for dot product
gx
=
ptb
.
moveaxis
(
gz
,
axis
,
-
1
)
gx
=
gx
@
repeated_eye
# Place gradient back into the correct axis
gx
=
ptb
.
moveaxis
(
gx
,
-
1
,
axis
)
else
:
# To sum the gradients that belong to the same repeated x,
raise
ValueError
()
# We create a repeated eye and dot product it with the gradient.
axis_size
=
x
.
shape
[
axis
]
repeated_eye
=
repeat
(
ptb
.
eye
(
axis_size
),
repeats
,
axis
=
0
)
# A sparse repeat would be neat
# Place gradient axis at end for dot product
gx
=
ptb
.
moveaxis
(
gz
,
axis
,
-
1
)
gx
=
gx
@
repeated_eye
# Place gradient back into the correct axis
gx
=
ptb
.
moveaxis
(
gx
,
-
1
,
axis
)
return
[
gx
,
DisconnectedType
()()]
return
[
gx
,
DisconnectedType
()()]
...
@@ -763,22 +765,8 @@ class Repeat(Op):
...
@@ -763,22 +765,8 @@ class Repeat(Op):
dtype
=
None
dtype
=
None
if
repeats
.
dtype
in
(
"uint8"
,
"uint16"
,
"uint32"
):
if
repeats
.
dtype
in
(
"uint8"
,
"uint16"
,
"uint32"
):
dtype
=
"int64"
dtype
=
"int64"
if
axis
is
None
:
if
repeats
.
ndim
==
0
:
out_shape
[
axis
]
=
pt_sum
(
repeats
,
dtype
=
dtype
)
if
len
(
i0_shapes
)
==
0
:
out_shape
=
[
repeats
]
else
:
res
=
1
for
d
in
i0_shapes
:
res
=
res
*
d
out_shape
=
(
res
*
repeats
,)
else
:
out_shape
=
[
pt_sum
(
repeats
,
dtype
=
dtype
)]
else
:
if
repeats
.
ndim
==
0
:
out_shape
[
axis
]
=
out_shape
[
axis
]
*
repeats
else
:
out_shape
[
axis
]
=
pt_sum
(
repeats
,
dtype
=
dtype
)
return
[
out_shape
]
return
[
out_shape
]
...
@@ -843,7 +831,10 @@ def repeat(
...
@@ -843,7 +831,10 @@ def repeat(
"""
"""
a
=
ptb
.
as_tensor_variable
(
a
)
a
=
ptb
.
as_tensor_variable
(
a
)
if
axis
is
not
None
:
if
axis
is
None
:
axis
=
0
a
=
a
.
flatten
()
else
:
axis
=
normalize_axis_index
(
axis
,
a
.
ndim
)
axis
=
normalize_axis_index
(
axis
,
a
.
ndim
)
repeats
=
ptb
.
as_tensor_variable
(
repeats
,
dtype
=
np
.
int64
)
repeats
=
ptb
.
as_tensor_variable
(
repeats
,
dtype
=
np
.
int64
)
...
@@ -851,40 +842,31 @@ def repeat(
...
@@ -851,40 +842,31 @@ def repeat(
if
repeats
.
ndim
>
1
:
if
repeats
.
ndim
>
1
:
raise
ValueError
(
"The dimension of repeats should not exceed 1."
)
raise
ValueError
(
"The dimension of repeats should not exceed 1."
)
if
repeats
.
ndim
==
1
and
not
repeats
.
broadcastable
[
0
]:
if
repeats
.
type
.
broadcastable
==
(
True
,):
# This behaves the same as scalar repeat
repeats
=
repeats
.
squeeze
()
if
repeats
.
ndim
==
1
:
# We only use the Repeat Op for vector repeats
# We only use the Repeat Op for vector repeats
return
Repeat
(
axis
=
axis
)(
a
,
repeats
)
return
Repeat
(
axis
=
axis
)(
a
,
repeats
)
else
:
else
:
if
repeats
.
ndim
==
1
:
repeats
=
repeats
[
0
]
if
a
.
dtype
==
"uint64"
:
if
a
.
dtype
==
"uint64"
:
# Multiplying int64 (shape) by uint64 (repeats) yields a float64
# Multiplying int64 (shape) by uint64 (repeats) yields a float64
# Which is not valid for the `reshape` operation at the end
# Which is not valid for the `reshape` operation at the end
raise
TypeError
(
"repeat doesn't support dtype uint64"
)
raise
TypeError
(
"repeat doesn't support dtype uint64"
)
if
axis
is
None
:
# Scalar repeat, we implement this with canonical Ops broadcast + reshape
axis
=
0
a_shape
=
a
.
shape
a
=
a
.
flatten
()
repeat_shape
=
list
(
a
.
shape
)
# alloc_shape is the shape of the intermediate tensor which has
# Replicate a along a new axis (axis+1) repeats times
# an additional dimension comparing to x. We use alloc to
broadcast_shape
=
list
(
a_shape
)
# allocate space for this intermediate tensor to replicate x
broadcast_shape
.
insert
(
axis
+
1
,
repeats
)
# along that additional dimension.
broadcast_a
=
broadcast_to
(
ptb
.
expand_dims
(
a
,
axis
+
1
),
broadcast_shape
)
alloc_shape
=
repeat_shape
[:]
alloc_shape
.
insert
(
axis
+
1
,
repeats
)
#
repeat_shape is now the shape of output, where shape[axis] becomes
#
Reshape broadcast_a to the final shape, merging axis and axis+1
# shape[axis]*repeats.
repeat_shape
=
list
(
a_shape
)
repeat_shape
[
axis
]
=
repeat_shape
[
axis
]
*
repeats
repeat_shape
[
axis
]
=
repeat_shape
[
axis
]
*
repeats
return
broadcast_a
.
reshape
(
repeat_shape
)
# After the original tensor is duplicated along the additional
# dimension, we reshape it to the expected output shape
return
ptb
.
alloc
(
ptb
.
expand_dims
(
a
,
axis
+
1
),
*
alloc_shape
)
.
reshape
(
repeat_shape
)
class
Bartlett
(
Op
):
class
Bartlett
(
Op
):
...
...
tests/link/numba/test_extra_ops.py
浏览文件 @
63c513e1
...
@@ -212,27 +212,15 @@ def test_RavelMultiIndex(arr, shape, mode, order, exc):
...
@@ -212,27 +212,15 @@ def test_RavelMultiIndex(arr, shape, mode, order, exc):
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"x, repeats, axis, exc"
,
"x, repeats, axis, exc"
,
[
[
(
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
0
,
dtype
=
"int64"
)),
None
,
None
,
),
(
(
pt
.
lmatrix
(),
np
.
zeros
((
2
,
2
),
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
None
,
None
,
),
(
(
(
pt
.
lvector
(),
np
.
arange
(
2
,
dtype
=
"int64"
)),
(
pt
.
lvector
(),
np
.
arange
(
2
,
dtype
=
"int64"
)),
(
pt
.
lvector
(),
np
.
array
([
1
,
1
],
dtype
=
"int64"
)),
(
pt
.
lvector
(),
np
.
array
([
1
,
3
],
dtype
=
"int64"
)),
None
,
0
,
None
,
None
,
),
),
(
(
(
pt
.
lmatrix
(),
np
.
zeros
((
2
,
2
),
dtype
=
"int64"
)),
(
pt
.
lmatrix
(),
np
.
zeros
((
2
,
2
),
dtype
=
"int64"
)),
(
pt
.
l
scalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
(
pt
.
l
vector
(),
np
.
array
([
1
,
3
]
,
dtype
=
"int64"
)),
0
,
0
,
UserWarning
,
UserWarning
,
),
),
...
...
tests/tensor/test_extra_ops.py
浏览文件 @
63c513e1
...
@@ -530,7 +530,7 @@ class TestRepeat(utt.InferShapeTester):
...
@@ -530,7 +530,7 @@ class TestRepeat(utt.InferShapeTester):
def
setup_method
(
self
):
def
setup_method
(
self
):
super
()
.
setup_method
()
super
()
.
setup_method
()
self
.
op_class
=
Repeat
self
.
op_class
=
Repeat
self
.
op
=
Repeat
()
self
.
op
=
Repeat
(
axis
=
0
)
# uint64 always fails
# uint64 always fails
# int64 and uint32 also fail if python int are 32-bit
# int64 and uint32 also fail if python int are 32-bit
if
LOCAL_BITWIDTH
==
64
:
if
LOCAL_BITWIDTH
==
64
:
...
@@ -595,43 +595,30 @@ class TestRepeat(utt.InferShapeTester):
...
@@ -595,43 +595,30 @@ class TestRepeat(utt.InferShapeTester):
def
test_infer_shape
(
self
,
ndim
,
dtype
):
def
test_infer_shape
(
self
,
ndim
,
dtype
):
rng
=
np
.
random
.
default_rng
(
4282
)
rng
=
np
.
random
.
default_rng
(
4282
)
x
=
TensorType
(
config
.
floatX
,
shape
=
(
None
,)
*
ndim
)()
a_var
=
TensorType
(
config
.
floatX
,
shape
=
(
None
,)
*
ndim
)(
"a"
)
r_var
=
vector
(
"r"
,
dtype
=
dtype
)
shp
=
(
np
.
arange
(
ndim
)
+
1
)
*
3
shp
=
(
np
.
arange
(
ndim
)
+
1
)
*
3
a
=
rng
.
random
(
shp
)
.
astype
(
config
.
floatX
)
a
=
rng
.
random
(
shp
)
.
astype
(
config
.
floatX
)
for
axis
in
self
.
_possible_axis
(
ndim
):
for
axis
in
self
.
_possible_axis
(
ndim
):
if
axis
is
not
None
and
axis
<
0
:
# Operator does not support negative axis
continue
r_var
=
scalar
(
dtype
=
dtype
)
r
=
np
.
asarray
(
3
,
dtype
=
dtype
)
if
dtype
in
self
.
numpy_unsupported_dtypes
:
if
dtype
in
self
.
numpy_unsupported_dtypes
:
r_var
=
vector
(
dtype
=
dtype
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
repeat
(
x
,
r_var
)
repeat
(
a_var
,
r_var
,
axis
=
axis
)
else
:
continue
self
.
_compile_and_check
(
[
x
,
r_var
],
[
Repeat
(
axis
=
axis
)(
x
,
r_var
)],
[
a
,
r
],
self
.
op_class
,
)
r_var
=
vector
(
dtype
=
dtype
)
if
axis
is
None
or
axis
<
0
:
if
axis
is
None
:
# Operator Repeat does not support None or negative axis
r
=
rng
.
integers
(
1
,
6
,
size
=
a
.
size
)
.
astype
(
dtype
)
continue
elif
a
.
size
>
0
:
r
=
rng
.
integers
(
1
,
6
,
size
=
a
.
shape
[
axis
])
.
astype
(
dtype
)
else
:
r
=
rng
.
integers
(
1
,
6
,
size
=
(
10
,))
.
astype
(
dtype
)
self
.
_compile_and_check
(
r
=
rng
.
integers
(
1
,
6
,
size
=
a
.
shape
[
axis
])
.
astype
(
dtype
)
[
x
,
r_var
],
[
Repeat
(
axis
=
axis
)(
x
,
r_var
)],
self
.
_compile_and_check
(
[
a
,
r
],
[
a_var
,
r_var
],
self
.
op_class
,
[
Repeat
(
axis
=
axis
)(
a_var
,
r_var
)],
)
[
a
,
r
],
self
.
op_class
,
)
@pytest.mark.parametrize
(
"x_ndim"
,
[
2
,
3
],
ids
=
lambda
x
:
f
"x_ndim={x}"
)
@pytest.mark.parametrize
(
"x_ndim"
,
[
2
,
3
],
ids
=
lambda
x
:
f
"x_ndim={x}"
)
@pytest.mark.parametrize
(
"repeats_ndim"
,
[
0
,
1
],
ids
=
lambda
r
:
f
"repeats_ndim={r}"
)
@pytest.mark.parametrize
(
"repeats_ndim"
,
[
0
,
1
],
ids
=
lambda
r
:
f
"repeats_ndim={r}"
)
...
@@ -647,18 +634,38 @@ class TestRepeat(utt.InferShapeTester):
...
@@ -647,18 +634,38 @@ class TestRepeat(utt.InferShapeTester):
repeats_size
=
(
x_test
.
shape
[
axis
]
if
axis
is
not
None
else
x_test
.
size
,)
repeats_size
=
(
x_test
.
shape
[
axis
]
if
axis
is
not
None
else
x_test
.
size
,)
repeats
=
rng
.
integers
(
1
,
6
,
size
=
repeats_size
)
repeats
=
rng
.
integers
(
1
,
6
,
size
=
repeats_size
)
utt
.
verify_grad
(
utt
.
verify_grad
(
lambda
x
:
Repeat
(
axis
=
axis
)(
x
,
repeat
s
),
lambda
x
:
repeat
(
x
,
repeats
,
axis
=
axi
s
),
[
x_test
],
[
x_test
],
)
)
def
test_broadcastable
(
self
):
def
test_static_shape
(
self
):
x
=
TensorType
(
config
.
floatX
,
shape
=
(
None
,
1
,
None
))()
x
=
TensorType
(
config
.
floatX
,
shape
=
(
None
,
1
,
3
))()
r
=
Repeat
(
axis
=
1
)(
x
,
2
)
symbolic_r
=
scalar
(
dtype
=
"int32"
)
assert
r
.
broadcastable
==
(
False
,
False
,
False
)
r
=
Repeat
(
axis
=
1
)(
x
,
1
)
r
=
repeat
(
x
,
2
,
axis
=
0
)
assert
r
.
broadcastable
==
(
False
,
True
,
False
)
assert
r
.
type
.
shape
==
(
None
,
1
,
3
)
r
=
Repeat
(
axis
=
0
)(
x
,
2
)
assert
r
.
broadcastable
==
(
False
,
True
,
False
)
r
=
repeat
(
x
,
2
,
axis
=
1
)
assert
r
.
type
.
shape
==
(
None
,
2
,
3
)
r
=
repeat
(
x
,
[
2
],
axis
=
1
)
assert
r
.
type
.
shape
==
(
None
,
2
,
3
)
r
=
repeat
(
x
,
symbolic_r
,
axis
=
1
)
assert
r
.
type
.
shape
==
(
None
,
None
,
3
)
r
=
repeat
(
x
,
1
,
axis
=
1
)
assert
r
.
type
.
shape
==
(
None
,
1
,
3
)
r
=
repeat
(
x
,
2
,
axis
=
2
)
assert
r
.
type
.
shape
==
(
None
,
1
,
6
)
r
=
repeat
(
x
,
[
2
,
2
,
2
],
axis
=
2
)
assert
r
.
type
.
shape
==
(
None
,
1
,
6
)
# This case could be implemented in the future
r
=
repeat
(
x
,
[
1
,
2
,
4
],
axis
=
2
)
assert
r
.
type
.
shape
==
(
None
,
1
,
None
)
class
TestBartlett
(
utt
.
InferShapeTester
):
class
TestBartlett
(
utt
.
InferShapeTester
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论