Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e88117e6
提交
e88117e6
authored
8月 22, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
10月 08, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Only require input_ndim and not input_broadcastable in DimShuffle
上级
d68f53f8
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
24 个修改的文件
包含
66 行增加
和
86 行删除
+66
-86
sp.py
pytensor/sparse/sandbox/sp.py
+2
-3
basic.py
pytensor/tensor/basic.py
+2
-2
elemwise.py
pytensor/tensor/elemwise.py
+0
-0
extra_ops.py
pytensor/tensor/extra_ops.py
+1
-6
inplace.py
pytensor/tensor/inplace.py
+2
-2
math.py
pytensor/tensor/math.py
+1
-3
jax.py
pytensor/tensor/random/rewriting/jax.py
+1
-1
basic.py
pytensor/tensor/rewriting/basic.py
+1
-1
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+2
-4
jax.py
pytensor/tensor/rewriting/jax.py
+1
-1
linalg.py
pytensor/tensor/rewriting/linalg.py
+1
-1
shape.py
pytensor/tensor/rewriting/shape.py
+1
-5
variable.py
pytensor/tensor/variable.py
+2
-2
test_elemwise.py
tests/link/jax/test_elemwise.py
+1
-1
test_elemwise.py
tests/link/numba/test_elemwise.py
+4
-4
test_elemwise.py
tests/link/pytorch/test_elemwise.py
+0
-6
test_elemwise.py
tests/tensor/rewriting/test_elemwise.py
+1
-1
test_math.py
tests/tensor/rewriting/test_math.py
+3
-3
test_basic.py
tests/tensor/test_basic.py
+1
-1
test_blas.py
tests/tensor/test_blas.py
+11
-9
test_elemwise.py
tests/tensor/test_elemwise.py
+16
-23
test_extra_ops.py
tests/tensor/test_extra_ops.py
+2
-5
test_fft.py
tests/tensor/test_fft.py
+9
-0
test_keepdims.py
tests/tensor/test_keepdims.py
+1
-2
没有找到文件。
pytensor/sparse/sandbox/sp.py
浏览文件 @
e88117e6
...
@@ -19,7 +19,6 @@ from pytensor.graph.op import Op
...
@@ -19,7 +19,6 @@ from pytensor.graph.op import Op
from
pytensor.tensor.math
import
dot
from
pytensor.tensor.math
import
dot
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.shape
import
reshape
from
pytensor.tensor.shape
import
reshape
from
pytensor.tensor.subtensor
import
DimShuffle
def
register_specialize
(
lopt
,
*
tags
,
**
kwargs
):
def
register_specialize
(
lopt
,
*
tags
,
**
kwargs
):
...
@@ -375,7 +374,7 @@ def convolve(
...
@@ -375,7 +374,7 @@ def convolve(
[
images
.
shape
[
0
],
pt
.
as_tensor
(
np
.
prod
(
outshp
)),
pt
.
as_tensor
(
nkern
)]
[
images
.
shape
[
0
],
pt
.
as_tensor
(
np
.
prod
(
outshp
)),
pt
.
as_tensor
(
nkern
)]
)
)
tensout
=
reshape
(
output
,
newshp
,
ndim
=
3
)
tensout
=
reshape
(
output
,
newshp
,
ndim
=
3
)
output
=
DimShuffle
((
False
,)
*
tensout
.
ndim
,
(
0
,
2
,
1
))(
tensout
)
output
=
tensout
.
transpose
(
0
,
2
,
1
)
if
flatten
:
if
flatten
:
output
=
pt
.
flatten
(
output
,
2
)
output
=
pt
.
flatten
(
output
,
2
)
...
@@ -443,6 +442,6 @@ def max_pool(images, imgshp, maxpoolshp):
...
@@ -443,6 +442,6 @@ def max_pool(images, imgshp, maxpoolshp):
)
)
out2
=
reshape
(
out1
,
pshape
,
ndim
=
3
)
out2
=
reshape
(
out1
,
pshape
,
ndim
=
3
)
out3
=
DimShuffle
(
out2
.
broadcastable
,
(
0
,
2
,
1
))(
out2
)
out3
=
out2
.
transpose
(
0
,
2
,
1
)
return
pt
.
flatten
(
out3
,
2
),
outshp
return
pt
.
flatten
(
out3
,
2
),
outshp
pytensor/tensor/basic.py
浏览文件 @
e88117e6
...
@@ -2042,7 +2042,7 @@ def transpose(x, axes=None):
...
@@ -2042,7 +2042,7 @@ def transpose(x, axes=None):
# No-op
# No-op
return
_x
return
_x
ret
=
DimShuffle
(
tuple
(
s
==
1
for
s
in
_x
.
type
.
shape
),
axes
)(
_x
)
ret
=
_x
.
dimshuffle
(
axes
)
if
_x
.
name
and
axes
==
tuple
(
range
((
_x
.
type
.
ndim
-
1
),
-
1
,
-
1
)):
if
_x
.
name
and
axes
==
tuple
(
range
((
_x
.
type
.
ndim
-
1
),
-
1
,
-
1
)):
ret
.
name
=
_x
.
name
+
".T"
ret
.
name
=
_x
.
name
+
".T"
...
@@ -3518,7 +3518,7 @@ class PermuteRowElements(Op):
...
@@ -3518,7 +3518,7 @@ class PermuteRowElements(Op):
newdims
.
append
(
i
)
newdims
.
append
(
i
)
i
+=
1
i
+=
1
gx
=
DimShuffle
(
tuple
(
s
==
1
for
s
in
gx
.
type
.
shape
),
newdims
)(
gx
)
gx
=
gx
.
dimshuffle
(
newdims
)
assert
gx
.
type
.
ndim
==
x
.
type
.
ndim
assert
gx
.
type
.
ndim
==
x
.
type
.
ndim
assert
all
(
assert
all
(
s1
==
s2
s1
==
s2
...
...
pytensor/tensor/elemwise.py
浏览文件 @
e88117e6
差异被折叠。
点击展开。
pytensor/tensor/extra_ops.py
浏览文件 @
e88117e6
...
@@ -41,7 +41,7 @@ from pytensor.tensor.math import (
...
@@ -41,7 +41,7 @@ from pytensor.tensor.math import (
)
)
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.math
import
sum
as
pt_sum
from
pytensor.tensor.math
import
sum
as
pt_sum
from
pytensor.tensor.shape
import
Shape_i
,
specify_broadcastable
from
pytensor.tensor.shape
import
Shape_i
from
pytensor.tensor.subtensor
import
advanced_inc_subtensor1
,
set_subtensor
from
pytensor.tensor.subtensor
import
advanced_inc_subtensor1
,
set_subtensor
from
pytensor.tensor.type
import
TensorType
,
dvector
,
int_dtypes
,
integer_dtypes
,
vector
from
pytensor.tensor.type
import
TensorType
,
dvector
,
int_dtypes
,
integer_dtypes
,
vector
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.tensor.variable
import
TensorVariable
...
@@ -609,11 +609,6 @@ def squeeze(x, axis=None):
...
@@ -609,11 +609,6 @@ def squeeze(x, axis=None):
# Nothing could be squeezed
# Nothing could be squeezed
return
_x
return
_x
# `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable.
# We add a `specify_broadcastable` instead of raising.
non_broadcastable_axis
=
[
i
for
i
in
axis
if
not
_x
.
broadcastable
[
i
]]
_x
=
specify_broadcastable
(
_x
,
*
non_broadcastable_axis
)
return
_x
.
dimshuffle
([
i
for
i
in
range
(
_x
.
ndim
)
if
i
not
in
axis
])
return
_x
.
dimshuffle
([
i
for
i
in
range
(
_x
.
ndim
)
if
i
not
in
axis
])
...
...
pytensor/tensor/inplace.py
浏览文件 @
e88117e6
from
pytensor
import
printing
from
pytensor
import
printing
from
pytensor.printing
import
pprint
from
pytensor.printing
import
pprint
from
pytensor.tensor.elemwise
import
DimShuffle
,
scalar_elemwise
from
pytensor.tensor.elemwise
import
scalar_elemwise
@scalar_elemwise
@scalar_elemwise
...
@@ -429,4 +429,4 @@ pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right"))
...
@@ -429,4 +429,4 @@ pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right"))
def
transpose_inplace
(
x
,
**
kwargs
):
def
transpose_inplace
(
x
,
**
kwargs
):
"Perform a transpose on a tensor without copying the underlying storage"
"Perform a transpose on a tensor without copying the underlying storage"
dims
=
list
(
range
(
x
.
ndim
-
1
,
-
1
,
-
1
))
dims
=
list
(
range
(
x
.
ndim
-
1
,
-
1
,
-
1
))
return
DimShuffle
(
x
.
broadcastable
,
dims
)(
x
)
return
x
.
dimshuffle
(
dims
)
pytensor/tensor/math.py
浏览文件 @
e88117e6
...
@@ -33,7 +33,6 @@ from pytensor.tensor.basic import (
...
@@ -33,7 +33,6 @@ from pytensor.tensor.basic import (
from
pytensor.tensor.blockwise
import
Blockwise
,
vectorize_node_fallback
from
pytensor.tensor.blockwise
import
Blockwise
,
vectorize_node_fallback
from
pytensor.tensor.elemwise
import
(
from
pytensor.tensor.elemwise
import
(
CAReduce
,
CAReduce
,
DimShuffle
,
Elemwise
,
Elemwise
,
get_normalized_batch_axes
,
get_normalized_batch_axes
,
scalar_elemwise
,
scalar_elemwise
,
...
@@ -2338,8 +2337,7 @@ class Sum(FixedOpCAReduce):
...
@@ -2338,8 +2337,7 @@ class Sum(FixedOpCAReduce):
else
:
else
:
new_dims
.
append
(
i
)
new_dims
.
append
(
i
)
i
+=
1
i
+=
1
ds_op
=
DimShuffle
(
gz
.
type
.
broadcastable
,
new_dims
)
gx
=
Elemwise
(
ps
.
second
)(
x
,
gz
.
dimshuffle
(
new_dims
))
gx
=
Elemwise
(
ps
.
second
)(
x
,
ds_op
(
gz
))
return
[
gx
]
return
[
gx
]
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
...
...
pytensor/tensor/random/rewriting/jax.py
浏览文件 @
e88117e6
...
@@ -65,7 +65,7 @@ def size_parameter_as_tuple(fgraph, node):
...
@@ -65,7 +65,7 @@ def size_parameter_as_tuple(fgraph, node):
if
isinstance
(
size_node
.
op
,
MakeVector
)
or
(
if
isinstance
(
size_node
.
op
,
MakeVector
)
or
(
isinstance
(
size_node
.
op
,
DimShuffle
)
isinstance
(
size_node
.
op
,
DimShuffle
)
and
size_node
.
op
.
input_
broadcastable
==
()
and
size_node
.
op
.
input_
ndim
==
0
and
size_node
.
op
.
new_order
==
(
"x"
,)
and
size_node
.
op
.
new_order
==
(
"x"
,)
):
):
# Here PyTensor converted a tuple or list to a tensor
# Here PyTensor converted a tuple or list to a tensor
...
...
pytensor/tensor/rewriting/basic.py
浏览文件 @
e88117e6
...
@@ -494,7 +494,7 @@ def local_alloc_sink_dimshuffle(fgraph, node):
...
@@ -494,7 +494,7 @@ def local_alloc_sink_dimshuffle(fgraph, node):
dimshuffle_new_order
=
[
"x"
]
*
num_dims_with_size_1_added_to_left
+
list
(
dimshuffle_new_order
=
[
"x"
]
*
num_dims_with_size_1_added_to_left
+
list
(
range
(
len
(
new_output_shape
))
range
(
len
(
new_output_shape
))
)
)
return
[
DimShuffle
(
inner
.
type
.
broadcastable
,
dimshuffle_new_order
)(
inn
er
)]
return
[
inner
.
dimshuffle
(
dimshuffle_new_ord
er
)]
@node_rewriter
([
AllocEmpty
])
@node_rewriter
([
AllocEmpty
])
...
...
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
e88117e6
...
@@ -422,8 +422,6 @@ def local_dimshuffle_lift(fgraph, node):
...
@@ -422,8 +422,6 @@ def local_dimshuffle_lift(fgraph, node):
"""
"""
op
=
node
.
op
op
=
node
.
op
if
not
isinstance
(
op
,
DimShuffle
):
return
False
inp
=
node
.
inputs
[
0
]
inp
=
node
.
inputs
[
0
]
inode
=
inp
.
owner
inode
=
inp
.
owner
...
@@ -437,7 +435,7 @@ def local_dimshuffle_lift(fgraph, node):
...
@@ -437,7 +435,7 @@ def local_dimshuffle_lift(fgraph, node):
# Don't use make_node to have tag.test_value set.
# Don't use make_node to have tag.test_value set.
new_inputs
=
[]
new_inputs
=
[]
for
inp
in
inode
.
inputs
:
for
inp
in
inode
.
inputs
:
new_inp
=
op
.
__class__
(
inp
.
type
.
broadcastable
,
op
.
new_order
)(
inp
)
new_inp
=
inp
.
dimshuffle
(
op
.
new_order
)
new_inputs
.
append
(
apply_local_dimshuffle_lift
(
fgraph
,
new_inp
))
new_inputs
.
append
(
apply_local_dimshuffle_lift
(
fgraph
,
new_inp
))
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
ret
=
inode
.
op
(
*
new_inputs
,
return_list
=
True
)
ret
=
inode
.
op
(
*
new_inputs
,
return_list
=
True
)
...
@@ -449,7 +447,7 @@ def local_dimshuffle_lift(fgraph, node):
...
@@ -449,7 +447,7 @@ def local_dimshuffle_lift(fgraph, node):
if
is_dimshuffle_useless
(
new_order
,
inp
):
if
is_dimshuffle_useless
(
new_order
,
inp
):
return
[
inp
]
return
[
inp
]
elif
inode
and
isinstance
(
inode
.
op
,
DimShuffle
):
elif
inode
and
isinstance
(
inode
.
op
,
DimShuffle
):
ret
=
op
.
__class__
(
inp
.
type
.
broadcastable
,
new_order
)(
inp
)
ret
=
inp
.
dimshuffle
(
new_order
)
ret
=
apply_local_dimshuffle_lift
(
fgraph
,
ret
)
ret
=
apply_local_dimshuffle_lift
(
fgraph
,
ret
)
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
return
[
ret
]
return
[
ret
]
...
...
pytensor/tensor/rewriting/jax.py
浏览文件 @
e88117e6
...
@@ -130,7 +130,7 @@ def shape_parameter_as_tuple(fgraph, node):
...
@@ -130,7 +130,7 @@ def shape_parameter_as_tuple(fgraph, node):
if
isinstance
(
shape_node
.
op
,
MakeVector
)
or
(
if
isinstance
(
shape_node
.
op
,
MakeVector
)
or
(
isinstance
(
shape_node
.
op
,
DimShuffle
)
isinstance
(
shape_node
.
op
,
DimShuffle
)
and
shape_node
.
op
.
input_
broadcastable
==
()
and
shape_node
.
op
.
input_
ndim
==
0
and
shape_node
.
op
.
new_order
==
(
"x"
,)
and
shape_node
.
op
.
new_order
==
(
"x"
,)
):
):
# Here PyTensor converted a tuple or list to a tensor
# Here PyTensor converted a tuple or list to a tensor
...
...
pytensor/tensor/rewriting/linalg.py
浏览文件 @
e88117e6
...
@@ -65,7 +65,7 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
...
@@ -65,7 +65,7 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
if
ndims
<
2
:
if
ndims
<
2
:
return
False
return
False
transpose_order
=
(
*
range
(
ndims
-
2
),
ndims
-
1
,
ndims
-
2
)
transpose_order
=
(
*
range
(
ndims
-
2
),
ndims
-
1
,
ndims
-
2
)
return
cast
(
bool
,
node
.
op
.
new_order
==
transpose_order
)
return
node
.
op
.
new_order
==
transpose_order
return
False
return
False
...
...
pytensor/tensor/rewriting/shape.py
浏览文件 @
e88117e6
...
@@ -925,11 +925,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
...
@@ -925,11 +925,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
if
index
!=
output
.
type
.
ndim
:
if
index
!=
output
.
type
.
ndim
:
inner
=
op
.
__class__
(
len
(
new_output_shape
))(
inp
,
new_output_shape
)
inner
=
op
.
__class__
(
len
(
new_output_shape
))(
inp
,
new_output_shape
)
copy_stack_trace
(
output
,
inner
)
copy_stack_trace
(
output
,
inner
)
new_node
=
[
new_node
=
[
inner
.
dimshuffle
(
dimshuffle_new_order
)]
DimShuffle
(
tuple
(
s
==
1
for
s
in
inner
.
type
.
shape
),
dimshuffle_new_order
)(
inner
)
]
copy_stack_trace
(
output
,
new_node
)
copy_stack_trace
(
output
,
new_node
)
return
new_node
return
new_node
...
...
pytensor/tensor/variable.py
浏览文件 @
e88117e6
...
@@ -344,8 +344,8 @@ class _tensor_py_operators:
...
@@ -344,8 +344,8 @@ class _tensor_py_operators:
"""
"""
if
(
len
(
pattern
)
==
1
)
and
(
isinstance
(
pattern
[
0
],
list
|
tuple
)):
if
(
len
(
pattern
)
==
1
)
and
(
isinstance
(
pattern
[
0
],
list
|
tuple
)):
pattern
=
pattern
[
0
]
pattern
=
pattern
[
0
]
op
=
pt
.
elemwise
.
DimShuffle
(
list
(
self
.
type
.
broadcastable
),
pattern
)
ds_op
=
pt
.
elemwise
.
DimShuffle
(
input_ndim
=
self
.
type
.
ndim
,
new_order
=
pattern
)
return
op
(
self
)
return
ds_
op
(
self
)
def
flatten
(
self
,
ndim
=
1
):
def
flatten
(
self
,
ndim
=
1
):
return
pt
.
basic
.
flatten
(
self
,
ndim
)
return
pt
.
basic
.
flatten
(
self
,
ndim
)
...
...
tests/link/jax/test_elemwise.py
浏览文件 @
e88117e6
...
@@ -39,7 +39,7 @@ def test_jax_Dimshuffle():
...
@@ -39,7 +39,7 @@ def test_jax_Dimshuffle():
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
x
=
pt_elemwise
.
DimShuffle
(
[
False
,
True
],
(
0
,))(
a_pt
)
x
=
pt_elemwise
.
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
0
,))(
a_pt
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
...
...
tests/link/numba/test_elemwise.py
浏览文件 @
e88117e6
...
@@ -15,7 +15,7 @@ from pytensor.compile.sharedvalue import SharedVariable
...
@@ -15,7 +15,7 @@ from pytensor.compile.sharedvalue import SharedVariable
from
pytensor.gradient
import
grad
from
pytensor.gradient
import
grad
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
elemwise
as
pt_elemwis
e
from
pytensor.tensor
.elemwise
import
DimShuffl
e
from
pytensor.tensor.math
import
All
,
Any
,
Max
,
Mean
,
Min
,
Prod
,
ProdWithoutZeros
,
Sum
from
pytensor.tensor.math
import
All
,
Any
,
Max
,
Mean
,
Min
,
Prod
,
ProdWithoutZeros
,
Sum
from
pytensor.tensor.special
import
LogSoftmax
,
Softmax
,
SoftmaxGrad
from
pytensor.tensor.special
import
LogSoftmax
,
Softmax
,
SoftmaxGrad
from
tests.link.numba.test_basic
import
(
from
tests.link.numba.test_basic
import
(
...
@@ -205,7 +205,7 @@ def test_elemwise_speed(benchmark):
...
@@ -205,7 +205,7 @@ def test_elemwise_speed(benchmark):
],
],
)
)
def
test_Dimshuffle
(
v
,
new_order
):
def
test_Dimshuffle
(
v
,
new_order
):
g
=
pt_elemwise
.
DimShuffle
(
v
.
broadcastable
,
new_order
)(
v
)
g
=
v
.
dimshuffle
(
new_order
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
g_fg
,
...
@@ -219,7 +219,7 @@ def test_Dimshuffle(v, new_order):
...
@@ -219,7 +219,7 @@ def test_Dimshuffle(v, new_order):
def
test_Dimshuffle_returns_array
():
def
test_Dimshuffle_returns_array
():
x
=
pt
.
vector
(
"x"
,
shape
=
(
1
,))
x
=
pt
.
vector
(
"x"
,
shape
=
(
1
,))
y
=
2
*
pt_elemwise
.
DimShuffle
([
True
],
[])(
x
)
y
=
2
*
x
.
dimshuffle
([]
)
func
=
pytensor
.
function
([
x
],
y
,
mode
=
"NUMBA"
)
func
=
pytensor
.
function
([
x
],
y
,
mode
=
"NUMBA"
)
out
=
func
(
np
.
zeros
(
1
,
dtype
=
config
.
floatX
))
out
=
func
(
np
.
zeros
(
1
,
dtype
=
config
.
floatX
))
assert
out
.
ndim
==
0
assert
out
.
ndim
==
0
...
@@ -230,7 +230,7 @@ def test_Dimshuffle_non_contiguous():
...
@@ -230,7 +230,7 @@ def test_Dimshuffle_non_contiguous():
non-contiguous arrays, make sure we work around thpt."""
non-contiguous arrays, make sure we work around thpt."""
x
=
pt
.
dvector
()
x
=
pt
.
dvector
()
idx
=
pt
.
vector
(
dtype
=
"int64"
)
idx
=
pt
.
vector
(
dtype
=
"int64"
)
op
=
pytensor
.
tensor
.
elemwise
.
DimShuffle
([
True
],
[])
op
=
DimShuffle
(
input_ndim
=
1
,
new_order
=
[])
out
=
op
(
pt
.
specify_shape
(
x
[
idx
][::
2
],
(
1
,)))
out
=
op
(
pt
.
specify_shape
(
x
[
idx
][::
2
],
(
1
,)))
func
=
pytensor
.
function
([
x
,
idx
],
out
,
mode
=
"NUMBA"
)
func
=
pytensor
.
function
([
x
,
idx
],
out
,
mode
=
"NUMBA"
)
assert
func
(
np
.
zeros
(
3
),
np
.
array
([
1
]))
.
ndim
==
0
assert
func
(
np
.
zeros
(
3
),
np
.
array
([
1
]))
.
ndim
==
0
...
...
tests/link/pytorch/test_elemwise.py
浏览文件 @
e88117e6
...
@@ -5,7 +5,6 @@ import pytensor.tensor as pt
...
@@ -5,7 +5,6 @@ import pytensor.tensor as pt
import
pytensor.tensor.math
as
ptm
import
pytensor.tensor.math
as
ptm
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
elemwise
as
pt_elemwise
from
pytensor.tensor.special
import
SoftmaxGrad
,
log_softmax
,
softmax
from
pytensor.tensor.special
import
SoftmaxGrad
,
log_softmax
,
softmax
from
pytensor.tensor.type
import
matrix
,
tensor
,
tensor3
,
vector
from
pytensor.tensor.type
import
matrix
,
tensor
,
tensor3
,
vector
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
...
@@ -27,11 +26,6 @@ def test_pytorch_Dimshuffle():
...
@@ -27,11 +26,6 @@ def test_pytorch_Dimshuffle():
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_pytorch_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
x
=
pt_elemwise
.
DimShuffle
([
False
,
True
],
(
0
,))(
a_pt
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
def
test_multiple_input_output
():
def
test_multiple_input_output
():
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
...
...
tests/tensor/rewriting/test_elemwise.py
浏览文件 @
e88117e6
...
@@ -79,7 +79,7 @@ dimshuffle_lift = out2in(local_dimshuffle_lift)
...
@@ -79,7 +79,7 @@ dimshuffle_lift = out2in(local_dimshuffle_lift)
def
ds
(
x
,
y
):
def
ds
(
x
,
y
):
return
DimShuffle
(
x
.
type
.
broadcastable
,
y
)(
x
)
return
x
.
dimshuffle
(
y
)
def
inputs
(
xbc
=
(
0
,
0
),
ybc
=
(
0
,
0
),
zbc
=
(
0
,
0
)):
def
inputs
(
xbc
=
(
0
,
0
),
ybc
=
(
0
,
0
),
zbc
=
(
0
,
0
)):
...
...
tests/tensor/rewriting/test_math.py
浏览文件 @
e88117e6
...
@@ -160,7 +160,7 @@ _fast_run_rewrites = optdb.query(_fast_run_rewrites)
...
@@ -160,7 +160,7 @@ _fast_run_rewrites = optdb.query(_fast_run_rewrites)
def
ds
(
x
,
y
):
def
ds
(
x
,
y
):
return
DimShuffle
(
x
.
type
.
broadcastable
,
y
)(
x
)
return
x
.
dimshuffle
(
y
)
def
rewrite
(
g
,
level
=
"fast_run"
):
def
rewrite
(
g
,
level
=
"fast_run"
):
...
@@ -3749,7 +3749,7 @@ def test_local_log_sum_exp_maximum():
...
@@ -3749,7 +3749,7 @@ def test_local_log_sum_exp_maximum():
check_max_log_sum_exp
(
x
,
axis
=
(
0
,
1
,
2
),
dimshuffle_op
=
None
)
check_max_log_sum_exp
(
x
,
axis
=
(
0
,
1
,
2
),
dimshuffle_op
=
None
)
# If a transpose is applied to the sum
# If a transpose is applied to the sum
transpose_op
=
DimShuffle
(
(
False
,
False
),
(
1
,
0
))
transpose_op
=
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
1
,
0
))
check_max_log_sum_exp
(
x
,
axis
=
2
,
dimshuffle_op
=
transpose_op
)
check_max_log_sum_exp
(
x
,
axis
=
2
,
dimshuffle_op
=
transpose_op
)
# If the sum is performed with keepdims=True
# If the sum is performed with keepdims=True
...
@@ -3770,7 +3770,7 @@ def test_local_log_sum_exp_near_one():
...
@@ -3770,7 +3770,7 @@ def test_local_log_sum_exp_near_one():
assert
np
.
allclose
(
naive_ret
,
rewritten_ret
)
assert
np
.
allclose
(
naive_ret
,
rewritten_ret
)
# If a transpose is applied
# If a transpose is applied
transpose_op
=
DimShuffle
(
(
False
,
False
),
(
1
,
0
))
transpose_op
=
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
1
,
0
))
f
=
compile_graph_log_sum_exp
(
x
,
axis
=
(
1
,),
dimshuffle_op
=
transpose_op
)
f
=
compile_graph_log_sum_exp
(
x
,
axis
=
(
1
,),
dimshuffle_op
=
transpose_op
)
naive_ret
=
np
.
log
(
np
.
sum
(
np
.
exp
(
x_val
),
axis
=
1
)
.
T
)
naive_ret
=
np
.
log
(
np
.
sum
(
np
.
exp
(
x_val
),
axis
=
1
)
.
T
)
rewritten_ret
=
f
(
x_val
)
rewritten_ret
=
f
(
x_val
)
...
...
tests/tensor/test_basic.py
浏览文件 @
e88117e6
...
@@ -3418,7 +3418,7 @@ def test_unalign():
...
@@ -3418,7 +3418,7 @@ def test_unalign():
def
test_dimshuffle_duplicate
():
def
test_dimshuffle_duplicate
():
x
=
vector
()
x
=
vector
()
with
pytest
.
raises
(
ValueError
,
match
=
"may not appear twice"
):
with
pytest
.
raises
(
ValueError
,
match
=
"may not appear twice"
):
DimShuffle
(
(
False
,),
(
0
,
0
))(
x
)
DimShuffle
(
input_ndim
=
1
,
new_order
=
(
0
,
0
))(
x
)
class
TestGetUnderlyingScalarConstantValue
:
class
TestGetUnderlyingScalarConstantValue
:
...
...
tests/tensor/test_blas.py
浏览文件 @
e88117e6
...
@@ -593,9 +593,9 @@ class TestAsScalar:
...
@@ -593,9 +593,9 @@ class TestAsScalar:
b
=
pt
.
constant
(
np
.
asarray
([[[
0.5
]]]))
b
=
pt
.
constant
(
np
.
asarray
([[[
0.5
]]]))
b2
=
b
.
dimshuffle
()
b2
=
b
.
dimshuffle
()
assert
b2
.
ndim
==
0
assert
b2
.
ndim
==
0
d_a
=
DimShuffle
(
[],
[])(
a
)
d_a
=
DimShuffle
(
input_ndim
=
0
,
new_order
=
[])(
a
)
d_b
=
DimShuffle
(
[
True
,
True
,
True
],
[
0
,
2
,
1
])(
b
)
d_b
=
DimShuffle
(
input_ndim
=
3
,
new_order
=
[
0
,
2
,
1
])(
b
)
d_a2
=
DimShuffle
(
[],
[
"x"
,
"x"
,
"x"
])(
a
)
d_a2
=
DimShuffle
(
input_ndim
=
0
,
new_order
=
[
"x"
,
"x"
,
"x"
])(
a
)
assert
_as_scalar
(
a
)
==
a
assert
_as_scalar
(
a
)
==
a
assert
_as_scalar
(
b
)
!=
b
assert
_as_scalar
(
b
)
!=
b
...
@@ -607,13 +607,13 @@ class TestAsScalar:
...
@@ -607,13 +607,13 @@ class TestAsScalar:
# Test that it fails on nonscalar constants
# Test that it fails on nonscalar constants
a
=
pt
.
constant
(
np
.
ones
(
5
))
a
=
pt
.
constant
(
np
.
ones
(
5
))
assert
_as_scalar
(
a
)
is
None
assert
_as_scalar
(
a
)
is
None
assert
_as_scalar
(
DimShuffle
(
[
False
],
[
0
,
"x"
])(
a
))
is
None
assert
_as_scalar
(
DimShuffle
(
input_ndim
=
1
,
new_order
=
[
0
,
"x"
])(
a
))
is
None
def
test_basic_2
(
self
):
def
test_basic_2
(
self
):
# Test that it works on scalar variables
# Test that it works on scalar variables
a
=
dscalar
()
a
=
dscalar
()
d_a
=
DimShuffle
(
[],
[])(
a
)
d_a
=
DimShuffle
(
input_ndim
=
0
,
new_order
=
[])(
a
)
d_a2
=
DimShuffle
(
[],
[
"x"
,
"x"
])(
a
)
d_a2
=
DimShuffle
(
input_ndim
=
0
,
new_order
=
[
"x"
,
"x"
])(
a
)
assert
_as_scalar
(
a
)
is
a
assert
_as_scalar
(
a
)
is
a
assert
_as_scalar
(
d_a
)
is
a
assert
_as_scalar
(
d_a
)
is
a
...
@@ -623,13 +623,15 @@ class TestAsScalar:
...
@@ -623,13 +623,15 @@ class TestAsScalar:
# Test that it fails on nonscalar variables
# Test that it fails on nonscalar variables
a
=
matrix
()
a
=
matrix
()
assert
_as_scalar
(
a
)
is
None
assert
_as_scalar
(
a
)
is
None
assert
_as_scalar
(
DimShuffle
(
[
False
,
False
],
[
0
,
"x"
,
1
])(
a
))
is
None
assert
_as_scalar
(
DimShuffle
(
input_ndim
=
2
,
new_order
=
[
0
,
"x"
,
1
])(
a
))
is
None
class
TestRealMatrix
:
class
TestRealMatrix
:
def
test_basic
(
self
):
def
test_basic
(
self
):
assert
_is_real_matrix
(
DimShuffle
([
False
,
False
],
[
1
,
0
])(
matrix
()))
assert
_is_real_matrix
(
DimShuffle
(
input_ndim
=
2
,
new_order
=
[
1
,
0
])(
matrix
()))
assert
not
_is_real_matrix
(
DimShuffle
([
False
],
[
"x"
,
0
])(
dvector
()))
assert
not
_is_real_matrix
(
DimShuffle
(
input_ndim
=
1
,
new_order
=
[
"x"
,
0
])(
dvector
())
)
"""
"""
...
...
tests/tensor/test_elemwise.py
浏览文件 @
e88117e6
...
@@ -60,46 +60,40 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
...
@@ -60,46 +60,40 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
((
1
,),
(
"x"
,
"x"
),
(
1
,
1
)),
((
1
,),
(
"x"
,
"x"
),
(
1
,
1
)),
]:
]:
i_shape
=
[
entry
if
entry
==
1
else
None
for
entry
in
xsh
]
i_shape
=
[
entry
if
entry
==
1
else
None
for
entry
in
xsh
]
ib
=
[
entry
==
1
for
entry
in
i_shape
]
x
=
self
.
type
(
self
.
dtype
,
shape
=
i_shape
)(
"x"
)
x
=
self
.
type
(
self
.
dtype
,
shape
=
i_shape
)(
"x"
)
e
=
self
.
op
(
i
b
,
shuffle
)(
x
)
e
=
self
.
op
(
i
nput_ndim
=
len
(
i_shape
),
new_order
=
shuffle
)(
x
)
f
=
pytensor
.
function
([
x
],
e
,
mode
=
Mode
(
linker
=
linker
))
f
=
pytensor
.
function
([
x
],
e
,
mode
=
Mode
(
linker
=
linker
))
assert
f
(
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
))
.
shape
==
zsh
assert
f
(
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
))
.
shape
==
zsh
# test that DimShuffle.infer_shape work correctly
# test that DimShuffle.infer_shape work correctly
x
=
self
.
type
(
self
.
dtype
,
shape
=
i_shape
)(
"x"
)
x
=
self
.
type
(
self
.
dtype
,
shape
=
i_shape
)(
"x"
)
e
=
self
.
op
(
i
b
,
shuffle
)(
x
)
e
=
self
.
op
(
i
nput_ndim
=
len
(
i_shape
),
new_order
=
shuffle
)(
x
)
f
=
pytensor
.
function
(
f
=
pytensor
.
function
(
[
x
],
e
.
shape
,
mode
=
Mode
(
linker
=
linker
),
on_unused_input
=
"ignore"
[
x
],
e
.
shape
,
mode
=
Mode
(
linker
=
linker
),
on_unused_input
=
"ignore"
)
)
assert
all
(
f
(
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
)))
==
all
(
zsh
)
assert
all
(
f
(
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
)))
==
all
(
zsh
)
# Test when we drop a axis that is not broadcastable
# Test when we drop a axis that is not broadcastable
ib
=
[
False
,
True
,
False
]
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
2
,
1
,
None
))(
"x"
)
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
None
,
1
,
None
))(
"x"
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
ValueError
):
self
.
op
(
input_ndim
=
3
,
new_order
=
shuffle
)(
x
)
self
.
op
(
ib
,
shuffle
)
# Test when we drop a axis that don't have shape 1
# Test when we drop a axis that don't have shape 1
ib
=
[
True
,
True
,
False
]
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
None
,
1
,
None
))(
"x"
)
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
1
,
1
,
None
))(
"x"
)
e
=
self
.
op
(
input_ndim
=
3
,
new_order
=
(
1
,
2
))(
x
)
e
=
self
.
op
(
ib
,
(
1
,
2
))(
x
)
f
=
pytensor
.
function
([
x
],
e
,
mode
=
Mode
(
linker
=
linker
))
f
=
pytensor
.
function
([
x
],
e
.
shape
,
mode
=
Mode
(
linker
=
linker
))
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
TypeError
):
f
(
np
.
ones
((
2
,
1
,
4
),
dtype
=
self
.
dtype
))
f
(
np
.
ones
((
2
,
1
,
4
)))
# Test that we can't take a dimensions multiple time
# Test that we can't take a dimensions multiple time
xsh
,
shuffle
,
zsh
=
((
1
,
1
,
4
),
(
0
,
1
,
2
,
0
),
(
1
,
4
))
xsh
,
shuffle
,
zsh
=
((
1
,
1
,
4
),
(
0
,
1
,
2
,
0
),
(
1
,
4
))
ib
=
[
False
,
True
,
False
]
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
None
,
1
,
None
))(
"x"
)
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
None
,
1
,
None
))(
"x"
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
DimShuffle
(
i
b
,
shuffle
)
DimShuffle
(
i
nput_ndim
=
3
,
new_order
=
shuffle
)
def
test_perform
(
self
):
def
test_perform
(
self
):
self
.
with_linker
(
PerformLinker
())
self
.
with_linker
(
PerformLinker
())
def
test_c_or_py
(
self
):
def
test_c_or_py
(
self
):
# Shape op don't have C code.
# But This will test DimShuffle c code
self
.
with_linker
(
OpWiseCLinker
())
self
.
with_linker
(
OpWiseCLinker
())
def
test_infer_shape
(
self
):
def
test_infer_shape
(
self
):
...
@@ -115,12 +109,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
...
@@ -115,12 +109,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
((
1
,),
(
"x"
,
"x"
)),
((
1
,),
(
"x"
,
"x"
)),
]:
]:
i_shape
=
[
entry
if
entry
==
1
else
None
for
entry
in
xsh
]
i_shape
=
[
entry
if
entry
==
1
else
None
for
entry
in
xsh
]
ib
=
[(
entry
==
1
)
for
entry
in
xsh
]
adtens
=
self
.
type
(
self
.
dtype
,
shape
=
i_shape
)(
"x"
)
adtens
=
self
.
type
(
self
.
dtype
,
shape
=
i_shape
)(
"x"
)
adtens_val
=
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
)
adtens_val
=
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
)
self
.
_compile_and_check
(
self
.
_compile_and_check
(
[
adtens
],
[
adtens
],
[
self
.
op
(
i
b
,
shuffle
)(
adtens
)],
[
self
.
op
(
i
nput_ndim
=
len
(
xsh
),
new_order
=
shuffle
)(
adtens
)],
[
adtens_val
],
[
adtens_val
],
self
.
op
,
self
.
op
,
warn
=
False
,
warn
=
False
,
...
@@ -191,11 +184,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
...
@@ -191,11 +184,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
y
=
x
.
dimshuffle
([
0
,
1
,
"x"
])
y
=
x
.
dimshuffle
([
0
,
1
,
"x"
])
assert
y
.
type
.
shape
==
(
1
,
2
,
1
)
assert
y
.
type
.
shape
==
(
1
,
2
,
1
)
def
test_valid_input_
broadcastable
(
self
):
def
test_valid_input_
ndim
(
self
):
assert
DimShuffle
(
[
True
,
False
],
(
1
,
0
))
.
input_broadcastable
==
(
True
,
False
)
assert
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
1
,
0
))
.
input_ndim
==
2
with
pytest
.
raises
(
ValueError
,
match
=
"input_broadcastable must be boolean
"
):
with
pytest
.
raises
(
TypeError
,
match
=
"input_ndim must be an integer
"
):
DimShuffle
(
[
None
,
None
],
(
1
,
0
))
DimShuffle
(
input_ndim
=
(
True
,
False
),
new_order
=
(
1
,
0
))
class
TestBroadcast
:
class
TestBroadcast
:
...
...
tests/tensor/test_extra_ops.py
浏览文件 @
e88117e6
...
@@ -480,12 +480,9 @@ class TestSqueeze(utt.InferShapeTester):
...
@@ -480,12 +480,9 @@ class TestSqueeze(utt.InferShapeTester):
assert
f
([
0
])
==
0
assert
f
([
0
])
==
0
# Test that we cannot squeeze dimensions whose length is greater than 1
# Test that we cannot squeeze dimensions whose length is greater than 1
error_txt_1
=
re
.
escape
(
"SpecifyShape: Got shape (3,), expected (1,)."
)
error_txt_2
=
re
.
escape
(
"SpecifyShape: dim 0 of input has shape 3, expected 1"
)
match
=
error_txt_1
if
pytensor
.
config
.
mode
==
"FAST_COMPILE"
else
error_txt_2
with
pytest
.
raises
(
with
pytest
.
raises
(
Assertion
Error
,
Value
Error
,
match
=
match
,
match
=
"cannot reshape array of size 3 into shape ()"
,
):
):
f
([
0
,
1
,
2
])
f
([
0
,
1
,
2
])
...
...
tests/tensor/test_fft.py
浏览文件 @
e88117e6
...
@@ -204,3 +204,12 @@ class TestFFT:
...
@@ -204,3 +204,12 @@ class TestFFT:
pytensor
.
config
.
floatX
pytensor
.
config
.
floatX
)
)
utt
.
verify_grad
(
f_irfft
,
[
inputs_val
],
eps
=
eps
)
utt
.
verify_grad
(
f_irfft
,
[
inputs_val
],
eps
=
eps
)
def
test_rfft_expanded_dims_grad
(
self
):
# Regression test for https://github.com/pymc-devs/pytensor/issues/969
def
test_func
(
x
):
return
fft
.
rfft
(
x
[
None
,
:])
rng
=
np
.
random
.
default_rng
(
213
)
inputs_val
=
rng
.
random
((
N
,))
.
astype
(
pytensor
.
config
.
floatX
)
utt
.
verify_grad
(
test_func
,
[
inputs_val
],
rng
=
rng
)
tests/tensor/test_keepdims.py
浏览文件 @
e88117e6
...
@@ -4,7 +4,6 @@ import pytest
...
@@ -4,7 +4,6 @@ import pytest
import
pytensor
import
pytensor
from
pytensor
import
function
from
pytensor
import
function
from
pytensor.compile.mode
import
Mode
from
pytensor.compile.mode
import
Mode
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.math
import
any
as
pt_any
from
pytensor.tensor.math
import
any
as
pt_any
from
pytensor.tensor.math
import
argmax
,
argmin
,
max_and_argmax
,
mean
,
prod
,
std
,
var
from
pytensor.tensor.math
import
argmax
,
argmin
,
max_and_argmax
,
mean
,
prod
,
std
,
var
...
@@ -40,7 +39,7 @@ class TestKeepDims:
...
@@ -40,7 +39,7 @@ class TestKeepDims:
new_dims
.
append
(
i
)
new_dims
.
append
(
i
)
i
+=
1
i
+=
1
return
DimShuffle
(
y
.
type
.
broadcastable
,
new_dims
)(
y
)
return
y
.
dimshuffle
(
new_dims
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"axis"
,
"axis"
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论