Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e88117e6
提交
e88117e6
authored
8月 22, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
10月 08, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Only require input_ndim and not input_broadcastable in DimShuffle
上级
d68f53f8
隐藏空白字符变更
内嵌
并排
正在显示
24 个修改的文件
包含
132 行增加
和
181 行删除
+132
-181
sp.py
pytensor/sparse/sandbox/sp.py
+2
-3
basic.py
pytensor/tensor/basic.py
+2
-2
elemwise.py
pytensor/tensor/elemwise.py
+66
-95
extra_ops.py
pytensor/tensor/extra_ops.py
+1
-6
inplace.py
pytensor/tensor/inplace.py
+2
-2
math.py
pytensor/tensor/math.py
+1
-3
jax.py
pytensor/tensor/random/rewriting/jax.py
+1
-1
basic.py
pytensor/tensor/rewriting/basic.py
+1
-1
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+2
-4
jax.py
pytensor/tensor/rewriting/jax.py
+1
-1
linalg.py
pytensor/tensor/rewriting/linalg.py
+1
-1
shape.py
pytensor/tensor/rewriting/shape.py
+1
-5
variable.py
pytensor/tensor/variable.py
+2
-2
test_elemwise.py
tests/link/jax/test_elemwise.py
+1
-1
test_elemwise.py
tests/link/numba/test_elemwise.py
+4
-4
test_elemwise.py
tests/link/pytorch/test_elemwise.py
+0
-6
test_elemwise.py
tests/tensor/rewriting/test_elemwise.py
+1
-1
test_math.py
tests/tensor/rewriting/test_math.py
+3
-3
test_basic.py
tests/tensor/test_basic.py
+1
-1
test_blas.py
tests/tensor/test_blas.py
+11
-9
test_elemwise.py
tests/tensor/test_elemwise.py
+16
-23
test_extra_ops.py
tests/tensor/test_extra_ops.py
+2
-5
test_fft.py
tests/tensor/test_fft.py
+9
-0
test_keepdims.py
tests/tensor/test_keepdims.py
+1
-2
没有找到文件。
pytensor/sparse/sandbox/sp.py
浏览文件 @
e88117e6
...
@@ -19,7 +19,6 @@ from pytensor.graph.op import Op
...
@@ -19,7 +19,6 @@ from pytensor.graph.op import Op
from
pytensor.tensor.math
import
dot
from
pytensor.tensor.math
import
dot
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.shape
import
reshape
from
pytensor.tensor.shape
import
reshape
from
pytensor.tensor.subtensor
import
DimShuffle
def
register_specialize
(
lopt
,
*
tags
,
**
kwargs
):
def
register_specialize
(
lopt
,
*
tags
,
**
kwargs
):
...
@@ -375,7 +374,7 @@ def convolve(
...
@@ -375,7 +374,7 @@ def convolve(
[
images
.
shape
[
0
],
pt
.
as_tensor
(
np
.
prod
(
outshp
)),
pt
.
as_tensor
(
nkern
)]
[
images
.
shape
[
0
],
pt
.
as_tensor
(
np
.
prod
(
outshp
)),
pt
.
as_tensor
(
nkern
)]
)
)
tensout
=
reshape
(
output
,
newshp
,
ndim
=
3
)
tensout
=
reshape
(
output
,
newshp
,
ndim
=
3
)
output
=
DimShuffle
((
False
,)
*
tensout
.
ndim
,
(
0
,
2
,
1
))(
tensout
)
output
=
tensout
.
transpose
(
0
,
2
,
1
)
if
flatten
:
if
flatten
:
output
=
pt
.
flatten
(
output
,
2
)
output
=
pt
.
flatten
(
output
,
2
)
...
@@ -443,6 +442,6 @@ def max_pool(images, imgshp, maxpoolshp):
...
@@ -443,6 +442,6 @@ def max_pool(images, imgshp, maxpoolshp):
)
)
out2
=
reshape
(
out1
,
pshape
,
ndim
=
3
)
out2
=
reshape
(
out1
,
pshape
,
ndim
=
3
)
out3
=
DimShuffle
(
out2
.
broadcastable
,
(
0
,
2
,
1
))(
out2
)
out3
=
out2
.
transpose
(
0
,
2
,
1
)
return
pt
.
flatten
(
out3
,
2
),
outshp
return
pt
.
flatten
(
out3
,
2
),
outshp
pytensor/tensor/basic.py
浏览文件 @
e88117e6
...
@@ -2042,7 +2042,7 @@ def transpose(x, axes=None):
...
@@ -2042,7 +2042,7 @@ def transpose(x, axes=None):
# No-op
# No-op
return
_x
return
_x
ret
=
DimShuffle
(
tuple
(
s
==
1
for
s
in
_x
.
type
.
shape
),
axes
)(
_x
)
ret
=
_x
.
dimshuffle
(
axes
)
if
_x
.
name
and
axes
==
tuple
(
range
((
_x
.
type
.
ndim
-
1
),
-
1
,
-
1
)):
if
_x
.
name
and
axes
==
tuple
(
range
((
_x
.
type
.
ndim
-
1
),
-
1
,
-
1
)):
ret
.
name
=
_x
.
name
+
".T"
ret
.
name
=
_x
.
name
+
".T"
...
@@ -3518,7 +3518,7 @@ class PermuteRowElements(Op):
...
@@ -3518,7 +3518,7 @@ class PermuteRowElements(Op):
newdims
.
append
(
i
)
newdims
.
append
(
i
)
i
+=
1
i
+=
1
gx
=
DimShuffle
(
tuple
(
s
==
1
for
s
in
gx
.
type
.
shape
),
newdims
)(
gx
)
gx
=
gx
.
dimshuffle
(
newdims
)
assert
gx
.
type
.
ndim
==
x
.
type
.
ndim
assert
gx
.
type
.
ndim
==
x
.
type
.
ndim
assert
all
(
assert
all
(
s1
==
s2
s1
==
s2
...
...
pytensor/tensor/elemwise.py
浏览文件 @
e88117e6
from
collections.abc
import
Sequence
from
copy
import
copy
from
copy
import
copy
from
textwrap
import
dedent
from
textwrap
import
dedent
from
typing
import
Literal
import
numpy
as
np
import
numpy
as
np
from
numpy.core.numeric
import
normalize_axis_tuple
from
numpy.core.numeric
import
normalize_axis_tuple
...
@@ -54,15 +56,14 @@ class DimShuffle(ExternalCOp):
...
@@ -54,15 +56,14 @@ class DimShuffle(ExternalCOp):
Parameters
Parameters
----------
----------
input_
broadcastable
input_
ndim
The expected
broadcastable patter
n of the input
The expected
number of dimensio
n of the input
new_order
new_order
A list representing the relationship between the input's
A list representing the relationship between the input's
dimensions and the output's dimensions. Each element of the
dimensions and the output's dimensions. Each element of the
list can either be an index or 'x'. Indices must be encoded
list can either be an index or 'x'. Indices must be encoded
as python integers, not pytensor symbolic integers.
as python integers, not pytensor symbolic integers.
inplace : bool, optional
Missing indexes correspond to drop dimensions.
If True (default), the output will be a view of the input.
Notes
Notes
-----
-----
...
@@ -77,10 +78,10 @@ class DimShuffle(ExternalCOp):
...
@@ -77,10 +78,10 @@ class DimShuffle(ExternalCOp):
.. code-block:: python
.. code-block:: python
DimShuffle(
(False, False, False),
["x", 2, "x", 0, 1])
DimShuffle(
input_ndim=3, new_order=
["x", 2, "x", 0, 1])
This `Op` will only work on 3d tensors
with no broadcastable
This `Op` will only work on 3d tensors
.
dimensions. The first dimension
will be broadcastable,
The first dimension of the output
will be broadcastable,
then we will have the third dimension of the input tensor as
then we will have the third dimension of the input tensor as
the second of the resulting tensor, etc. If the tensor has
the second of the resulting tensor, etc. If the tensor has
shape (20, 30, 40), the resulting tensor will have dimensions
shape (20, 30, 40), the resulting tensor will have dimensions
...
@@ -88,39 +89,36 @@ class DimShuffle(ExternalCOp):
...
@@ -88,39 +89,36 @@ class DimShuffle(ExternalCOp):
.. code-block:: python
.. code-block:: python
DimShuffle(
(True, False),
[1])
DimShuffle(
input_ndim=2, new_order=
[1])
This `Op` will only work on 2d tensors with the first dimension
This `Op` will only work on 2d tensors with the first dimension broadcastable.
broadcastable.
The second dimension of the input tensor will be the first dimension of the resulting tensor.
The second dimension of the input tensor will be the first dimension of
If the tensor has shape (1, 20), the resulting tensor will have shape (20, ).
the resulting tensor.
If the tensor has shape (1, 20), the resulting tensor will have shape
(20, ).
Examples
Examples
--------
--------
.. code-block:: python
.. code-block:: python
DimShuffle((), ["x"]) # make a 0d (scalar) into a 1d vector
DimShuffle(input_ndim=0, new_order=["x"]) # make a 0d (scalar) into a 1d vector
DimShuffle((False, False), [0, 1]) # identity
DimShuffle(input_ndim=2, new_order=[0, 1]) # identity
DimShuffle((False, False), [1, 0]) # inverts the 1st and 2nd dimensions
DimShuffle(input_ndim=2, new_order=[1, 0]) # transposition
DimShuffle((False,), ["x", 0]) # make a row out of a 1d vector
# Make a row out of a 1d vector (N to 1xN)
# (N to 1xN)
DimShuffle(input_ndim=1, new_order=["x", 0])
DimShuffle((False,), [0, "x"]) # make a column out of a 1d vector
# Make a colum out of a 1d vector (N to Nx1)
# (N to Nx1)
DimShuffle(input_ndim=1, new_order=[0, "x"])
DimShuffle((False, False, False), [2, 0, 1]) # AxBxC to CxAxB
DimShuffle(input_ndim=3, new_order=[2, 0, 1]) # AxBxC to CxAxB
DimShuffle((False, False), [0, "x", 1]) # AxB to Ax1xB
DimShuffle(input_ndim=2, new_order=[0, "x", 1]) # AxB to Ax1xB
DimShuffle((False, False), [1, "x", 0]) # AxB to Bx1xA
DimShuffle(input_ndim=2, new_order=[1, "x", 0]) # AxB to Bx1xA
The reordering of the dimensions can be done with the numpy.transpose
function.
Adding, subtracting dimensions can be done with reshape.
Notes
-----
The python implementation of this Op combines numpy.transpose for reordering of the dimensions
and numpy.reshape for subtracting and adding broadcastable dimensions.
"""
"""
_f16_ok
=
True
_f16_ok
=
True
check_input
=
False
check_input
=
False
__props__
=
(
"input_
broadcastable
"
,
"new_order"
,
"inplace"
)
__props__
=
(
"input_
ndim
"
,
"new_order"
,
"inplace"
)
c_func_file
=
"c_code/dimshuffle.c"
c_func_file
=
"c_code/dimshuffle.c"
c_func_name
=
"APPLY_SPECIFIC(cpu_dimshuffle)"
c_func_name
=
"APPLY_SPECIFIC(cpu_dimshuffle)"
...
@@ -133,16 +131,14 @@ class DimShuffle(ExternalCOp):
...
@@ -133,16 +131,14 @@ class DimShuffle(ExternalCOp):
inplace
=
scalar_bool
,
inplace
=
scalar_bool
,
)
)
def
__init__
(
self
,
input_broadcastable
,
new_order
):
def
__init__
(
self
,
*
,
input_ndim
:
int
,
new_order
:
Sequence
[
int
|
Literal
[
"x"
]]
):
super
()
.
__init__
([
self
.
c_func_file
],
self
.
c_func_name
)
super
()
.
__init__
([
self
.
c_func_file
],
self
.
c_func_name
)
self
.
input_broadcastable
=
tuple
(
input_broadcastable
)
if
not
isinstance
(
input_ndim
,
int
):
if
not
all
(
isinstance
(
bs
,
bool
|
np
.
bool_
)
for
bs
in
self
.
input_broadcastable
):
raise
TypeError
(
f
"input_ndim must be an integer, got {type(int)}"
)
raise
ValueError
(
f
"input_broadcastable must be boolean, {self.input_broadcastable}"
)
self
.
new_order
=
tuple
(
new_order
)
self
.
input_ndim
=
input_ndim
self
.
new_order
=
tuple
(
new_order
)
self
.
inplace
=
True
self
.
inplace
=
True
for
i
,
j
in
enumerate
(
new_order
):
for
i
,
j
in
enumerate
(
new_order
):
...
@@ -152,10 +148,10 @@ class DimShuffle(ExternalCOp):
...
@@ -152,10 +148,10 @@ class DimShuffle(ExternalCOp):
"DimShuffle indices must be Python ints; got "
"DimShuffle indices must be Python ints; got "
f
"{j} of type {type(j)}."
f
"{j} of type {type(j)}."
)
)
if
j
>=
len
(
input_broadcastable
)
:
if
j
>=
input_ndim
:
raise
ValueError
(
raise
ValueError
(
f
"new_order[{i}] is {j}, but the input only has "
f
"new_order[{i}] is {j}, but the input only has "
f
"{
len(input_broadcastable)
} axes."
f
"{
input_ndim
} axes."
)
)
if
j
in
new_order
[(
i
+
1
)
:]:
if
j
in
new_order
[(
i
+
1
)
:]:
raise
ValueError
(
raise
ValueError
(
...
@@ -164,19 +160,7 @@ class DimShuffle(ExternalCOp):
...
@@ -164,19 +160,7 @@ class DimShuffle(ExternalCOp):
)
)
# List of input dimensions to drop
# List of input dimensions to drop
drop
=
[]
drop
=
[
i
for
i
in
range
(
input_ndim
)
if
i
not
in
new_order
]
for
i
,
b
in
enumerate
(
input_broadcastable
):
if
i
not
in
new_order
:
# We want to drop this dimension because it's not a value in
# `new_order`
if
b
==
1
:
drop
.
append
(
i
)
else
:
# We cannot drop non-broadcastable dimensions
raise
ValueError
(
"Cannot drop a non-broadcastable dimension: "
f
"{input_broadcastable}, {new_order}"
)
# This is the list of the original dimensions that we keep
# This is the list of the original dimensions that we keep
self
.
shuffle
=
[
x
for
x
in
new_order
if
x
!=
"x"
]
self
.
shuffle
=
[
x
for
x
in
new_order
if
x
!=
"x"
]
...
@@ -186,7 +170,6 @@ class DimShuffle(ExternalCOp):
...
@@ -186,7 +170,6 @@ class DimShuffle(ExternalCOp):
self
.
augment
=
sorted
(
i
for
i
,
x
in
enumerate
(
new_order
)
if
x
==
"x"
)
self
.
augment
=
sorted
(
i
for
i
,
x
in
enumerate
(
new_order
)
if
x
==
"x"
)
self
.
drop
=
drop
self
.
drop
=
drop
input_ndim
=
len
(
input_broadcastable
)
self
.
is_left_expand_dims
=
self
.
augment
and
(
self
.
is_left_expand_dims
=
self
.
augment
and
(
input_ndim
==
0
or
new_order
[
-
input_ndim
:]
==
list
(
range
(
input_ndim
))
input_ndim
==
0
or
new_order
[
-
input_ndim
:]
==
list
(
range
(
input_ndim
))
)
)
...
@@ -204,30 +187,29 @@ class DimShuffle(ExternalCOp):
...
@@ -204,30 +187,29 @@ class DimShuffle(ExternalCOp):
# Let's just build the ExternalCOp.
# Let's just build the ExternalCOp.
super
()
.
__init__
([
self
.
c_func_file
],
self
.
c_func_name
)
super
()
.
__init__
([
self
.
c_func_file
],
self
.
c_func_name
)
def
make_node
(
self
,
_input
):
def
make_node
(
self
,
inp
):
input
=
as_tensor_variable
(
_input
)
input
=
as_tensor_variable
(
inp
)
ib
=
tuple
(
s
==
1
for
s
in
input
.
type
.
shape
)
if
input
.
type
.
ndim
!=
self
.
input_ndim
:
if
ib
!=
self
.
input_broadcastable
:
raise
TypeError
(
if
len
(
ib
)
!=
len
(
self
.
input_broadcastable
):
"The number of dimensions of the input is incorrect for this op. "
f
"Expected {self.input_ndim}, got {input.type.ndim}."
)
input_static_shape
=
input
.
type
.
shape
# Runtime check for invalid drop
for
d
in
self
.
drop
:
if
input_static_shape
[
d
]
not
in
(
1
,
None
):
raise
TypeError
(
raise
TypeError
(
"The number of dimensions of the "
f
"Input dropped dimension {d} must have length 1 but has {input_static_shape[d]}"
f
"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
)
)
for
expected
,
b
in
zip
(
self
.
input_broadcastable
,
ib
):
if
expected
and
not
b
:
raise
TypeError
(
"The broadcastable pattern of the "
f
"input is incorrect for this op. Expected {self.input_broadcastable}, got {ib}."
)
# else, expected == b or not expected and b
# Both case are good.
out_static_shape
=
[]
out_static_shape
=
[]
for
dim_idx
in
self
.
new_order
:
for
dim_idx
in
self
.
new_order
:
if
dim_idx
==
"x"
:
if
dim_idx
==
"x"
:
out_static_shape
.
append
(
1
)
out_static_shape
.
append
(
1
)
else
:
else
:
out_static_shape
.
append
(
input
.
type
.
shape
[
dim_idx
])
out_static_shape
.
append
(
input
_static_
shape
[
dim_idx
])
output
=
TensorType
(
dtype
=
input
.
type
.
dtype
,
shape
=
out_static_shape
)()
output
=
TensorType
(
dtype
=
input
.
type
.
dtype
,
shape
=
out_static_shape
)()
...
@@ -254,12 +236,14 @@ class DimShuffle(ExternalCOp):
...
@@ -254,12 +236,14 @@ class DimShuffle(ExternalCOp):
if
not
isinstance
(
res
,
np
.
ndarray
|
np
.
memmap
):
if
not
isinstance
(
res
,
np
.
ndarray
|
np
.
memmap
):
raise
TypeError
(
res
)
raise
TypeError
(
res
)
# Put dropped axis at end
res
=
res
.
transpose
(
self
.
transposition
)
res
=
res
.
transpose
(
self
.
transposition
)
shape
=
list
(
res
.
shape
[:
len
(
self
.
shuffle
)])
# Define new shape without dropped axis and including new ones
new_shape
=
list
(
res
.
shape
[:
len
(
self
.
shuffle
)])
for
augm
in
self
.
augment
:
for
augm
in
self
.
augment
:
shape
.
insert
(
augm
,
1
)
new_
shape
.
insert
(
augm
,
1
)
res
=
res
.
reshape
(
shape
)
res
=
res
.
reshape
(
new_
shape
)
if
not
self
.
inplace
:
if
not
self
.
inplace
:
res
=
np
.
copy
(
res
)
res
=
np
.
copy
(
res
)
...
@@ -284,22 +268,15 @@ class DimShuffle(ExternalCOp):
...
@@ -284,22 +268,15 @@ class DimShuffle(ExternalCOp):
def
grad
(
self
,
inp
,
grads
):
def
grad
(
self
,
inp
,
grads
):
(
x
,)
=
inp
(
x
,)
=
inp
(
gz
,)
=
grads
(
gz
,)
=
grads
gz
=
as_tensor_variable
(
gz
)
grad_order
=
[
"x"
]
*
x
.
type
.
ndim
grad_order
=
[
"x"
]
*
x
.
type
.
ndim
for
i
,
v
in
enumerate
(
self
.
new_order
):
for
i
,
v
in
enumerate
(
self
.
new_order
):
if
v
!=
"x"
:
if
v
!=
"x"
:
grad_order
[
v
]
=
i
grad_order
[
v
]
=
i
# Do not make the DimShuffle inplace as an optimization at the
# canonicalization optimization phase will remove the inplace.
if
x
.
type
.
dtype
in
discrete_dtypes
:
# The inplace will be reintroduced automatically later in the graph.
return
[
x
.
zeros_like
(
dtype
=
config
.
floatX
)]
if
inp
[
0
]
.
dtype
in
discrete_dtypes
:
return
[
inp
[
0
]
.
zeros_like
(
dtype
=
config
.
floatX
)]
else
:
else
:
return
[
return
[
gz
.
dimshuffle
(
grad_order
)]
DimShuffle
(
tuple
(
s
==
1
for
s
in
gz
.
type
.
shape
),
grad_order
)(
Elemwise
(
scalar_identity
)(
gz
)
)
]
class
DimShufflePrinter
(
Printer
):
class
DimShufflePrinter
(
Printer
):
...
@@ -409,7 +386,7 @@ class Elemwise(OpenMPOp):
...
@@ -409,7 +386,7 @@ class Elemwise(OpenMPOp):
self
.
nfunc
=
None
self
.
nfunc
=
None
self
.
inplace_pattern
=
frozendict
(
self
.
inplace_pattern
)
self
.
inplace_pattern
=
frozendict
(
self
.
inplace_pattern
)
def
get_output_info
(
self
,
dim_shuffle
,
*
inputs
):
def
get_output_info
(
self
,
*
inputs
):
"""Return the outputs dtype and broadcastable pattern and the
"""Return the outputs dtype and broadcastable pattern and the
dimshuffled inputs.
dimshuffled inputs.
...
@@ -427,12 +404,7 @@ class Elemwise(OpenMPOp):
...
@@ -427,12 +404,7 @@ class Elemwise(OpenMPOp):
if
not
difference
:
if
not
difference
:
args
.
append
(
input
)
args
.
append
(
input
)
else
:
else
:
args
.
append
(
args
.
append
(
input
.
dimshuffle
([
"x"
]
*
difference
+
list
(
range
(
length
))))
dim_shuffle
(
input
.
type
.
broadcastable
,
[
"x"
]
*
difference
+
list
(
range
(
length
)),
)(
input
)
)
inputs
=
args
inputs
=
args
# HERE: all the broadcast dims have the same length now
# HERE: all the broadcast dims have the same length now
...
@@ -489,7 +461,7 @@ class Elemwise(OpenMPOp):
...
@@ -489,7 +461,7 @@ class Elemwise(OpenMPOp):
using DimShuffle.
using DimShuffle.
"""
"""
inputs
=
[
as_tensor_variable
(
i
)
for
i
in
inputs
]
inputs
=
[
as_tensor_variable
(
i
)
for
i
in
inputs
]
out_dtypes
,
out_shapes
,
inputs
=
self
.
get_output_info
(
DimShuffle
,
*
inputs
)
out_dtypes
,
out_shapes
,
inputs
=
self
.
get_output_info
(
*
inputs
)
outputs
=
[
outputs
=
[
TensorType
(
dtype
=
dtype
,
shape
=
shape
)()
TensorType
(
dtype
=
dtype
,
shape
=
shape
)()
for
dtype
,
shape
in
zip
(
out_dtypes
,
out_shapes
)
for
dtype
,
shape
in
zip
(
out_dtypes
,
out_shapes
)
...
@@ -634,7 +606,7 @@ class Elemwise(OpenMPOp):
...
@@ -634,7 +606,7 @@ class Elemwise(OpenMPOp):
res
=
pytensor
.
tensor
.
basic
.
constant
(
res
=
pytensor
.
tensor
.
basic
.
constant
(
np
.
asarray
(
r
.
data
),
dtype
=
r
.
type
.
dtype
np
.
asarray
(
r
.
data
),
dtype
=
r
.
type
.
dtype
)
)
return
DimShuffle
((),
[
"x"
]
*
nd
)(
res
)
return
res
.
dimshuffle
([
"x"
]
*
nd
)
new_r
=
Elemwise
(
node
.
op
,
{})(
*
[
transform
(
ipt
)
for
ipt
in
node
.
inputs
])
new_r
=
Elemwise
(
node
.
op
,
{})(
*
[
transform
(
ipt
)
for
ipt
in
node
.
inputs
])
if
isinstance
(
new_r
,
list
|
tuple
):
if
isinstance
(
new_r
,
list
|
tuple
):
...
@@ -1707,13 +1679,12 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
...
@@ -1707,13 +1679,12 @@ def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Appl
batched_ndims
=
x
.
type
.
ndim
-
node
.
inputs
[
0
]
.
type
.
ndim
batched_ndims
=
x
.
type
.
ndim
-
node
.
inputs
[
0
]
.
type
.
ndim
if
not
batched_ndims
:
if
not
batched_ndims
:
return
node
.
op
.
make_node
(
x
)
return
node
.
op
.
make_node
(
x
)
input_broadcastable
=
x
.
type
.
broadcastable
[:
batched_ndims
]
+
op
.
input_broadcastable
# e.g., ds(input_ndim=2, order=(1, "x", 0)) -> ds(input_ndim=4, order=(0, 1, 3, "x", 2))
# e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2))
# e.g., ds(input_ndim=2, order=(1, "x")) -> ds(input_ndim=4, order=(0, 1, 3, "x"))
# e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x"))
new_order
=
list
(
range
(
batched_ndims
))
+
[
new_order
=
list
(
range
(
batched_ndims
))
+
[
"x"
if
(
o
==
"x"
)
else
(
o
+
batched_ndims
)
for
o
in
op
.
new_order
"x"
if
(
o
==
"x"
)
else
(
o
+
batched_ndims
)
for
o
in
op
.
new_order
]
]
return
DimShuffle
(
input_broadcastable
,
new_order
)
.
make_node
(
x
)
return
x
.
dimshuffle
(
new_order
)
.
owner
def
get_normalized_batch_axes
(
def
get_normalized_batch_axes
(
...
...
pytensor/tensor/extra_ops.py
浏览文件 @
e88117e6
...
@@ -41,7 +41,7 @@ from pytensor.tensor.math import (
...
@@ -41,7 +41,7 @@ from pytensor.tensor.math import (
)
)
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.math
import
sum
as
pt_sum
from
pytensor.tensor.math
import
sum
as
pt_sum
from
pytensor.tensor.shape
import
Shape_i
,
specify_broadcastable
from
pytensor.tensor.shape
import
Shape_i
from
pytensor.tensor.subtensor
import
advanced_inc_subtensor1
,
set_subtensor
from
pytensor.tensor.subtensor
import
advanced_inc_subtensor1
,
set_subtensor
from
pytensor.tensor.type
import
TensorType
,
dvector
,
int_dtypes
,
integer_dtypes
,
vector
from
pytensor.tensor.type
import
TensorType
,
dvector
,
int_dtypes
,
integer_dtypes
,
vector
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.tensor.variable
import
TensorVariable
...
@@ -609,11 +609,6 @@ def squeeze(x, axis=None):
...
@@ -609,11 +609,6 @@ def squeeze(x, axis=None):
# Nothing could be squeezed
# Nothing could be squeezed
return
_x
return
_x
# `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable.
# We add a `specify_broadcastable` instead of raising.
non_broadcastable_axis
=
[
i
for
i
in
axis
if
not
_x
.
broadcastable
[
i
]]
_x
=
specify_broadcastable
(
_x
,
*
non_broadcastable_axis
)
return
_x
.
dimshuffle
([
i
for
i
in
range
(
_x
.
ndim
)
if
i
not
in
axis
])
return
_x
.
dimshuffle
([
i
for
i
in
range
(
_x
.
ndim
)
if
i
not
in
axis
])
...
...
pytensor/tensor/inplace.py
浏览文件 @
e88117e6
from
pytensor
import
printing
from
pytensor
import
printing
from
pytensor.printing
import
pprint
from
pytensor.printing
import
pprint
from
pytensor.tensor.elemwise
import
DimShuffle
,
scalar_elemwise
from
pytensor.tensor.elemwise
import
scalar_elemwise
@scalar_elemwise
@scalar_elemwise
...
@@ -429,4 +429,4 @@ pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right"))
...
@@ -429,4 +429,4 @@ pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right"))
def
transpose_inplace
(
x
,
**
kwargs
):
def
transpose_inplace
(
x
,
**
kwargs
):
"Perform a transpose on a tensor without copying the underlying storage"
"Perform a transpose on a tensor without copying the underlying storage"
dims
=
list
(
range
(
x
.
ndim
-
1
,
-
1
,
-
1
))
dims
=
list
(
range
(
x
.
ndim
-
1
,
-
1
,
-
1
))
return
DimShuffle
(
x
.
broadcastable
,
dims
)(
x
)
return
x
.
dimshuffle
(
dims
)
pytensor/tensor/math.py
浏览文件 @
e88117e6
...
@@ -33,7 +33,6 @@ from pytensor.tensor.basic import (
...
@@ -33,7 +33,6 @@ from pytensor.tensor.basic import (
from
pytensor.tensor.blockwise
import
Blockwise
,
vectorize_node_fallback
from
pytensor.tensor.blockwise
import
Blockwise
,
vectorize_node_fallback
from
pytensor.tensor.elemwise
import
(
from
pytensor.tensor.elemwise
import
(
CAReduce
,
CAReduce
,
DimShuffle
,
Elemwise
,
Elemwise
,
get_normalized_batch_axes
,
get_normalized_batch_axes
,
scalar_elemwise
,
scalar_elemwise
,
...
@@ -2338,8 +2337,7 @@ class Sum(FixedOpCAReduce):
...
@@ -2338,8 +2337,7 @@ class Sum(FixedOpCAReduce):
else
:
else
:
new_dims
.
append
(
i
)
new_dims
.
append
(
i
)
i
+=
1
i
+=
1
ds_op
=
DimShuffle
(
gz
.
type
.
broadcastable
,
new_dims
)
gx
=
Elemwise
(
ps
.
second
)(
x
,
gz
.
dimshuffle
(
new_dims
))
gx
=
Elemwise
(
ps
.
second
)(
x
,
ds_op
(
gz
))
return
[
gx
]
return
[
gx
]
def
R_op
(
self
,
inputs
,
eval_points
):
def
R_op
(
self
,
inputs
,
eval_points
):
...
...
pytensor/tensor/random/rewriting/jax.py
浏览文件 @
e88117e6
...
@@ -65,7 +65,7 @@ def size_parameter_as_tuple(fgraph, node):
...
@@ -65,7 +65,7 @@ def size_parameter_as_tuple(fgraph, node):
if
isinstance
(
size_node
.
op
,
MakeVector
)
or
(
if
isinstance
(
size_node
.
op
,
MakeVector
)
or
(
isinstance
(
size_node
.
op
,
DimShuffle
)
isinstance
(
size_node
.
op
,
DimShuffle
)
and
size_node
.
op
.
input_
broadcastable
==
()
and
size_node
.
op
.
input_
ndim
==
0
and
size_node
.
op
.
new_order
==
(
"x"
,)
and
size_node
.
op
.
new_order
==
(
"x"
,)
):
):
# Here PyTensor converted a tuple or list to a tensor
# Here PyTensor converted a tuple or list to a tensor
...
...
pytensor/tensor/rewriting/basic.py
浏览文件 @
e88117e6
...
@@ -494,7 +494,7 @@ def local_alloc_sink_dimshuffle(fgraph, node):
...
@@ -494,7 +494,7 @@ def local_alloc_sink_dimshuffle(fgraph, node):
dimshuffle_new_order
=
[
"x"
]
*
num_dims_with_size_1_added_to_left
+
list
(
dimshuffle_new_order
=
[
"x"
]
*
num_dims_with_size_1_added_to_left
+
list
(
range
(
len
(
new_output_shape
))
range
(
len
(
new_output_shape
))
)
)
return
[
DimShuffle
(
inner
.
type
.
broadcastable
,
dimshuffle_new_order
)(
inn
er
)]
return
[
inner
.
dimshuffle
(
dimshuffle_new_ord
er
)]
@node_rewriter
([
AllocEmpty
])
@node_rewriter
([
AllocEmpty
])
...
...
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
e88117e6
...
@@ -422,8 +422,6 @@ def local_dimshuffle_lift(fgraph, node):
...
@@ -422,8 +422,6 @@ def local_dimshuffle_lift(fgraph, node):
"""
"""
op
=
node
.
op
op
=
node
.
op
if
not
isinstance
(
op
,
DimShuffle
):
return
False
inp
=
node
.
inputs
[
0
]
inp
=
node
.
inputs
[
0
]
inode
=
inp
.
owner
inode
=
inp
.
owner
...
@@ -437,7 +435,7 @@ def local_dimshuffle_lift(fgraph, node):
...
@@ -437,7 +435,7 @@ def local_dimshuffle_lift(fgraph, node):
# Don't use make_node to have tag.test_value set.
# Don't use make_node to have tag.test_value set.
new_inputs
=
[]
new_inputs
=
[]
for
inp
in
inode
.
inputs
:
for
inp
in
inode
.
inputs
:
new_inp
=
op
.
__class__
(
inp
.
type
.
broadcastable
,
op
.
new_order
)(
inp
)
new_inp
=
inp
.
dimshuffle
(
op
.
new_order
)
new_inputs
.
append
(
apply_local_dimshuffle_lift
(
fgraph
,
new_inp
))
new_inputs
.
append
(
apply_local_dimshuffle_lift
(
fgraph
,
new_inp
))
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
ret
=
inode
.
op
(
*
new_inputs
,
return_list
=
True
)
ret
=
inode
.
op
(
*
new_inputs
,
return_list
=
True
)
...
@@ -449,7 +447,7 @@ def local_dimshuffle_lift(fgraph, node):
...
@@ -449,7 +447,7 @@ def local_dimshuffle_lift(fgraph, node):
if
is_dimshuffle_useless
(
new_order
,
inp
):
if
is_dimshuffle_useless
(
new_order
,
inp
):
return
[
inp
]
return
[
inp
]
elif
inode
and
isinstance
(
inode
.
op
,
DimShuffle
):
elif
inode
and
isinstance
(
inode
.
op
,
DimShuffle
):
ret
=
op
.
__class__
(
inp
.
type
.
broadcastable
,
new_order
)(
inp
)
ret
=
inp
.
dimshuffle
(
new_order
)
ret
=
apply_local_dimshuffle_lift
(
fgraph
,
ret
)
ret
=
apply_local_dimshuffle_lift
(
fgraph
,
ret
)
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
return
[
ret
]
return
[
ret
]
...
...
pytensor/tensor/rewriting/jax.py
浏览文件 @
e88117e6
...
@@ -130,7 +130,7 @@ def shape_parameter_as_tuple(fgraph, node):
...
@@ -130,7 +130,7 @@ def shape_parameter_as_tuple(fgraph, node):
if
isinstance
(
shape_node
.
op
,
MakeVector
)
or
(
if
isinstance
(
shape_node
.
op
,
MakeVector
)
or
(
isinstance
(
shape_node
.
op
,
DimShuffle
)
isinstance
(
shape_node
.
op
,
DimShuffle
)
and
shape_node
.
op
.
input_
broadcastable
==
()
and
shape_node
.
op
.
input_
ndim
==
0
and
shape_node
.
op
.
new_order
==
(
"x"
,)
and
shape_node
.
op
.
new_order
==
(
"x"
,)
):
):
# Here PyTensor converted a tuple or list to a tensor
# Here PyTensor converted a tuple or list to a tensor
...
...
pytensor/tensor/rewriting/linalg.py
浏览文件 @
e88117e6
...
@@ -65,7 +65,7 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
...
@@ -65,7 +65,7 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
if
ndims
<
2
:
if
ndims
<
2
:
return
False
return
False
transpose_order
=
(
*
range
(
ndims
-
2
),
ndims
-
1
,
ndims
-
2
)
transpose_order
=
(
*
range
(
ndims
-
2
),
ndims
-
1
,
ndims
-
2
)
return
cast
(
bool
,
node
.
op
.
new_order
==
transpose_order
)
return
node
.
op
.
new_order
==
transpose_order
return
False
return
False
...
...
pytensor/tensor/rewriting/shape.py
浏览文件 @
e88117e6
...
@@ -925,11 +925,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
...
@@ -925,11 +925,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
if
index
!=
output
.
type
.
ndim
:
if
index
!=
output
.
type
.
ndim
:
inner
=
op
.
__class__
(
len
(
new_output_shape
))(
inp
,
new_output_shape
)
inner
=
op
.
__class__
(
len
(
new_output_shape
))(
inp
,
new_output_shape
)
copy_stack_trace
(
output
,
inner
)
copy_stack_trace
(
output
,
inner
)
new_node
=
[
new_node
=
[
inner
.
dimshuffle
(
dimshuffle_new_order
)]
DimShuffle
(
tuple
(
s
==
1
for
s
in
inner
.
type
.
shape
),
dimshuffle_new_order
)(
inner
)
]
copy_stack_trace
(
output
,
new_node
)
copy_stack_trace
(
output
,
new_node
)
return
new_node
return
new_node
...
...
pytensor/tensor/variable.py
浏览文件 @
e88117e6
...
@@ -344,8 +344,8 @@ class _tensor_py_operators:
...
@@ -344,8 +344,8 @@ class _tensor_py_operators:
"""
"""
if
(
len
(
pattern
)
==
1
)
and
(
isinstance
(
pattern
[
0
],
list
|
tuple
)):
if
(
len
(
pattern
)
==
1
)
and
(
isinstance
(
pattern
[
0
],
list
|
tuple
)):
pattern
=
pattern
[
0
]
pattern
=
pattern
[
0
]
op
=
pt
.
elemwise
.
DimShuffle
(
list
(
self
.
type
.
broadcastable
),
pattern
)
ds_op
=
pt
.
elemwise
.
DimShuffle
(
input_ndim
=
self
.
type
.
ndim
,
new_order
=
pattern
)
return
op
(
self
)
return
ds_
op
(
self
)
def
flatten
(
self
,
ndim
=
1
):
def
flatten
(
self
,
ndim
=
1
):
return
pt
.
basic
.
flatten
(
self
,
ndim
)
return
pt
.
basic
.
flatten
(
self
,
ndim
)
...
...
tests/link/jax/test_elemwise.py
浏览文件 @
e88117e6
...
@@ -39,7 +39,7 @@ def test_jax_Dimshuffle():
...
@@ -39,7 +39,7 @@ def test_jax_Dimshuffle():
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
x
=
pt_elemwise
.
DimShuffle
(
[
False
,
True
],
(
0
,))(
a_pt
)
x
=
pt_elemwise
.
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
0
,))(
a_pt
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
...
...
tests/link/numba/test_elemwise.py
浏览文件 @
e88117e6
...
@@ -15,7 +15,7 @@ from pytensor.compile.sharedvalue import SharedVariable
...
@@ -15,7 +15,7 @@ from pytensor.compile.sharedvalue import SharedVariable
from
pytensor.gradient
import
grad
from
pytensor.gradient
import
grad
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
elemwise
as
pt_elemwis
e
from
pytensor.tensor
.elemwise
import
DimShuffl
e
from
pytensor.tensor.math
import
All
,
Any
,
Max
,
Mean
,
Min
,
Prod
,
ProdWithoutZeros
,
Sum
from
pytensor.tensor.math
import
All
,
Any
,
Max
,
Mean
,
Min
,
Prod
,
ProdWithoutZeros
,
Sum
from
pytensor.tensor.special
import
LogSoftmax
,
Softmax
,
SoftmaxGrad
from
pytensor.tensor.special
import
LogSoftmax
,
Softmax
,
SoftmaxGrad
from
tests.link.numba.test_basic
import
(
from
tests.link.numba.test_basic
import
(
...
@@ -205,7 +205,7 @@ def test_elemwise_speed(benchmark):
...
@@ -205,7 +205,7 @@ def test_elemwise_speed(benchmark):
],
],
)
)
def
test_Dimshuffle
(
v
,
new_order
):
def
test_Dimshuffle
(
v
,
new_order
):
g
=
pt_elemwise
.
DimShuffle
(
v
.
broadcastable
,
new_order
)(
v
)
g
=
v
.
dimshuffle
(
new_order
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
g_fg
,
...
@@ -219,7 +219,7 @@ def test_Dimshuffle(v, new_order):
...
@@ -219,7 +219,7 @@ def test_Dimshuffle(v, new_order):
def
test_Dimshuffle_returns_array
():
def
test_Dimshuffle_returns_array
():
x
=
pt
.
vector
(
"x"
,
shape
=
(
1
,))
x
=
pt
.
vector
(
"x"
,
shape
=
(
1
,))
y
=
2
*
pt_elemwise
.
DimShuffle
([
True
],
[])(
x
)
y
=
2
*
x
.
dimshuffle
([]
)
func
=
pytensor
.
function
([
x
],
y
,
mode
=
"NUMBA"
)
func
=
pytensor
.
function
([
x
],
y
,
mode
=
"NUMBA"
)
out
=
func
(
np
.
zeros
(
1
,
dtype
=
config
.
floatX
))
out
=
func
(
np
.
zeros
(
1
,
dtype
=
config
.
floatX
))
assert
out
.
ndim
==
0
assert
out
.
ndim
==
0
...
@@ -230,7 +230,7 @@ def test_Dimshuffle_non_contiguous():
...
@@ -230,7 +230,7 @@ def test_Dimshuffle_non_contiguous():
non-contiguous arrays, make sure we work around thpt."""
non-contiguous arrays, make sure we work around thpt."""
x
=
pt
.
dvector
()
x
=
pt
.
dvector
()
idx
=
pt
.
vector
(
dtype
=
"int64"
)
idx
=
pt
.
vector
(
dtype
=
"int64"
)
op
=
pytensor
.
tensor
.
elemwise
.
DimShuffle
([
True
],
[])
op
=
DimShuffle
(
input_ndim
=
1
,
new_order
=
[])
out
=
op
(
pt
.
specify_shape
(
x
[
idx
][::
2
],
(
1
,)))
out
=
op
(
pt
.
specify_shape
(
x
[
idx
][::
2
],
(
1
,)))
func
=
pytensor
.
function
([
x
,
idx
],
out
,
mode
=
"NUMBA"
)
func
=
pytensor
.
function
([
x
,
idx
],
out
,
mode
=
"NUMBA"
)
assert
func
(
np
.
zeros
(
3
),
np
.
array
([
1
]))
.
ndim
==
0
assert
func
(
np
.
zeros
(
3
),
np
.
array
([
1
]))
.
ndim
==
0
...
...
tests/link/pytorch/test_elemwise.py
浏览文件 @
e88117e6
...
@@ -5,7 +5,6 @@ import pytensor.tensor as pt
...
@@ -5,7 +5,6 @@ import pytensor.tensor as pt
import
pytensor.tensor.math
as
ptm
import
pytensor.tensor.math
as
ptm
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
elemwise
as
pt_elemwise
from
pytensor.tensor.special
import
SoftmaxGrad
,
log_softmax
,
softmax
from
pytensor.tensor.special
import
SoftmaxGrad
,
log_softmax
,
softmax
from
pytensor.tensor.type
import
matrix
,
tensor
,
tensor3
,
vector
from
pytensor.tensor.type
import
matrix
,
tensor
,
tensor3
,
vector
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
...
@@ -27,11 +26,6 @@ def test_pytorch_Dimshuffle():
...
@@ -27,11 +26,6 @@ def test_pytorch_Dimshuffle():
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_pytorch_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
x
=
pt_elemwise
.
DimShuffle
([
False
,
True
],
(
0
,))(
a_pt
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
def
test_multiple_input_output
():
def
test_multiple_input_output
():
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
...
...
tests/tensor/rewriting/test_elemwise.py
浏览文件 @
e88117e6
...
@@ -79,7 +79,7 @@ dimshuffle_lift = out2in(local_dimshuffle_lift)
...
@@ -79,7 +79,7 @@ dimshuffle_lift = out2in(local_dimshuffle_lift)
def
ds
(
x
,
y
):
def
ds
(
x
,
y
):
return
DimShuffle
(
x
.
type
.
broadcastable
,
y
)(
x
)
return
x
.
dimshuffle
(
y
)
def
inputs
(
xbc
=
(
0
,
0
),
ybc
=
(
0
,
0
),
zbc
=
(
0
,
0
)):
def
inputs
(
xbc
=
(
0
,
0
),
ybc
=
(
0
,
0
),
zbc
=
(
0
,
0
)):
...
...
tests/tensor/rewriting/test_math.py
浏览文件 @
e88117e6
...
@@ -160,7 +160,7 @@ _fast_run_rewrites = optdb.query(_fast_run_rewrites)
...
@@ -160,7 +160,7 @@ _fast_run_rewrites = optdb.query(_fast_run_rewrites)
def
ds
(
x
,
y
):
def
ds
(
x
,
y
):
return
DimShuffle
(
x
.
type
.
broadcastable
,
y
)(
x
)
return
x
.
dimshuffle
(
y
)
def
rewrite
(
g
,
level
=
"fast_run"
):
def
rewrite
(
g
,
level
=
"fast_run"
):
...
@@ -3749,7 +3749,7 @@ def test_local_log_sum_exp_maximum():
...
@@ -3749,7 +3749,7 @@ def test_local_log_sum_exp_maximum():
check_max_log_sum_exp
(
x
,
axis
=
(
0
,
1
,
2
),
dimshuffle_op
=
None
)
check_max_log_sum_exp
(
x
,
axis
=
(
0
,
1
,
2
),
dimshuffle_op
=
None
)
# If a transpose is applied to the sum
# If a transpose is applied to the sum
transpose_op
=
DimShuffle
(
(
False
,
False
),
(
1
,
0
))
transpose_op
=
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
1
,
0
))
check_max_log_sum_exp
(
x
,
axis
=
2
,
dimshuffle_op
=
transpose_op
)
check_max_log_sum_exp
(
x
,
axis
=
2
,
dimshuffle_op
=
transpose_op
)
# If the sum is performed with keepdims=True
# If the sum is performed with keepdims=True
...
@@ -3770,7 +3770,7 @@ def test_local_log_sum_exp_near_one():
...
@@ -3770,7 +3770,7 @@ def test_local_log_sum_exp_near_one():
assert
np
.
allclose
(
naive_ret
,
rewritten_ret
)
assert
np
.
allclose
(
naive_ret
,
rewritten_ret
)
# If a transpose is applied
# If a transpose is applied
transpose_op
=
DimShuffle
(
(
False
,
False
),
(
1
,
0
))
transpose_op
=
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
1
,
0
))
f
=
compile_graph_log_sum_exp
(
x
,
axis
=
(
1
,),
dimshuffle_op
=
transpose_op
)
f
=
compile_graph_log_sum_exp
(
x
,
axis
=
(
1
,),
dimshuffle_op
=
transpose_op
)
naive_ret
=
np
.
log
(
np
.
sum
(
np
.
exp
(
x_val
),
axis
=
1
)
.
T
)
naive_ret
=
np
.
log
(
np
.
sum
(
np
.
exp
(
x_val
),
axis
=
1
)
.
T
)
rewritten_ret
=
f
(
x_val
)
rewritten_ret
=
f
(
x_val
)
...
...
tests/tensor/test_basic.py
浏览文件 @
e88117e6
...
@@ -3418,7 +3418,7 @@ def test_unalign():
...
@@ -3418,7 +3418,7 @@ def test_unalign():
def
test_dimshuffle_duplicate
():
def
test_dimshuffle_duplicate
():
x
=
vector
()
x
=
vector
()
with
pytest
.
raises
(
ValueError
,
match
=
"may not appear twice"
):
with
pytest
.
raises
(
ValueError
,
match
=
"may not appear twice"
):
DimShuffle
(
(
False
,),
(
0
,
0
))(
x
)
DimShuffle
(
input_ndim
=
1
,
new_order
=
(
0
,
0
))(
x
)
class
TestGetUnderlyingScalarConstantValue
:
class
TestGetUnderlyingScalarConstantValue
:
...
...
tests/tensor/test_blas.py
浏览文件 @
e88117e6
...
@@ -593,9 +593,9 @@ class TestAsScalar:
...
@@ -593,9 +593,9 @@ class TestAsScalar:
b
=
pt
.
constant
(
np
.
asarray
([[[
0.5
]]]))
b
=
pt
.
constant
(
np
.
asarray
([[[
0.5
]]]))
b2
=
b
.
dimshuffle
()
b2
=
b
.
dimshuffle
()
assert
b2
.
ndim
==
0
assert
b2
.
ndim
==
0
d_a
=
DimShuffle
(
[],
[])(
a
)
d_a
=
DimShuffle
(
input_ndim
=
0
,
new_order
=
[])(
a
)
d_b
=
DimShuffle
(
[
True
,
True
,
True
],
[
0
,
2
,
1
])(
b
)
d_b
=
DimShuffle
(
input_ndim
=
3
,
new_order
=
[
0
,
2
,
1
])(
b
)
d_a2
=
DimShuffle
(
[],
[
"x"
,
"x"
,
"x"
])(
a
)
d_a2
=
DimShuffle
(
input_ndim
=
0
,
new_order
=
[
"x"
,
"x"
,
"x"
])(
a
)
assert
_as_scalar
(
a
)
==
a
assert
_as_scalar
(
a
)
==
a
assert
_as_scalar
(
b
)
!=
b
assert
_as_scalar
(
b
)
!=
b
...
@@ -607,13 +607,13 @@ class TestAsScalar:
...
@@ -607,13 +607,13 @@ class TestAsScalar:
# Test that it fails on nonscalar constants
# Test that it fails on nonscalar constants
a
=
pt
.
constant
(
np
.
ones
(
5
))
a
=
pt
.
constant
(
np
.
ones
(
5
))
assert
_as_scalar
(
a
)
is
None
assert
_as_scalar
(
a
)
is
None
assert
_as_scalar
(
DimShuffle
(
[
False
],
[
0
,
"x"
])(
a
))
is
None
assert
_as_scalar
(
DimShuffle
(
input_ndim
=
1
,
new_order
=
[
0
,
"x"
])(
a
))
is
None
def
test_basic_2
(
self
):
def
test_basic_2
(
self
):
# Test that it works on scalar variables
# Test that it works on scalar variables
a
=
dscalar
()
a
=
dscalar
()
d_a
=
DimShuffle
(
[],
[])(
a
)
d_a
=
DimShuffle
(
input_ndim
=
0
,
new_order
=
[])(
a
)
d_a2
=
DimShuffle
(
[],
[
"x"
,
"x"
])(
a
)
d_a2
=
DimShuffle
(
input_ndim
=
0
,
new_order
=
[
"x"
,
"x"
])(
a
)
assert
_as_scalar
(
a
)
is
a
assert
_as_scalar
(
a
)
is
a
assert
_as_scalar
(
d_a
)
is
a
assert
_as_scalar
(
d_a
)
is
a
...
@@ -623,13 +623,15 @@ class TestAsScalar:
...
@@ -623,13 +623,15 @@ class TestAsScalar:
# Test that it fails on nonscalar variables
# Test that it fails on nonscalar variables
a
=
matrix
()
a
=
matrix
()
assert
_as_scalar
(
a
)
is
None
assert
_as_scalar
(
a
)
is
None
assert
_as_scalar
(
DimShuffle
(
[
False
,
False
],
[
0
,
"x"
,
1
])(
a
))
is
None
assert
_as_scalar
(
DimShuffle
(
input_ndim
=
2
,
new_order
=
[
0
,
"x"
,
1
])(
a
))
is
None
class
TestRealMatrix
:
class
TestRealMatrix
:
def
test_basic
(
self
):
def
test_basic
(
self
):
assert
_is_real_matrix
(
DimShuffle
([
False
,
False
],
[
1
,
0
])(
matrix
()))
assert
_is_real_matrix
(
DimShuffle
(
input_ndim
=
2
,
new_order
=
[
1
,
0
])(
matrix
()))
assert
not
_is_real_matrix
(
DimShuffle
([
False
],
[
"x"
,
0
])(
dvector
()))
assert
not
_is_real_matrix
(
DimShuffle
(
input_ndim
=
1
,
new_order
=
[
"x"
,
0
])(
dvector
())
)
"""
"""
...
...
tests/tensor/test_elemwise.py
浏览文件 @
e88117e6
...
@@ -60,46 +60,40 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
...
@@ -60,46 +60,40 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
((
1
,),
(
"x"
,
"x"
),
(
1
,
1
)),
((
1
,),
(
"x"
,
"x"
),
(
1
,
1
)),
]:
]:
i_shape
=
[
entry
if
entry
==
1
else
None
for
entry
in
xsh
]
i_shape
=
[
entry
if
entry
==
1
else
None
for
entry
in
xsh
]
ib
=
[
entry
==
1
for
entry
in
i_shape
]
x
=
self
.
type
(
self
.
dtype
,
shape
=
i_shape
)(
"x"
)
x
=
self
.
type
(
self
.
dtype
,
shape
=
i_shape
)(
"x"
)
e
=
self
.
op
(
i
b
,
shuffle
)(
x
)
e
=
self
.
op
(
i
nput_ndim
=
len
(
i_shape
),
new_order
=
shuffle
)(
x
)
f
=
pytensor
.
function
([
x
],
e
,
mode
=
Mode
(
linker
=
linker
))
f
=
pytensor
.
function
([
x
],
e
,
mode
=
Mode
(
linker
=
linker
))
assert
f
(
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
))
.
shape
==
zsh
assert
f
(
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
))
.
shape
==
zsh
# test that DimShuffle.infer_shape work correctly
# test that DimShuffle.infer_shape work correctly
x
=
self
.
type
(
self
.
dtype
,
shape
=
i_shape
)(
"x"
)
x
=
self
.
type
(
self
.
dtype
,
shape
=
i_shape
)(
"x"
)
e
=
self
.
op
(
i
b
,
shuffle
)(
x
)
e
=
self
.
op
(
i
nput_ndim
=
len
(
i_shape
),
new_order
=
shuffle
)(
x
)
f
=
pytensor
.
function
(
f
=
pytensor
.
function
(
[
x
],
e
.
shape
,
mode
=
Mode
(
linker
=
linker
),
on_unused_input
=
"ignore"
[
x
],
e
.
shape
,
mode
=
Mode
(
linker
=
linker
),
on_unused_input
=
"ignore"
)
)
assert
all
(
f
(
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
)))
==
all
(
zsh
)
assert
all
(
f
(
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
)))
==
all
(
zsh
)
# Test when we drop a axis that is not broadcastable
# Test when we drop a axis that is not broadcastable
ib
=
[
False
,
True
,
False
]
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
2
,
1
,
None
))(
"x"
)
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
None
,
1
,
None
))(
"x"
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
ValueError
):
self
.
op
(
input_ndim
=
3
,
new_order
=
shuffle
)(
x
)
self
.
op
(
ib
,
shuffle
)
# Test when we drop a axis that don't have shape 1
# Test when we drop a axis that don't have shape 1
ib
=
[
True
,
True
,
False
]
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
None
,
1
,
None
))(
"x"
)
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
1
,
1
,
None
))(
"x"
)
e
=
self
.
op
(
input_ndim
=
3
,
new_order
=
(
1
,
2
))(
x
)
e
=
self
.
op
(
ib
,
(
1
,
2
))(
x
)
f
=
pytensor
.
function
([
x
],
e
,
mode
=
Mode
(
linker
=
linker
))
f
=
pytensor
.
function
([
x
],
e
.
shape
,
mode
=
Mode
(
linker
=
linker
))
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
TypeError
):
f
(
np
.
ones
((
2
,
1
,
4
),
dtype
=
self
.
dtype
))
f
(
np
.
ones
((
2
,
1
,
4
)))
# Test that we can't take a dimensions multiple time
# Test that we can't take a dimensions multiple time
xsh
,
shuffle
,
zsh
=
((
1
,
1
,
4
),
(
0
,
1
,
2
,
0
),
(
1
,
4
))
xsh
,
shuffle
,
zsh
=
((
1
,
1
,
4
),
(
0
,
1
,
2
,
0
),
(
1
,
4
))
ib
=
[
False
,
True
,
False
]
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
None
,
1
,
None
))(
"x"
)
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
None
,
1
,
None
))(
"x"
)
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
DimShuffle
(
i
b
,
shuffle
)
DimShuffle
(
i
nput_ndim
=
3
,
new_order
=
shuffle
)
def
test_perform
(
self
):
def
test_perform
(
self
):
self
.
with_linker
(
PerformLinker
())
self
.
with_linker
(
PerformLinker
())
def
test_c_or_py
(
self
):
def
test_c_or_py
(
self
):
# Shape op don't have C code.
# But This will test DimShuffle c code
self
.
with_linker
(
OpWiseCLinker
())
self
.
with_linker
(
OpWiseCLinker
())
def
test_infer_shape
(
self
):
def
test_infer_shape
(
self
):
...
@@ -115,12 +109,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
...
@@ -115,12 +109,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
((
1
,),
(
"x"
,
"x"
)),
((
1
,),
(
"x"
,
"x"
)),
]:
]:
i_shape
=
[
entry
if
entry
==
1
else
None
for
entry
in
xsh
]
i_shape
=
[
entry
if
entry
==
1
else
None
for
entry
in
xsh
]
ib
=
[(
entry
==
1
)
for
entry
in
xsh
]
adtens
=
self
.
type
(
self
.
dtype
,
shape
=
i_shape
)(
"x"
)
adtens
=
self
.
type
(
self
.
dtype
,
shape
=
i_shape
)(
"x"
)
adtens_val
=
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
)
adtens_val
=
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
)
self
.
_compile_and_check
(
self
.
_compile_and_check
(
[
adtens
],
[
adtens
],
[
self
.
op
(
i
b
,
shuffle
)(
adtens
)],
[
self
.
op
(
i
nput_ndim
=
len
(
xsh
),
new_order
=
shuffle
)(
adtens
)],
[
adtens_val
],
[
adtens_val
],
self
.
op
,
self
.
op
,
warn
=
False
,
warn
=
False
,
...
@@ -191,11 +184,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
...
@@ -191,11 +184,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
y
=
x
.
dimshuffle
([
0
,
1
,
"x"
])
y
=
x
.
dimshuffle
([
0
,
1
,
"x"
])
assert
y
.
type
.
shape
==
(
1
,
2
,
1
)
assert
y
.
type
.
shape
==
(
1
,
2
,
1
)
def
test_valid_input_
broadcastable
(
self
):
def
test_valid_input_
ndim
(
self
):
assert
DimShuffle
(
[
True
,
False
],
(
1
,
0
))
.
input_broadcastable
==
(
True
,
False
)
assert
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
1
,
0
))
.
input_ndim
==
2
with
pytest
.
raises
(
ValueError
,
match
=
"input_broadcastable must be boolean
"
):
with
pytest
.
raises
(
TypeError
,
match
=
"input_ndim must be an integer
"
):
DimShuffle
(
[
None
,
None
],
(
1
,
0
))
DimShuffle
(
input_ndim
=
(
True
,
False
),
new_order
=
(
1
,
0
))
class
TestBroadcast
:
class
TestBroadcast
:
...
...
tests/tensor/test_extra_ops.py
浏览文件 @
e88117e6
...
@@ -480,12 +480,9 @@ class TestSqueeze(utt.InferShapeTester):
...
@@ -480,12 +480,9 @@ class TestSqueeze(utt.InferShapeTester):
assert
f
([
0
])
==
0
assert
f
([
0
])
==
0
# Test that we cannot squeeze dimensions whose length is greater than 1
# Test that we cannot squeeze dimensions whose length is greater than 1
error_txt_1
=
re
.
escape
(
"SpecifyShape: Got shape (3,), expected (1,)."
)
error_txt_2
=
re
.
escape
(
"SpecifyShape: dim 0 of input has shape 3, expected 1"
)
match
=
error_txt_1
if
pytensor
.
config
.
mode
==
"FAST_COMPILE"
else
error_txt_2
with
pytest
.
raises
(
with
pytest
.
raises
(
Assertion
Error
,
Value
Error
,
match
=
match
,
match
=
"cannot reshape array of size 3 into shape ()"
,
):
):
f
([
0
,
1
,
2
])
f
([
0
,
1
,
2
])
...
...
tests/tensor/test_fft.py
浏览文件 @
e88117e6
...
@@ -204,3 +204,12 @@ class TestFFT:
...
@@ -204,3 +204,12 @@ class TestFFT:
pytensor
.
config
.
floatX
pytensor
.
config
.
floatX
)
)
utt
.
verify_grad
(
f_irfft
,
[
inputs_val
],
eps
=
eps
)
utt
.
verify_grad
(
f_irfft
,
[
inputs_val
],
eps
=
eps
)
def
test_rfft_expanded_dims_grad
(
self
):
# Regression test for https://github.com/pymc-devs/pytensor/issues/969
def
test_func
(
x
):
return
fft
.
rfft
(
x
[
None
,
:])
rng
=
np
.
random
.
default_rng
(
213
)
inputs_val
=
rng
.
random
((
N
,))
.
astype
(
pytensor
.
config
.
floatX
)
utt
.
verify_grad
(
test_func
,
[
inputs_val
],
rng
=
rng
)
tests/tensor/test_keepdims.py
浏览文件 @
e88117e6
...
@@ -4,7 +4,6 @@ import pytest
...
@@ -4,7 +4,6 @@ import pytest
import
pytensor
import
pytensor
from
pytensor
import
function
from
pytensor
import
function
from
pytensor.compile.mode
import
Mode
from
pytensor.compile.mode
import
Mode
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.math
import
any
as
pt_any
from
pytensor.tensor.math
import
any
as
pt_any
from
pytensor.tensor.math
import
argmax
,
argmin
,
max_and_argmax
,
mean
,
prod
,
std
,
var
from
pytensor.tensor.math
import
argmax
,
argmin
,
max_and_argmax
,
mean
,
prod
,
std
,
var
...
@@ -40,7 +39,7 @@ class TestKeepDims:
...
@@ -40,7 +39,7 @@ class TestKeepDims:
new_dims
.
append
(
i
)
new_dims
.
append
(
i
)
i
+=
1
i
+=
1
return
DimShuffle
(
y
.
type
.
broadcastable
,
new_dims
)(
y
)
return
y
.
dimshuffle
(
new_dims
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"axis"
,
"axis"
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论