Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e88117e6
提交
e88117e6
authored
8月 22, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
10月 08, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Only require input_ndim and not input_broadcastable in DimShuffle
上级
d68f53f8
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
24 个修改的文件
包含
66 行增加
和
86 行删除
+66
-86
sp.py
pytensor/sparse/sandbox/sp.py
+2
-3
basic.py
pytensor/tensor/basic.py
+2
-2
elemwise.py
pytensor/tensor/elemwise.py
+0
-0
extra_ops.py
pytensor/tensor/extra_ops.py
+1
-6
inplace.py
pytensor/tensor/inplace.py
+2
-2
math.py
pytensor/tensor/math.py
+1
-3
jax.py
pytensor/tensor/random/rewriting/jax.py
+1
-1
basic.py
pytensor/tensor/rewriting/basic.py
+1
-1
elemwise.py
pytensor/tensor/rewriting/elemwise.py
+2
-4
jax.py
pytensor/tensor/rewriting/jax.py
+1
-1
linalg.py
pytensor/tensor/rewriting/linalg.py
+1
-1
shape.py
pytensor/tensor/rewriting/shape.py
+1
-5
variable.py
pytensor/tensor/variable.py
+2
-2
test_elemwise.py
tests/link/jax/test_elemwise.py
+1
-1
test_elemwise.py
tests/link/numba/test_elemwise.py
+4
-4
test_elemwise.py
tests/link/pytorch/test_elemwise.py
+0
-6
test_elemwise.py
tests/tensor/rewriting/test_elemwise.py
+1
-1
test_math.py
tests/tensor/rewriting/test_math.py
+3
-3
test_basic.py
tests/tensor/test_basic.py
+1
-1
test_blas.py
tests/tensor/test_blas.py
+11
-9
test_elemwise.py
tests/tensor/test_elemwise.py
+16
-23
test_extra_ops.py
tests/tensor/test_extra_ops.py
+2
-5
test_fft.py
tests/tensor/test_fft.py
+9
-0
test_keepdims.py
tests/tensor/test_keepdims.py
+1
-2
没有找到文件。
pytensor/sparse/sandbox/sp.py
浏览文件 @
e88117e6
...
...
@@ -19,7 +19,6 @@ from pytensor.graph.op import Op
from
pytensor.tensor.math
import
dot
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.shape
import
reshape
from
pytensor.tensor.subtensor
import
DimShuffle
def
register_specialize
(
lopt
,
*
tags
,
**
kwargs
):
...
...
@@ -375,7 +374,7 @@ def convolve(
[
images
.
shape
[
0
],
pt
.
as_tensor
(
np
.
prod
(
outshp
)),
pt
.
as_tensor
(
nkern
)]
)
tensout
=
reshape
(
output
,
newshp
,
ndim
=
3
)
output
=
DimShuffle
((
False
,)
*
tensout
.
ndim
,
(
0
,
2
,
1
))(
tensout
)
output
=
tensout
.
transpose
(
0
,
2
,
1
)
if
flatten
:
output
=
pt
.
flatten
(
output
,
2
)
...
...
@@ -443,6 +442,6 @@ def max_pool(images, imgshp, maxpoolshp):
)
out2
=
reshape
(
out1
,
pshape
,
ndim
=
3
)
out3
=
DimShuffle
(
out2
.
broadcastable
,
(
0
,
2
,
1
))(
out2
)
out3
=
out2
.
transpose
(
0
,
2
,
1
)
return
pt
.
flatten
(
out3
,
2
),
outshp
pytensor/tensor/basic.py
浏览文件 @
e88117e6
...
...
@@ -2042,7 +2042,7 @@ def transpose(x, axes=None):
# No-op
return
_x
ret
=
DimShuffle
(
tuple
(
s
==
1
for
s
in
_x
.
type
.
shape
),
axes
)(
_x
)
ret
=
_x
.
dimshuffle
(
axes
)
if
_x
.
name
and
axes
==
tuple
(
range
((
_x
.
type
.
ndim
-
1
),
-
1
,
-
1
)):
ret
.
name
=
_x
.
name
+
".T"
...
...
@@ -3518,7 +3518,7 @@ class PermuteRowElements(Op):
newdims
.
append
(
i
)
i
+=
1
gx
=
DimShuffle
(
tuple
(
s
==
1
for
s
in
gx
.
type
.
shape
),
newdims
)(
gx
)
gx
=
gx
.
dimshuffle
(
newdims
)
assert
gx
.
type
.
ndim
==
x
.
type
.
ndim
assert
all
(
s1
==
s2
...
...
pytensor/tensor/elemwise.py
浏览文件 @
e88117e6
差异被折叠。
点击展开。
pytensor/tensor/extra_ops.py
浏览文件 @
e88117e6
...
...
@@ -41,7 +41,7 @@ from pytensor.tensor.math import (
)
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.math
import
sum
as
pt_sum
from
pytensor.tensor.shape
import
Shape_i
,
specify_broadcastable
from
pytensor.tensor.shape
import
Shape_i
from
pytensor.tensor.subtensor
import
advanced_inc_subtensor1
,
set_subtensor
from
pytensor.tensor.type
import
TensorType
,
dvector
,
int_dtypes
,
integer_dtypes
,
vector
from
pytensor.tensor.variable
import
TensorVariable
...
...
@@ -609,11 +609,6 @@ def squeeze(x, axis=None):
# Nothing could be squeezed
return
_x
# `Dimshuffle` raises when we try to drop an axis that is not statically broadcastable.
# We add a `specify_broadcastable` instead of raising.
non_broadcastable_axis
=
[
i
for
i
in
axis
if
not
_x
.
broadcastable
[
i
]]
_x
=
specify_broadcastable
(
_x
,
*
non_broadcastable_axis
)
return
_x
.
dimshuffle
([
i
for
i
in
range
(
_x
.
ndim
)
if
i
not
in
axis
])
...
...
pytensor/tensor/inplace.py
浏览文件 @
e88117e6
from
pytensor
import
printing
from
pytensor.printing
import
pprint
from
pytensor.tensor.elemwise
import
DimShuffle
,
scalar_elemwise
from
pytensor.tensor.elemwise
import
scalar_elemwise
@scalar_elemwise
...
...
@@ -429,4 +429,4 @@ pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right"))
def
transpose_inplace
(
x
,
**
kwargs
):
"Perform a transpose on a tensor without copying the underlying storage"
dims
=
list
(
range
(
x
.
ndim
-
1
,
-
1
,
-
1
))
return
DimShuffle
(
x
.
broadcastable
,
dims
)(
x
)
return
x
.
dimshuffle
(
dims
)
pytensor/tensor/math.py
浏览文件 @
e88117e6
...
...
@@ -33,7 +33,6 @@ from pytensor.tensor.basic import (
from
pytensor.tensor.blockwise
import
Blockwise
,
vectorize_node_fallback
from
pytensor.tensor.elemwise
import
(
CAReduce
,
DimShuffle
,
Elemwise
,
get_normalized_batch_axes
,
scalar_elemwise
,
...
...
@@ -2338,8 +2337,7 @@ class Sum(FixedOpCAReduce):
else
:
new_dims
.
append
(
i
)
i
+=
1
ds_op
=
DimShuffle
(
gz
.
type
.
broadcastable
,
new_dims
)
gx
=
Elemwise
(
ps
.
second
)(
x
,
ds_op
(
gz
))
gx
=
Elemwise
(
ps
.
second
)(
x
,
gz
.
dimshuffle
(
new_dims
))
return
[
gx
]
def
R_op
(
self
,
inputs
,
eval_points
):
...
...
pytensor/tensor/random/rewriting/jax.py
浏览文件 @
e88117e6
...
...
@@ -65,7 +65,7 @@ def size_parameter_as_tuple(fgraph, node):
if
isinstance
(
size_node
.
op
,
MakeVector
)
or
(
isinstance
(
size_node
.
op
,
DimShuffle
)
and
size_node
.
op
.
input_
broadcastable
==
()
and
size_node
.
op
.
input_
ndim
==
0
and
size_node
.
op
.
new_order
==
(
"x"
,)
):
# Here PyTensor converted a tuple or list to a tensor
...
...
pytensor/tensor/rewriting/basic.py
浏览文件 @
e88117e6
...
...
@@ -494,7 +494,7 @@ def local_alloc_sink_dimshuffle(fgraph, node):
dimshuffle_new_order
=
[
"x"
]
*
num_dims_with_size_1_added_to_left
+
list
(
range
(
len
(
new_output_shape
))
)
return
[
DimShuffle
(
inner
.
type
.
broadcastable
,
dimshuffle_new_order
)(
inn
er
)]
return
[
inner
.
dimshuffle
(
dimshuffle_new_ord
er
)]
@node_rewriter
([
AllocEmpty
])
...
...
pytensor/tensor/rewriting/elemwise.py
浏览文件 @
e88117e6
...
...
@@ -422,8 +422,6 @@ def local_dimshuffle_lift(fgraph, node):
"""
op
=
node
.
op
if
not
isinstance
(
op
,
DimShuffle
):
return
False
inp
=
node
.
inputs
[
0
]
inode
=
inp
.
owner
...
...
@@ -437,7 +435,7 @@ def local_dimshuffle_lift(fgraph, node):
# Don't use make_node to have tag.test_value set.
new_inputs
=
[]
for
inp
in
inode
.
inputs
:
new_inp
=
op
.
__class__
(
inp
.
type
.
broadcastable
,
op
.
new_order
)(
inp
)
new_inp
=
inp
.
dimshuffle
(
op
.
new_order
)
new_inputs
.
append
(
apply_local_dimshuffle_lift
(
fgraph
,
new_inp
))
copy_stack_trace
(
node
.
outputs
[
0
],
new_inputs
)
ret
=
inode
.
op
(
*
new_inputs
,
return_list
=
True
)
...
...
@@ -449,7 +447,7 @@ def local_dimshuffle_lift(fgraph, node):
if
is_dimshuffle_useless
(
new_order
,
inp
):
return
[
inp
]
elif
inode
and
isinstance
(
inode
.
op
,
DimShuffle
):
ret
=
op
.
__class__
(
inp
.
type
.
broadcastable
,
new_order
)(
inp
)
ret
=
inp
.
dimshuffle
(
new_order
)
ret
=
apply_local_dimshuffle_lift
(
fgraph
,
ret
)
copy_stack_trace
(
node
.
outputs
[
0
],
ret
)
return
[
ret
]
...
...
pytensor/tensor/rewriting/jax.py
浏览文件 @
e88117e6
...
...
@@ -130,7 +130,7 @@ def shape_parameter_as_tuple(fgraph, node):
if
isinstance
(
shape_node
.
op
,
MakeVector
)
or
(
isinstance
(
shape_node
.
op
,
DimShuffle
)
and
shape_node
.
op
.
input_
broadcastable
==
()
and
shape_node
.
op
.
input_
ndim
==
0
and
shape_node
.
op
.
new_order
==
(
"x"
,)
):
# Here PyTensor converted a tuple or list to a tensor
...
...
pytensor/tensor/rewriting/linalg.py
浏览文件 @
e88117e6
...
...
@@ -65,7 +65,7 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
if
ndims
<
2
:
return
False
transpose_order
=
(
*
range
(
ndims
-
2
),
ndims
-
1
,
ndims
-
2
)
return
cast
(
bool
,
node
.
op
.
new_order
==
transpose_order
)
return
node
.
op
.
new_order
==
transpose_order
return
False
...
...
pytensor/tensor/rewriting/shape.py
浏览文件 @
e88117e6
...
...
@@ -925,11 +925,7 @@ def local_reshape_to_dimshuffle(fgraph, node):
if
index
!=
output
.
type
.
ndim
:
inner
=
op
.
__class__
(
len
(
new_output_shape
))(
inp
,
new_output_shape
)
copy_stack_trace
(
output
,
inner
)
new_node
=
[
DimShuffle
(
tuple
(
s
==
1
for
s
in
inner
.
type
.
shape
),
dimshuffle_new_order
)(
inner
)
]
new_node
=
[
inner
.
dimshuffle
(
dimshuffle_new_order
)]
copy_stack_trace
(
output
,
new_node
)
return
new_node
...
...
pytensor/tensor/variable.py
浏览文件 @
e88117e6
...
...
@@ -344,8 +344,8 @@ class _tensor_py_operators:
"""
if
(
len
(
pattern
)
==
1
)
and
(
isinstance
(
pattern
[
0
],
list
|
tuple
)):
pattern
=
pattern
[
0
]
op
=
pt
.
elemwise
.
DimShuffle
(
list
(
self
.
type
.
broadcastable
),
pattern
)
return
op
(
self
)
ds_op
=
pt
.
elemwise
.
DimShuffle
(
input_ndim
=
self
.
type
.
ndim
,
new_order
=
pattern
)
return
ds_
op
(
self
)
def
flatten
(
self
,
ndim
=
1
):
return
pt
.
basic
.
flatten
(
self
,
ndim
)
...
...
tests/link/jax/test_elemwise.py
浏览文件 @
e88117e6
...
...
@@ -39,7 +39,7 @@ def test_jax_Dimshuffle():
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
x
=
pt_elemwise
.
DimShuffle
(
[
False
,
True
],
(
0
,))(
a_pt
)
x
=
pt_elemwise
.
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
0
,))(
a_pt
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
...
...
tests/link/numba/test_elemwise.py
浏览文件 @
e88117e6
...
...
@@ -15,7 +15,7 @@ from pytensor.compile.sharedvalue import SharedVariable
from
pytensor.gradient
import
grad
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
elemwise
as
pt_elemwis
e
from
pytensor.tensor
.elemwise
import
DimShuffl
e
from
pytensor.tensor.math
import
All
,
Any
,
Max
,
Mean
,
Min
,
Prod
,
ProdWithoutZeros
,
Sum
from
pytensor.tensor.special
import
LogSoftmax
,
Softmax
,
SoftmaxGrad
from
tests.link.numba.test_basic
import
(
...
...
@@ -205,7 +205,7 @@ def test_elemwise_speed(benchmark):
],
)
def
test_Dimshuffle
(
v
,
new_order
):
g
=
pt_elemwise
.
DimShuffle
(
v
.
broadcastable
,
new_order
)(
v
)
g
=
v
.
dimshuffle
(
new_order
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
g_fg
,
...
...
@@ -219,7 +219,7 @@ def test_Dimshuffle(v, new_order):
def
test_Dimshuffle_returns_array
():
x
=
pt
.
vector
(
"x"
,
shape
=
(
1
,))
y
=
2
*
pt_elemwise
.
DimShuffle
([
True
],
[])(
x
)
y
=
2
*
x
.
dimshuffle
([]
)
func
=
pytensor
.
function
([
x
],
y
,
mode
=
"NUMBA"
)
out
=
func
(
np
.
zeros
(
1
,
dtype
=
config
.
floatX
))
assert
out
.
ndim
==
0
...
...
@@ -230,7 +230,7 @@ def test_Dimshuffle_non_contiguous():
non-contiguous arrays, make sure we work around thpt."""
x
=
pt
.
dvector
()
idx
=
pt
.
vector
(
dtype
=
"int64"
)
op
=
pytensor
.
tensor
.
elemwise
.
DimShuffle
([
True
],
[])
op
=
DimShuffle
(
input_ndim
=
1
,
new_order
=
[])
out
=
op
(
pt
.
specify_shape
(
x
[
idx
][::
2
],
(
1
,)))
func
=
pytensor
.
function
([
x
,
idx
],
out
,
mode
=
"NUMBA"
)
assert
func
(
np
.
zeros
(
3
),
np
.
array
([
1
]))
.
ndim
==
0
...
...
tests/link/pytorch/test_elemwise.py
浏览文件 @
e88117e6
...
...
@@ -5,7 +5,6 @@ import pytensor.tensor as pt
import
pytensor.tensor.math
as
ptm
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
elemwise
as
pt_elemwise
from
pytensor.tensor.special
import
SoftmaxGrad
,
log_softmax
,
softmax
from
pytensor.tensor.type
import
matrix
,
tensor
,
tensor3
,
vector
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
...
...
@@ -27,11 +26,6 @@ def test_pytorch_Dimshuffle():
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
x
=
pt_elemwise
.
DimShuffle
([
False
,
True
],
(
0
,))(
a_pt
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
def
test_multiple_input_output
():
x
=
vector
(
"x"
)
...
...
tests/tensor/rewriting/test_elemwise.py
浏览文件 @
e88117e6
...
...
@@ -79,7 +79,7 @@ dimshuffle_lift = out2in(local_dimshuffle_lift)
def
ds
(
x
,
y
):
return
DimShuffle
(
x
.
type
.
broadcastable
,
y
)(
x
)
return
x
.
dimshuffle
(
y
)
def
inputs
(
xbc
=
(
0
,
0
),
ybc
=
(
0
,
0
),
zbc
=
(
0
,
0
)):
...
...
tests/tensor/rewriting/test_math.py
浏览文件 @
e88117e6
...
...
@@ -160,7 +160,7 @@ _fast_run_rewrites = optdb.query(_fast_run_rewrites)
def
ds
(
x
,
y
):
return
DimShuffle
(
x
.
type
.
broadcastable
,
y
)(
x
)
return
x
.
dimshuffle
(
y
)
def
rewrite
(
g
,
level
=
"fast_run"
):
...
...
@@ -3749,7 +3749,7 @@ def test_local_log_sum_exp_maximum():
check_max_log_sum_exp
(
x
,
axis
=
(
0
,
1
,
2
),
dimshuffle_op
=
None
)
# If a transpose is applied to the sum
transpose_op
=
DimShuffle
(
(
False
,
False
),
(
1
,
0
))
transpose_op
=
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
1
,
0
))
check_max_log_sum_exp
(
x
,
axis
=
2
,
dimshuffle_op
=
transpose_op
)
# If the sum is performed with keepdims=True
...
...
@@ -3770,7 +3770,7 @@ def test_local_log_sum_exp_near_one():
assert
np
.
allclose
(
naive_ret
,
rewritten_ret
)
# If a transpose is applied
transpose_op
=
DimShuffle
(
(
False
,
False
),
(
1
,
0
))
transpose_op
=
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
1
,
0
))
f
=
compile_graph_log_sum_exp
(
x
,
axis
=
(
1
,),
dimshuffle_op
=
transpose_op
)
naive_ret
=
np
.
log
(
np
.
sum
(
np
.
exp
(
x_val
),
axis
=
1
)
.
T
)
rewritten_ret
=
f
(
x_val
)
...
...
tests/tensor/test_basic.py
浏览文件 @
e88117e6
...
...
@@ -3418,7 +3418,7 @@ def test_unalign():
def
test_dimshuffle_duplicate
():
x
=
vector
()
with
pytest
.
raises
(
ValueError
,
match
=
"may not appear twice"
):
DimShuffle
(
(
False
,),
(
0
,
0
))(
x
)
DimShuffle
(
input_ndim
=
1
,
new_order
=
(
0
,
0
))(
x
)
class
TestGetUnderlyingScalarConstantValue
:
...
...
tests/tensor/test_blas.py
浏览文件 @
e88117e6
...
...
@@ -593,9 +593,9 @@ class TestAsScalar:
b
=
pt
.
constant
(
np
.
asarray
([[[
0.5
]]]))
b2
=
b
.
dimshuffle
()
assert
b2
.
ndim
==
0
d_a
=
DimShuffle
(
[],
[])(
a
)
d_b
=
DimShuffle
(
[
True
,
True
,
True
],
[
0
,
2
,
1
])(
b
)
d_a2
=
DimShuffle
(
[],
[
"x"
,
"x"
,
"x"
])(
a
)
d_a
=
DimShuffle
(
input_ndim
=
0
,
new_order
=
[])(
a
)
d_b
=
DimShuffle
(
input_ndim
=
3
,
new_order
=
[
0
,
2
,
1
])(
b
)
d_a2
=
DimShuffle
(
input_ndim
=
0
,
new_order
=
[
"x"
,
"x"
,
"x"
])(
a
)
assert
_as_scalar
(
a
)
==
a
assert
_as_scalar
(
b
)
!=
b
...
...
@@ -607,13 +607,13 @@ class TestAsScalar:
# Test that it fails on nonscalar constants
a
=
pt
.
constant
(
np
.
ones
(
5
))
assert
_as_scalar
(
a
)
is
None
assert
_as_scalar
(
DimShuffle
(
[
False
],
[
0
,
"x"
])(
a
))
is
None
assert
_as_scalar
(
DimShuffle
(
input_ndim
=
1
,
new_order
=
[
0
,
"x"
])(
a
))
is
None
def
test_basic_2
(
self
):
# Test that it works on scalar variables
a
=
dscalar
()
d_a
=
DimShuffle
(
[],
[])(
a
)
d_a2
=
DimShuffle
(
[],
[
"x"
,
"x"
])(
a
)
d_a
=
DimShuffle
(
input_ndim
=
0
,
new_order
=
[])(
a
)
d_a2
=
DimShuffle
(
input_ndim
=
0
,
new_order
=
[
"x"
,
"x"
])(
a
)
assert
_as_scalar
(
a
)
is
a
assert
_as_scalar
(
d_a
)
is
a
...
...
@@ -623,13 +623,15 @@ class TestAsScalar:
# Test that it fails on nonscalar variables
a
=
matrix
()
assert
_as_scalar
(
a
)
is
None
assert
_as_scalar
(
DimShuffle
(
[
False
,
False
],
[
0
,
"x"
,
1
])(
a
))
is
None
assert
_as_scalar
(
DimShuffle
(
input_ndim
=
2
,
new_order
=
[
0
,
"x"
,
1
])(
a
))
is
None
class
TestRealMatrix
:
def
test_basic
(
self
):
assert
_is_real_matrix
(
DimShuffle
([
False
,
False
],
[
1
,
0
])(
matrix
()))
assert
not
_is_real_matrix
(
DimShuffle
([
False
],
[
"x"
,
0
])(
dvector
()))
assert
_is_real_matrix
(
DimShuffle
(
input_ndim
=
2
,
new_order
=
[
1
,
0
])(
matrix
()))
assert
not
_is_real_matrix
(
DimShuffle
(
input_ndim
=
1
,
new_order
=
[
"x"
,
0
])(
dvector
())
)
"""
...
...
tests/tensor/test_elemwise.py
浏览文件 @
e88117e6
...
...
@@ -60,46 +60,40 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
((
1
,),
(
"x"
,
"x"
),
(
1
,
1
)),
]:
i_shape
=
[
entry
if
entry
==
1
else
None
for
entry
in
xsh
]
ib
=
[
entry
==
1
for
entry
in
i_shape
]
x
=
self
.
type
(
self
.
dtype
,
shape
=
i_shape
)(
"x"
)
e
=
self
.
op
(
i
b
,
shuffle
)(
x
)
e
=
self
.
op
(
i
nput_ndim
=
len
(
i_shape
),
new_order
=
shuffle
)(
x
)
f
=
pytensor
.
function
([
x
],
e
,
mode
=
Mode
(
linker
=
linker
))
assert
f
(
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
))
.
shape
==
zsh
# test that DimShuffle.infer_shape work correctly
x
=
self
.
type
(
self
.
dtype
,
shape
=
i_shape
)(
"x"
)
e
=
self
.
op
(
i
b
,
shuffle
)(
x
)
e
=
self
.
op
(
i
nput_ndim
=
len
(
i_shape
),
new_order
=
shuffle
)(
x
)
f
=
pytensor
.
function
(
[
x
],
e
.
shape
,
mode
=
Mode
(
linker
=
linker
),
on_unused_input
=
"ignore"
)
assert
all
(
f
(
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
)))
==
all
(
zsh
)
# Test when we drop a axis that is not broadcastable
ib
=
[
False
,
True
,
False
]
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
None
,
1
,
None
))(
"x"
)
with
pytest
.
raises
(
ValueError
):
self
.
op
(
ib
,
shuffle
)
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
2
,
1
,
None
))(
"x"
)
with
pytest
.
raises
(
TypeError
):
self
.
op
(
input_ndim
=
3
,
new_order
=
shuffle
)(
x
)
# Test when we drop a axis that don't have shape 1
ib
=
[
True
,
True
,
False
]
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
1
,
1
,
None
))(
"x"
)
e
=
self
.
op
(
ib
,
(
1
,
2
))(
x
)
f
=
pytensor
.
function
([
x
],
e
.
shape
,
mode
=
Mode
(
linker
=
linker
))
with
pytest
.
raises
(
TypeError
):
f
(
np
.
ones
((
2
,
1
,
4
)))
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
None
,
1
,
None
))(
"x"
)
e
=
self
.
op
(
input_ndim
=
3
,
new_order
=
(
1
,
2
))(
x
)
f
=
pytensor
.
function
([
x
],
e
,
mode
=
Mode
(
linker
=
linker
))
with
pytest
.
raises
(
ValueError
):
f
(
np
.
ones
((
2
,
1
,
4
),
dtype
=
self
.
dtype
))
# Test that we can't take a dimensions multiple time
xsh
,
shuffle
,
zsh
=
((
1
,
1
,
4
),
(
0
,
1
,
2
,
0
),
(
1
,
4
))
ib
=
[
False
,
True
,
False
]
x
=
self
.
type
(
self
.
dtype
,
shape
=
(
None
,
1
,
None
))(
"x"
)
with
pytest
.
raises
(
ValueError
):
DimShuffle
(
i
b
,
shuffle
)
DimShuffle
(
i
nput_ndim
=
3
,
new_order
=
shuffle
)
def
test_perform
(
self
):
self
.
with_linker
(
PerformLinker
())
def
test_c_or_py
(
self
):
# Shape op don't have C code.
# But This will test DimShuffle c code
self
.
with_linker
(
OpWiseCLinker
())
def
test_infer_shape
(
self
):
...
...
@@ -115,12 +109,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
((
1
,),
(
"x"
,
"x"
)),
]:
i_shape
=
[
entry
if
entry
==
1
else
None
for
entry
in
xsh
]
ib
=
[(
entry
==
1
)
for
entry
in
xsh
]
adtens
=
self
.
type
(
self
.
dtype
,
shape
=
i_shape
)(
"x"
)
adtens_val
=
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
)
self
.
_compile_and_check
(
[
adtens
],
[
self
.
op
(
i
b
,
shuffle
)(
adtens
)],
[
self
.
op
(
i
nput_ndim
=
len
(
xsh
),
new_order
=
shuffle
)(
adtens
)],
[
adtens_val
],
self
.
op
,
warn
=
False
,
...
...
@@ -191,11 +184,11 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
y
=
x
.
dimshuffle
([
0
,
1
,
"x"
])
assert
y
.
type
.
shape
==
(
1
,
2
,
1
)
def
test_valid_input_
broadcastable
(
self
):
assert
DimShuffle
(
[
True
,
False
],
(
1
,
0
))
.
input_broadcastable
==
(
True
,
False
)
def
test_valid_input_
ndim
(
self
):
assert
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
1
,
0
))
.
input_ndim
==
2
with
pytest
.
raises
(
ValueError
,
match
=
"input_broadcastable must be boolean
"
):
DimShuffle
(
[
None
,
None
],
(
1
,
0
))
with
pytest
.
raises
(
TypeError
,
match
=
"input_ndim must be an integer
"
):
DimShuffle
(
input_ndim
=
(
True
,
False
),
new_order
=
(
1
,
0
))
class
TestBroadcast
:
...
...
tests/tensor/test_extra_ops.py
浏览文件 @
e88117e6
...
...
@@ -480,12 +480,9 @@ class TestSqueeze(utt.InferShapeTester):
assert
f
([
0
])
==
0
# Test that we cannot squeeze dimensions whose length is greater than 1
error_txt_1
=
re
.
escape
(
"SpecifyShape: Got shape (3,), expected (1,)."
)
error_txt_2
=
re
.
escape
(
"SpecifyShape: dim 0 of input has shape 3, expected 1"
)
match
=
error_txt_1
if
pytensor
.
config
.
mode
==
"FAST_COMPILE"
else
error_txt_2
with
pytest
.
raises
(
Assertion
Error
,
match
=
match
,
Value
Error
,
match
=
"cannot reshape array of size 3 into shape ()"
,
):
f
([
0
,
1
,
2
])
...
...
tests/tensor/test_fft.py
浏览文件 @
e88117e6
...
...
@@ -204,3 +204,12 @@ class TestFFT:
pytensor
.
config
.
floatX
)
utt
.
verify_grad
(
f_irfft
,
[
inputs_val
],
eps
=
eps
)
def
test_rfft_expanded_dims_grad
(
self
):
# Regression test for https://github.com/pymc-devs/pytensor/issues/969
def
test_func
(
x
):
return
fft
.
rfft
(
x
[
None
,
:])
rng
=
np
.
random
.
default_rng
(
213
)
inputs_val
=
rng
.
random
((
N
,))
.
astype
(
pytensor
.
config
.
floatX
)
utt
.
verify_grad
(
test_func
,
[
inputs_val
],
rng
=
rng
)
tests/tensor/test_keepdims.py
浏览文件 @
e88117e6
...
...
@@ -4,7 +4,6 @@ import pytest
import
pytensor
from
pytensor
import
function
from
pytensor.compile.mode
import
Mode
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.math
import
any
as
pt_any
from
pytensor.tensor.math
import
argmax
,
argmin
,
max_and_argmax
,
mean
,
prod
,
std
,
var
...
...
@@ -40,7 +39,7 @@ class TestKeepDims:
new_dims
.
append
(
i
)
i
+=
1
return
DimShuffle
(
y
.
type
.
broadcastable
,
new_dims
)(
y
)
return
y
.
dimshuffle
(
new_dims
)
@pytest.mark.parametrize
(
"axis"
,
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论