Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
42a7adb9
提交
42a7adb9
authored
3月 11, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
3月 18, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove Unbroadcast Op
上级
a24f5345
全部展开
显示空白字符变更
内嵌
并排
正在显示
22 个修改的文件
包含
25 行增加
和
512 行删除
+25
-512
basic.rst
doc/library/tensor/basic.rst
+6
-9
pfunc.py
pytensor/compile/function/pfunc.py
+1
-7
ifelse.py
pytensor/ifelse.py
+1
-2
shape.py
pytensor/link/jax/dispatch/shape.py
+1
-9
tensor_basic.py
pytensor/link/numba/dispatch/tensor_basic.py
+0
-10
shape.py
pytensor/link/pytorch/dispatch/shape.py
+1
-9
basic.py
pytensor/scan/basic.py
+5
-5
op.py
pytensor/scan/op.py
+1
-2
basic.py
pytensor/tensor/basic.py
+1
-12
shape.py
pytensor/tensor/rewriting/shape.py
+0
-77
subtensor.py
pytensor/tensor/rewriting/subtensor.py
+0
-37
shape.py
pytensor/tensor/shape.py
+0
-116
test_shape.py
tests/link/jax/test_shape.py
+1
-5
test_tensor_basic.py
tests/link/numba/test_tensor_basic.py
+0
-11
test_shape.py
tests/link/pytorch/test_shape.py
+1
-8
test_printing.py
tests/scan/test_printing.py
+0
-0
test_rewriting.py
tests/scan/test_rewriting.py
+1
-1
test_basic.py
tests/tensor/rewriting/test_basic.py
+0
-44
test_subtensor.py
tests/tensor/rewriting/test_subtensor.py
+0
-60
test_basic.py
tests/tensor/test_basic.py
+5
-5
test_shape.py
tests/tensor/test_shape.py
+0
-75
test_rop.py
tests/test_rop.py
+0
-8
没有找到文件。
doc/library/tensor/basic.rst
浏览文件 @
42a7adb9
...
@@ -619,9 +619,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
...
@@ -619,9 +619,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
.. function:: shape_padleft(x, n_ones=1)
.. function:: shape_padleft(x, n_ones=1)
Reshape `x` by left padding the shape with `n_ones` 1s. Note that all
Reshape `x` by left padding the shape with `n_ones` 1s.
this new dimension will be broadcastable. To make them non-broadcastable
All new dimensions will be broadcastable.
see the :func:`unbroadcast`.
:param x: variable to be reshaped
:param x: variable to be reshaped
:type x: any `TensorVariable` (or compatible)
:type x: any `TensorVariable` (or compatible)
...
@@ -633,9 +632,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
...
@@ -633,9 +632,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
.. function:: shape_padright(x, n_ones=1)
.. function:: shape_padright(x, n_ones=1)
Reshape `x` by right padding the shape with `n_ones` ones. Note that all
Reshape `x` by right padding the shape with `n_ones` ones.
this new dimension will be broadcastable. To make them non-broadcastable
All new dimensions will be broadcastable.
see the :func:`unbroadcast`.
:param x: variable to be reshaped
:param x: variable to be reshaped
:type x: any TensorVariable (or compatible)
:type x: any TensorVariable (or compatible)
...
@@ -646,9 +644,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
...
@@ -646,9 +644,8 @@ dimensions, see :meth:`_tensor_py_operators.dimshuffle`.
.. function:: shape_padaxis(t, axis)
.. function:: shape_padaxis(t, axis)
Reshape `t` by inserting ``1`` at the dimension `axis`. Note that this new
Reshape `t` by inserting ``1`` at the dimension `axis`.
dimension will be broadcastable. To make it non-broadcastable
All new dimensions will be broadcastable.
see the :func:`unbroadcast`.
:type x: any `TensorVariable` (or compatible)
:type x: any `TensorVariable` (or compatible)
:param x: variable to be reshaped
:param x: variable to be reshaped
...
...
pytensor/compile/function/pfunc.py
浏览文件 @
42a7adb9
...
@@ -292,14 +292,8 @@ def rebuild_collect_shared(
...
@@ -292,14 +292,8 @@ def rebuild_collect_shared(
f
" shared_var.type={store_into.type},"
f
" shared_var.type={store_into.type},"
f
" update_val={update_val}, update_val.type={getattr(update_val, 'type', None)})."
f
" update_val={update_val}, update_val.type={getattr(update_val, 'type', None)})."
)
)
err_sug
=
(
"If the difference is related to the broadcast pattern,"
" you can call the"
" tensor.shape.unbroadcast(var, axis_to_unbroadcast[, ...])"
" function to mask broadcastable dimensions."
)
raise
TypeError
(
err_msg
,
err_sug
)
raise
TypeError
(
err_msg
)
assert
store_into
.
type
.
is_super
(
update_val
.
type
)
assert
store_into
.
type
.
is_super
(
update_val
.
type
)
update_d
[
store_into
]
=
update_val
update_d
[
store_into
]
=
update_val
...
...
pytensor/ifelse.py
浏览文件 @
42a7adb9
...
@@ -26,7 +26,7 @@ from pytensor.graph.op import _NoPythonOp
...
@@ -26,7 +26,7 @@ from pytensor.graph.op import _NoPythonOp
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.rewriting.basic
import
GraphRewriter
,
in2out
,
node_rewriter
from
pytensor.graph.rewriting.basic
import
GraphRewriter
,
in2out
,
node_rewriter
from
pytensor.graph.type
import
HasDataType
,
HasShape
from
pytensor.graph.type
import
HasDataType
,
HasShape
from
pytensor.tensor.shape
import
Reshape
,
Shape
,
SpecifyShape
,
Unbroadcast
from
pytensor.tensor.shape
import
Reshape
,
Shape
,
SpecifyShape
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -481,7 +481,6 @@ acceptable_ops = (
...
@@ -481,7 +481,6 @@ acceptable_ops = (
Shape
,
Shape
,
SpecifyShape
,
SpecifyShape
,
Reshape
,
Reshape
,
Unbroadcast
,
pt
.
math
.
Dot
,
pt
.
math
.
Dot
,
pt
.
math
.
Max
,
pt
.
math
.
Max
,
pt
.
math
.
Argmax
,
pt
.
math
.
Argmax
,
...
...
pytensor/link/jax/dispatch/shape.py
浏览文件 @
42a7adb9
...
@@ -4,7 +4,7 @@ from pytensor.graph import Constant
...
@@ -4,7 +4,7 @@ from pytensor.graph import Constant
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
,
Unbroadcast
from
pytensor.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type
import
TensorType
...
@@ -104,11 +104,3 @@ def jax_funcify_SpecifyShape(op, node, **kwargs):
...
@@ -104,11 +104,3 @@ def jax_funcify_SpecifyShape(op, node, **kwargs):
return
x
return
x
return
specifyshape
return
specifyshape
@jax_funcify.register
(
Unbroadcast
)
def
jax_funcify_Unbroadcast
(
op
,
**
kwargs
):
def
unbroadcast
(
x
):
return
x
return
unbroadcast
pytensor/link/numba/dispatch/tensor_basic.py
浏览文件 @
42a7adb9
...
@@ -17,7 +17,6 @@ from pytensor.tensor.basic import (
...
@@ -17,7 +17,6 @@ from pytensor.tensor.basic import (
Split
,
Split
,
TensorFromScalar
,
TensorFromScalar
,
)
)
from
pytensor.tensor.shape
import
Unbroadcast
@numba_funcify.register
(
AllocEmpty
)
@numba_funcify.register
(
AllocEmpty
)
...
@@ -232,15 +231,6 @@ def makevector({", ".join(input_names)}):
...
@@ -232,15 +231,6 @@ def makevector({", ".join(input_names)}):
return
numba_basic
.
numba_njit
(
makevector_fn
)
return
numba_basic
.
numba_njit
(
makevector_fn
)
@numba_funcify.register
(
Unbroadcast
)
def
numba_funcify_Unbroadcast
(
op
,
**
kwargs
):
@numba_basic.numba_njit
def
unbroadcast
(
x
):
return
x
return
unbroadcast
@numba_funcify.register
(
TensorFromScalar
)
@numba_funcify.register
(
TensorFromScalar
)
def
numba_funcify_TensorFromScalar
(
op
,
**
kwargs
):
def
numba_funcify_TensorFromScalar
(
op
,
**
kwargs
):
@numba_basic.numba_njit
(
inline
=
"always"
)
@numba_basic.numba_njit
(
inline
=
"always"
)
...
...
pytensor/link/pytorch/dispatch/shape.py
浏览文件 @
42a7adb9
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.basic
import
Constant
from
pytensor.link.pytorch.dispatch.basic
import
pytorch_funcify
from
pytensor.link.pytorch.dispatch.basic
import
pytorch_funcify
from
pytensor.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
,
Unbroadcast
from
pytensor.tensor.shape
import
Reshape
,
Shape
,
Shape_i
,
SpecifyShape
@pytorch_funcify.register
(
Reshape
)
@pytorch_funcify.register
(
Reshape
)
...
@@ -56,11 +56,3 @@ def pytorch_funcify_SpecifyShape(op, node, **kwargs):
...
@@ -56,11 +56,3 @@ def pytorch_funcify_SpecifyShape(op, node, **kwargs):
return
x
return
x
return
specifyshape
return
specifyshape
@pytorch_funcify.register
(
Unbroadcast
)
def
pytorch_funcify_Unbroadcast
(
op
,
**
kwargs
):
def
unbroadcast
(
x
):
return
x
return
unbroadcast
pytensor/scan/basic.py
浏览文件 @
42a7adb9
...
@@ -15,7 +15,7 @@ from pytensor.scan.utils import expand_empty, safe_new, until
...
@@ -15,7 +15,7 @@ from pytensor.scan.utils import expand_empty, safe_new, until
from
pytensor.tensor.basic
import
get_underlying_scalar_constant_value
from
pytensor.tensor.basic
import
get_underlying_scalar_constant_value
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.math
import
minimum
from
pytensor.tensor.math
import
minimum
from
pytensor.tensor.shape
import
shape_padleft
,
unbroadcast
from
pytensor.tensor.shape
import
shape_padleft
from
pytensor.tensor.type
import
TensorType
,
integer_dtypes
from
pytensor.tensor.type
import
TensorType
,
integer_dtypes
from
pytensor.updates
import
OrderedUpdates
from
pytensor.updates
import
OrderedUpdates
...
@@ -748,7 +748,7 @@ def scan(
...
@@ -748,7 +748,7 @@ def scan(
# defined in scan utils
# defined in scan utils
sit_sot_scan_inputs
.
append
(
sit_sot_scan_inputs
.
append
(
expand_empty
(
expand_empty
(
unbroadcast
(
shape_padleft
(
actual_arg
),
0
),
shape_padleft
(
actual_arg
),
actual_n_steps
,
actual_n_steps
,
)
)
)
)
...
@@ -865,13 +865,13 @@ def scan(
...
@@ -865,13 +865,13 @@ def scan(
if
n_fixed_steps
in
(
1
,
-
1
):
if
n_fixed_steps
in
(
1
,
-
1
):
for
pos
,
inner_out
in
enumerate
(
outputs
):
for
pos
,
inner_out
in
enumerate
(
outputs
):
# we need to see if we need to pad our sequences with an
# we need to see if we need to pad our sequences with an
#
unbroadcastable
dimension; case example : we return an
#
extra
dimension; case example : we return an
# output for which we want all intermediate. If n_steps is 1
# output for which we want all intermediate. If n_steps is 1
# then, if we return the output as given by the innner function
# then, if we return the output as given by the innner function
# this will represent only a slice and it will have one
# this will represent only a slice and it will have one
# dimension less.
# dimension less.
if
isinstance
(
inner_out
.
type
,
TensorType
)
and
return_steps
.
get
(
pos
,
0
)
!=
1
:
if
isinstance
(
inner_out
.
type
,
TensorType
)
and
return_steps
.
get
(
pos
,
0
)
!=
1
:
outputs
[
pos
]
=
unbroadcast
(
shape_padleft
(
inner_out
),
0
)
outputs
[
pos
]
=
shape_padleft
(
inner_out
)
if
not
return_list
and
len
(
outputs
)
==
1
:
if
not
return_list
and
len
(
outputs
)
==
1
:
outputs
=
outputs
[
0
]
outputs
=
outputs
[
0
]
...
@@ -1002,7 +1002,7 @@ def scan(
...
@@ -1002,7 +1002,7 @@ def scan(
sit_sot_inner_inputs
.
append
(
new_var
)
sit_sot_inner_inputs
.
append
(
new_var
)
sit_sot_scan_inputs
.
append
(
sit_sot_scan_inputs
.
append
(
expand_empty
(
expand_empty
(
unbroadcast
(
shape_padleft
(
input
.
variable
),
0
),
shape_padleft
(
input
.
variable
),
actual_n_steps
,
actual_n_steps
,
)
)
)
)
...
...
pytensor/scan/op.py
浏览文件 @
42a7adb9
...
@@ -166,8 +166,7 @@ def check_broadcast(v1, v2):
...
@@ -166,8 +166,7 @@ def check_broadcast(v1, v2):
"axis
%
d in `output_info`. This can happen if one of the "
"axis
%
d in `output_info`. This can happen if one of the "
"dimension is fixed to 1 in the input, while it is still "
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using pytensor.tensor."
"them consistent, e.g. using pytensor.tensor.specify_broadcastable."
"{unbroadcast, specify_broadcastable}."
)
)
size
=
min
(
v1
.
type
.
ndim
,
v2
.
type
.
ndim
)
size
=
min
(
v1
.
type
.
ndim
,
v2
.
type
.
ndim
)
for
n
,
(
b1
,
b2
)
in
enumerate
(
for
n
,
(
b1
,
b2
)
in
enumerate
(
...
...
pytensor/tensor/basic.py
浏览文件 @
42a7adb9
...
@@ -53,7 +53,6 @@ from pytensor.tensor.exceptions import NotScalarConstantError
...
@@ -53,7 +53,6 @@ from pytensor.tensor.exceptions import NotScalarConstantError
from
pytensor.tensor.shape
import
(
from
pytensor.tensor.shape
import
(
Shape
,
Shape
,
Shape_i
,
Shape_i
,
Unbroadcast
,
shape
,
shape
,
shape_padaxis
,
shape_padaxis
,
shape_padleft
,
shape_padleft
,
...
@@ -334,9 +333,7 @@ def _get_underlying_scalar_constant_value(
...
@@ -334,9 +333,7 @@ def _get_underlying_scalar_constant_value(
if
not
only_process_constants
and
getattr
(
v
,
"owner"
,
None
)
and
max_recur
>
0
:
if
not
only_process_constants
and
getattr
(
v
,
"owner"
,
None
)
and
max_recur
>
0
:
op
=
v
.
owner
.
op
op
=
v
.
owner
.
op
max_recur
-=
1
max_recur
-=
1
if
isinstance
(
if
isinstance
(
op
,
Alloc
|
DimShuffle
|
OutputGuard
|
DeepCopyOp
):
op
,
Alloc
|
DimShuffle
|
Unbroadcast
|
OutputGuard
|
DeepCopyOp
):
# OutputGuard is only used in debugmode but we
# OutputGuard is only used in debugmode but we
# keep it here to avoid problems with old pickles
# keep it here to avoid problems with old pickles
v
=
v
.
owner
.
inputs
[
0
]
v
=
v
.
owner
.
inputs
[
0
]
...
@@ -498,14 +495,6 @@ def _get_underlying_scalar_constant_value(
...
@@ -498,14 +495,6 @@ def _get_underlying_scalar_constant_value(
grandparent
=
leftmost_parent
.
owner
.
inputs
[
0
]
grandparent
=
leftmost_parent
.
owner
.
inputs
[
0
]
gp_shape
=
grandparent
.
type
.
shape
gp_shape
=
grandparent
.
type
.
shape
ndim
=
grandparent
.
type
.
ndim
ndim
=
grandparent
.
type
.
ndim
if
grandparent
.
owner
and
isinstance
(
grandparent
.
owner
.
op
,
Unbroadcast
):
ggp_shape
=
grandparent
.
owner
.
inputs
[
0
]
.
type
.
shape
l
=
[
_get_underlying_scalar_constant_value
(
s
)
for
s
in
ggp_shape
]
gp_shape
=
tuple
(
l
)
if
not
(
idx
<
ndim
):
if
not
(
idx
<
ndim
):
msg
=
(
msg
=
(
...
...
pytensor/tensor/rewriting/shape.py
浏览文件 @
42a7adb9
...
@@ -42,9 +42,7 @@ from pytensor.tensor.shape import (
...
@@ -42,9 +42,7 @@ from pytensor.tensor.shape import (
Shape
,
Shape
,
Shape_i
,
Shape_i
,
SpecifyShape
,
SpecifyShape
,
Unbroadcast
,
specify_shape
,
specify_shape
,
unbroadcast
,
)
)
from
pytensor.tensor.subtensor
import
Subtensor
,
get_idx_list
from
pytensor.tensor.subtensor
import
Subtensor
,
get_idx_list
from
pytensor.tensor.type
import
TensorType
,
discrete_dtypes
,
integer_dtypes
from
pytensor.tensor.type
import
TensorType
,
discrete_dtypes
,
integer_dtypes
...
@@ -1296,78 +1294,3 @@ def local_track_shape_i(fgraph, node):
...
@@ -1296,78 +1294,3 @@ def local_track_shape_i(fgraph, node):
# structure.
# structure.
replacement
=
shape_feature
.
scheduled
[
node
]
replacement
=
shape_feature
.
scheduled
[
node
]
return
[
shape_feature
.
shape_of
[
replacement
][
node
.
op
.
i
]]
return
[
shape_feature
.
shape_of
[
replacement
][
node
.
op
.
i
]]
@register_useless
@register_canonicalize
@register_specialize
@node_rewriter
([
Unbroadcast
])
def
local_useless_unbroadcast
(
fgraph
,
node
):
"""Remove `Unbroadcast` if it does not actually change the broadcasting pattern."""
if
isinstance
(
node
.
op
,
Unbroadcast
):
x
=
node
.
inputs
[
0
]
if
x
.
type
.
ndim
==
node
.
outputs
[
0
]
.
type
.
ndim
and
all
(
s1
==
s2
for
s1
,
s2
in
zip
(
x
.
type
.
shape
,
node
.
outputs
[
0
]
.
type
.
shape
,
strict
=
True
)
if
s1
==
1
or
s2
==
1
):
# No broadcastable flag was modified
# No need to copy over stack trace,
# because x should already have a stack trace.
return
[
x
]
else
:
# Keep the flags that modify something
new_axes
=
tuple
(
ax
for
ax
in
node
.
op
.
axes
if
x
.
type
.
shape
[
ax
]
==
1
)
if
new_axes
==
node
.
op
.
axes
:
# All flags are useful
return
None
else
:
r
=
unbroadcast
(
x
,
*
new_axes
)
# Copy over stacktrace from previous output
copy_stack_trace
(
node
.
outputs
,
r
)
return
[
r
]
@register_canonicalize
@register_specialize
@node_rewriter
([
Unbroadcast
])
def
local_unbroadcast_lift
(
fgraph
,
node
):
"""
Lifts `Unbroadcast` through unary Elemwise operations,
and merges consecutive `Unbroadcast`s.
Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x))
Unbroadcast(Unbroadcast(x)) => Unbroadcast(x)
TODO: Implement equivalent Elemwise lift for SpecifyShape
"""
op
=
node
.
op
if
not
isinstance
(
op
,
Unbroadcast
):
return
False
inp
=
node
.
inputs
[
0
]
inode
=
inp
.
owner
if
inode
and
isinstance
(
inode
.
op
,
Elemwise
)
and
len
(
inode
.
inputs
)
==
1
:
if
len
(
fgraph
.
clients
.
get
(
inp
,
()))
==
1
:
unbroadcasted
=
unbroadcast
(
inode
.
inputs
[
0
],
*
op
.
axes
)
copy_stack_trace
(
node
.
outputs
,
unbroadcasted
)
rval
=
inode
.
op
.
make_node
(
unbroadcasted
)
.
outputs
# Copy over stacktrace from previous output (after unbroadcasting)
# and input (after elemwise operation) to new output, because an
# error in the new graph could have been caused by either of the
# two ops.
copy_stack_trace
(
node
.
outputs
+
node
.
inputs
,
rval
)
return
rval
if
inode
and
isinstance
(
inode
.
op
,
Unbroadcast
):
# Merge axis of each unbroadcast
axis
=
tuple
(
set
(
inode
.
op
.
axes
)
.
union
(
set
(
op
.
axes
)))
iinput
=
inode
.
inputs
[
0
]
rval
=
[
unbroadcast
(
iinput
,
*
axis
)]
# Copy over stacktrace from previous output (after second unbroadcasting)
# and from previous input (after first unbroadcasting) because an error in
# the new graph could have been caused by either of the two Unbroadcast ops.
copy_stack_trace
(
node
.
outputs
+
node
.
inputs
,
rval
)
return
rval
pytensor/tensor/rewriting/subtensor.py
浏览文件 @
42a7adb9
...
@@ -59,11 +59,9 @@ from pytensor.tensor.rewriting.basic import (
...
@@ -59,11 +59,9 @@ from pytensor.tensor.rewriting.basic import (
from
pytensor.tensor.shape
import
(
from
pytensor.tensor.shape
import
(
Shape
,
Shape
,
SpecifyShape
,
SpecifyShape
,
Unbroadcast
,
shape_padleft
,
shape_padleft
,
shape_tuple
,
shape_tuple
,
specify_shape
,
specify_shape
,
unbroadcast
,
)
)
from
pytensor.tensor.sharedvar
import
TensorSharedVariable
from
pytensor.tensor.sharedvar
import
TensorSharedVariable
from
pytensor.tensor.subtensor
import
(
from
pytensor.tensor.subtensor
import
(
...
@@ -429,7 +427,6 @@ def local_subtensor_lift(fgraph, node):
...
@@ -429,7 +427,6 @@ def local_subtensor_lift(fgraph, node):
Handles the following unary ops:
Handles the following unary ops:
elemwise(x,...)[idx] -> elemwise(x[idx],...)
elemwise(x,...)[idx] -> elemwise(x[idx],...)
when x,... are broadcasted scalar or not broadcasted at all
when x,... are broadcasted scalar or not broadcasted at all
Unbroadcast(x)[idx] => Unbroadcast(x[idx])
"""
"""
if
isinstance
(
node
.
op
,
Subtensor
):
if
isinstance
(
node
.
op
,
Subtensor
):
...
@@ -488,40 +485,6 @@ def local_subtensor_lift(fgraph, node):
...
@@ -488,40 +485,6 @@ def local_subtensor_lift(fgraph, node):
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
ret
)
return
[
ret
]
return
[
ret
]
if
isinstance
(
u
.
owner
.
op
,
Unbroadcast
):
# Subtensor might reduce dim., adapt broadcast pattern accordingly
old_axes
=
u
.
owner
.
op
.
axes
new_axes
=
[]
# loop through indices being subtensor-ed
# i indexes broadcastable pattern before subtensor
# j indexes broadcastable pattern after subtensor
j
=
0
for
i
,
x
in
enumerate
(
node
.
op
.
idx_list
):
# if it is not a slice, it will reduce the dimension, should
# not appear in the broascastable dimensions
if
isinstance
(
x
,
slice
):
if
i
in
old_axes
:
new_axes
.
append
(
j
)
j
+=
1
# now keep the broadcastable pattern of all
# items not appearing in subtensor list
for
i
in
range
(
len
(
node
.
op
.
idx_list
),
len
(
u
.
broadcastable
)):
if
i
in
old_axes
:
new_axes
.
append
(
j
)
j
+=
1
subt_x
=
node
.
op
(
u
.
owner
.
inputs
[
0
],
*
node
.
inputs
[
1
:])
# Copy over previous output stacktrace
copy_stack_trace
(
node
.
outputs
[
0
],
subt_x
)
rbcast_subt_x
=
unbroadcast
(
subt_x
,
*
new_axes
)
# Copy over previous output stacktrace
# and stacktrace from previous unary operation
copy_stack_trace
([
node
.
outputs
[
0
],
node
.
inputs
[
0
]],
rbcast_subt_x
)
return
[
rbcast_subt_x
]
@register_canonicalize
@register_canonicalize
@register_specialize
@register_specialize
...
...
pytensor/tensor/shape.py
浏览文件 @
42a7adb9
...
@@ -18,7 +18,6 @@ from pytensor.link.c.params_type import ParamsType
...
@@ -18,7 +18,6 @@ from pytensor.link.c.params_type import ParamsType
from
pytensor.npy_2_compat
import
normalize_axis_tuple
from
pytensor.npy_2_compat
import
normalize_axis_tuple
from
pytensor.tensor
import
_get_vector_length
,
as_tensor_variable
,
get_vector_length
from
pytensor.tensor
import
_get_vector_length
,
as_tensor_variable
,
get_vector_length
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor
import
basic
as
ptb
from
pytensor.tensor.elemwise
import
get_normalized_batch_axes
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.type
import
DenseTensorType
,
TensorType
,
int_dtypes
,
tensor
from
pytensor.tensor.type
import
DenseTensorType
,
TensorType
,
int_dtypes
,
tensor
from
pytensor.tensor.type_other
import
NoneConst
,
NoneTypeT
from
pytensor.tensor.type_other
import
NoneConst
,
NoneTypeT
...
@@ -1008,118 +1007,3 @@ def specify_broadcastable(x, *axes):
...
@@ -1008,118 +1007,3 @@ def specify_broadcastable(x, *axes):
axes
=
normalize_axis_tuple
(
axes
,
x
.
type
.
ndim
)
axes
=
normalize_axis_tuple
(
axes
,
x
.
type
.
ndim
)
shape_info
=
[
1
if
i
in
axes
else
s
for
i
,
s
in
enumerate
(
x
.
type
.
shape
)]
shape_info
=
[
1
if
i
in
axes
else
s
for
i
,
s
in
enumerate
(
x
.
type
.
shape
)]
return
specify_shape
(
x
,
shape_info
)
return
specify_shape
(
x
,
shape_info
)
class
Unbroadcast
(
COp
):
"""
Mask static broadcastable dimensions of input as `None`
See Also
--------
unbroadcast <pytensor.tensor.shape.unbroadcast>
Examples
--------
``Unbroadcast((1,))(x)`` would make `x` second static dimension be `None`
"""
view_map
=
{
0
:
[
0
]}
_f16_ok
=
True
# Mapping from Type to C code (and version) to use.
# In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s.
c_code_and_version
:
dict
=
{}
check_input
=
False
__props__
=
(
"axes"
,)
_f16_ok
=
True
def
__init__
(
self
,
*
axis
):
# Sort them to make sure we merge all possible case.
items
=
tuple
(
sorted
(
axis
))
self
.
axes
=
items
for
axis
in
self
.
axes
:
if
not
isinstance
(
axis
,
np
.
integer
|
int
):
raise
TypeError
(
f
"Unbroadcast needs integer axes. Got {axis}"
)
def
__str__
(
self
):
return
f
"{self.__class__.__name__}{{{','.join(str(i) for i in self.axes)}}}"
def
make_node
(
self
,
x
):
x
=
as_tensor_variable
(
x
)
if
x
.
type
.
ndim
<=
max
(
self
.
axes
):
raise
ValueError
(
"Trying to unbroadcast of non-existent dimension"
)
shape
=
[
None
if
(
sh
==
1
and
i
in
self
.
axes
)
else
sh
for
i
,
sh
in
enumerate
(
x
.
type
.
shape
)
]
return
Apply
(
self
,
[
x
],
[
x
.
type
.
clone
(
shape
=
shape
)()])
def
perform
(
self
,
node
,
inp
,
out_
):
(
x
,)
=
inp
(
out
,)
=
out_
out
[
0
]
=
x
def
grad
(
self
,
inp
,
grads
):
(
x
,)
=
inp
(
gz
,)
=
grads
# restore the broadcasting pattern of the input
return
[
specify_shape
(
gz
,
x
.
type
.
shape
)]
def
infer_shape
(
self
,
fgraph
,
node
,
ishapes
):
assert
len
(
ishapes
)
==
1
return
[
tuple
(
ishapes
[
0
])]
def
R_op
(
self
,
inputs
,
eval_points
):
if
eval_points
[
0
]
is
None
:
return
[
None
]
return
self
(
*
eval_points
,
return_list
=
True
)
def
c_code
(
self
,
node
,
nodename
,
inp
,
out
,
sub
):
(
iname
,)
=
inp
(
oname
,)
=
out
return
f
"""
Py_XDECREF({oname});
{oname} = {iname};
Py_XINCREF({oname});
"""
def
c_code_cache_version
(
self
):
return
(
3
,)
def
unbroadcast
(
x
,
*
axes
):
"""
Mask static broadcastable dimensions of input as `None`
Parameters
----------
x : tensor_like
Input pytensor tensor.
axis : an int or an iterable object such as list or tuple of int values
The broadcastable dimensions of x that should be unbroadcasted.
Returns
-------
tensor
A pytensor tensor, with static broadcastable dimensions masked as `None`
"""
x
=
as_tensor_variable
(
x
)
unbroadcasted_axes
=
[
axis
for
axis
in
axes
if
x
.
type
.
shape
[
axis
]
==
1
]
if
not
unbroadcasted_axes
:
return
x
return
Unbroadcast
(
*
unbroadcasted_axes
)(
x
)
@_vectorize_node.register
(
Unbroadcast
)
def
_vectorize_unbroadcast
(
op
:
Unbroadcast
,
node
:
Apply
,
batch_x
:
TensorVariable
)
->
Apply
:
core_ndim
=
node
.
inputs
[
0
]
.
type
.
ndim
batch_ndim
=
batch_x
.
type
.
ndim
-
core_ndim
batch_axes
=
get_normalized_batch_axes
(
op
.
axes
,
core_ndim
,
batch_ndim
)
return
cast
(
Apply
,
unbroadcast
(
batch_x
,
*
batch_axes
)
.
owner
)
tests/link/jax/test_shape.py
浏览文件 @
42a7adb9
...
@@ -4,7 +4,7 @@ import pytest
...
@@ -4,7 +4,7 @@ import pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.compile.ops
import
DeepCopyOp
,
ViewOp
from
pytensor.compile.ops
import
DeepCopyOp
,
ViewOp
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.tensor.shape
import
Shape
,
Shape_i
,
Unbroadcast
,
reshape
from
pytensor.tensor.shape
import
Shape
,
Shape_i
,
reshape
from
pytensor.tensor.type
import
iscalar
,
vector
from
pytensor.tensor.type
import
iscalar
,
vector
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -70,10 +70,6 @@ def test_jax_compile_ops():
...
@@ -70,10 +70,6 @@ def test_jax_compile_ops():
compare_jax_and_py
([],
[
x
],
[])
compare_jax_and_py
([],
[
x
],
[])
x_np
=
np
.
zeros
((
20
,
1
,
1
))
x_np
=
np
.
zeros
((
20
,
1
,
1
))
x
=
Unbroadcast
(
0
,
2
)(
pt
.
as_tensor_variable
(
x_np
))
compare_jax_and_py
([],
[
x
],
[])
x
=
ViewOp
()(
pt
.
as_tensor_variable
(
x_np
))
x
=
ViewOp
()(
pt
.
as_tensor_variable
(
x_np
))
compare_jax_and_py
([],
[
x
],
[])
compare_jax_and_py
([],
[
x
],
[])
tests/link/numba/test_tensor_basic.py
浏览文件 @
42a7adb9
...
@@ -7,7 +7,6 @@ import pytensor.tensor.basic as ptb
...
@@ -7,7 +7,6 @@ import pytensor.tensor.basic as ptb
from
pytensor
import
config
,
function
from
pytensor
import
config
,
function
from
pytensor.compile
import
get_mode
from
pytensor.compile
import
get_mode
from
pytensor.scalar
import
Add
from
pytensor.scalar
import
Add
from
pytensor.tensor.shape
import
Unbroadcast
from
tests.link.numba.test_basic
import
(
from
tests.link.numba.test_basic
import
(
compare_numba_and_py
,
compare_numba_and_py
,
compare_shape_dtype
,
compare_shape_dtype
,
...
@@ -75,16 +74,6 @@ def test_ScalarFromTensor():
...
@@ -75,16 +74,6 @@ def test_ScalarFromTensor():
)
)
def
test_Unbroadcast
():
v
,
v_test
=
pt
.
row
(),
np
.
array
([[
1.0
,
2.0
]],
dtype
=
config
.
floatX
)
g
=
Unbroadcast
(
0
)(
v
)
compare_numba_and_py
(
[
v
],
g
,
[
v_test
],
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"vals, dtype"
,
"vals, dtype"
,
[
[
...
...
tests/link/pytorch/test_shape.py
浏览文件 @
42a7adb9
...
@@ -2,7 +2,7 @@ import numpy as np
...
@@ -2,7 +2,7 @@ import numpy as np
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.tensor.shape
import
Shape
,
Shape_i
,
Unbroadcast
,
reshape
from
pytensor.tensor.shape
import
Shape
,
Shape_i
,
reshape
from
pytensor.tensor.type
import
iscalar
,
vector
from
pytensor.tensor.type
import
iscalar
,
vector
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
...
@@ -50,10 +50,3 @@ def test_pytorch_Reshape_dynamic():
...
@@ -50,10 +50,3 @@ def test_pytorch_Reshape_dynamic():
compare_pytorch_and_py
(
compare_pytorch_and_py
(
[
a
,
shape_pt
],
[
x
],
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
),
2
]
[
a
,
shape_pt
],
[
x
],
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
),
2
]
)
)
def
test_pytorch_unbroadcast
():
x_np
=
np
.
zeros
((
20
,
1
,
1
))
x
=
Unbroadcast
(
0
,
2
)(
pt
.
as_tensor_variable
(
x_np
))
compare_pytorch_and_py
([],
[
x
],
[])
tests/scan/test_printing.py
浏览文件 @
42a7adb9
差异被折叠。
点击展开。
tests/scan/test_rewriting.py
浏览文件 @
42a7adb9
...
@@ -1621,7 +1621,7 @@ class TestSaveMem:
...
@@ -1621,7 +1621,7 @@ class TestSaveMem:
np
.
testing
.
assert_allclose
(
f
(
x0
=
0
,
seq
=
test_seq
,
n_steps
=
200
),
100
)
np
.
testing
.
assert_allclose
(
f
(
x0
=
0
,
seq
=
test_seq
,
n_steps
=
200
),
100
)
np
.
testing
.
assert_allclose
(
f
(
x0
=
1
,
seq
=
test_seq
,
n_steps
=
20
),
21
)
np
.
testing
.
assert_allclose
(
f
(
x0
=
1
,
seq
=
test_seq
,
n_steps
=
20
),
21
)
np
.
testing
.
assert_allclose
(
f
(
x0
=
np
.
e
,
seq
=
test_seq
,
n_steps
=
1
),
np
.
e
+
1
)
np
.
testing
.
assert_allclose
(
f
(
x0
=
np
.
e
,
seq
=
test_seq
,
n_steps
=
1
),
np
.
e
+
1
)
with
pytest
.
raises
(
AssertionError
,
match
=
"n_steps > 0"
):
with
pytest
.
raises
(
(
AssertionError
,
IndexError
)
):
f
(
x0
=
0
,
seq
=
test_seq
,
n_steps
=
0
)
f
(
x0
=
0
,
seq
=
test_seq
,
n_steps
=
0
)
# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
...
...
tests/tensor/rewriting/test_basic.py
浏览文件 @
42a7adb9
...
@@ -77,9 +77,7 @@ from pytensor.tensor.shape import (
...
@@ -77,9 +77,7 @@ from pytensor.tensor.shape import (
Reshape
,
Reshape
,
Shape_i
,
Shape_i
,
SpecifyShape
,
SpecifyShape
,
Unbroadcast
,
specify_shape
,
specify_shape
,
unbroadcast
,
)
)
from
pytensor.tensor.subtensor
import
(
from
pytensor.tensor.subtensor
import
(
AdvancedIncSubtensor1
,
AdvancedIncSubtensor1
,
...
@@ -558,48 +556,6 @@ class TestTile:
...
@@ -558,48 +556,6 @@ class TestTile:
f
(
data
)
f
(
data
)
class
TestUnbroadcast
:
def
setup_method
(
self
):
self
.
mode
=
get_default_mode
()
.
including
(
"canonicalize"
)
def
test_local_useless_unbroadcast
(
self
):
x1
=
tensor
(
dtype
=
"float64"
,
shape
=
(
1
,
2
))
x2
=
tensor
(
dtype
=
"float64"
,
shape
=
(
2
,
1
))
unbroadcast_op
=
Unbroadcast
(
0
)
f
=
function
([
x1
],
unbroadcast_op
(
x1
),
mode
=
self
.
mode
)
assert
(
sum
(
isinstance
(
node
.
op
,
Unbroadcast
)
for
node
in
f
.
maker
.
fgraph
.
toposort
())
==
1
)
f
=
function
([
x2
],
unbroadcast_op
(
x2
),
mode
=
self
.
mode
)
assert
(
sum
(
isinstance
(
node
.
op
,
Unbroadcast
)
for
node
in
f
.
maker
.
fgraph
.
toposort
())
==
0
)
def
test_local_unbroadcast_lift
(
self
):
x
=
tensor
(
dtype
=
"float64"
,
shape
=
(
1
,
1
))
y
=
unbroadcast
(
pt
.
exp
(
unbroadcast
(
x
,
0
)),
1
)
assert
(
sum
(
isinstance
(
node
.
op
,
Unbroadcast
)
for
node
in
FunctionGraph
([
x
],
[
y
],
copy_inputs
=
False
)
.
toposort
()
)
==
2
)
f
=
function
([
x
],
y
,
mode
=
self
.
mode
)
assert
(
sum
(
isinstance
(
node
.
op
,
Unbroadcast
)
for
node
in
f
.
maker
.
fgraph
.
toposort
())
==
1
)
np
.
testing
.
assert_almost_equal
(
f
([[
1
]]),
np
.
exp
([[
1
]]))
class
TestUselessElemwise
:
class
TestUselessElemwise
:
def
setup_method
(
self
):
def
setup_method
(
self
):
self
.
mode
=
get_default_mode
()
.
including
(
"canonicalize"
,
"local_fill_to_alloc"
)
self
.
mode
=
get_default_mode
()
.
including
(
"canonicalize"
,
"local_fill_to_alloc"
)
...
...
tests/tensor/rewriting/test_subtensor.py
浏览文件 @
42a7adb9
...
@@ -28,7 +28,6 @@ from pytensor.tensor.rewriting.subtensor import (
...
@@ -28,7 +28,6 @@ from pytensor.tensor.rewriting.subtensor import (
)
)
from
pytensor.tensor.shape
import
(
from
pytensor.tensor.shape
import
(
SpecifyShape
,
SpecifyShape
,
Unbroadcast
,
_shape
,
_shape
,
shape
,
shape
,
specify_shape
,
specify_shape
,
...
@@ -55,7 +54,6 @@ from pytensor.tensor.type import (
...
@@ -55,7 +54,6 @@ from pytensor.tensor.type import (
lscalar
,
lscalar
,
lscalars
,
lscalars
,
matrix
,
matrix
,
row
,
scalar
,
scalar
,
tensor
,
tensor
,
tensor3
,
tensor3
,
...
@@ -921,64 +919,6 @@ class TestLocalSubtensorLift:
...
@@ -921,64 +919,6 @@ class TestLocalSubtensorLift:
assert
len
(
prog
)
==
2
assert
len
(
prog
)
==
2
f
([
1
,
2
,
3
],
4
)
# let debugmode test something
f
([
1
,
2
,
3
],
4
)
# let debugmode test something
def
test_basic_8
(
self
):
# Test that Subtensor(Unbroadcast(x)) gets optimized into
# Unbroadcast(Subtensor(x)).
# test basic case
x
=
row
(
"x"
)
xval
=
np
.
random
.
random
((
1
,
10
))
.
astype
(
config
.
floatX
)
assert
x
.
broadcastable
==
(
True
,
False
)
newx
=
Unbroadcast
(
0
)(
x
)
assert
newx
.
broadcastable
==
(
False
,
False
)
f1
=
function
([
x
],
newx
[:
2
,
:
5
],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f1
,
ops_to_check
=
[
Subtensor
,
Unbroadcast
])
prog
=
f1
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
1
]
.
op
,
Unbroadcast
)
assert
(
f1
(
xval
)
==
xval
[:
2
,
:
5
])
.
all
()
# corner case 1: Unbroadcast changes dims which are dropped through subtensor
y
=
tensor
(
dtype
=
"float64"
,
shape
=
(
1
,
10
,
1
,
3
),
name
=
"x"
)
yval
=
np
.
random
.
random
((
1
,
10
,
1
,
3
))
.
astype
(
config
.
floatX
)
assert
y
.
broadcastable
==
(
True
,
False
,
True
,
False
)
newy
=
Unbroadcast
(
0
,
2
)(
y
)
assert
newy
.
broadcastable
==
(
False
,
False
,
False
,
False
)
f2
=
function
([
y
],
newy
[:,
3
,
0
,
:],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f2
,
ops_to_check
=
[
Subtensor
,
Unbroadcast
])
prog
=
f2
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
1
]
.
op
,
Unbroadcast
)
assert
(
f2
(
yval
)
==
yval
[:,
3
,
0
,
:])
.
all
()
# corner case 2: subtensor idx_list is shorter than resulting broadcast pattern
f3
=
function
([
y
],
newy
[:,
3
,
0
],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f3
,
ops_to_check
=
[
Subtensor
,
Unbroadcast
])
prog
=
f3
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
1
]
.
op
,
Unbroadcast
)
assert
(
f3
(
yval
)
==
yval
[:,
3
,
0
])
.
all
()
# corner case 3: subtensor idx_list is shorter than Unbroadcast.axis
z
=
tensor
(
dtype
=
"float64"
,
shape
=
(
4
,
10
,
3
,
1
),
name
=
"x"
)
zval
=
np
.
random
.
random
((
4
,
10
,
3
,
1
))
.
astype
(
config
.
floatX
)
assert
z
.
broadcastable
==
(
False
,
False
,
False
,
True
)
newz
=
Unbroadcast
(
3
)(
z
)
assert
newz
.
broadcastable
==
(
False
,
False
,
False
,
False
)
f4
=
function
([
z
],
newz
[:,
3
,
0
],
mode
=
mode_opt
)
# Check stacktrace was copied over correctly after opt was applied
assert
check_stack_trace
(
f4
,
ops_to_check
=
[
Subtensor
,
Unbroadcast
])
prog
=
f4
.
maker
.
fgraph
.
toposort
()
assert
isinstance
(
prog
[
0
]
.
op
,
Subtensor
)
assert
isinstance
(
prog
[
1
]
.
op
,
Unbroadcast
)
assert
(
f4
(
zval
)
==
zval
[:,
3
,
0
])
.
all
()
class
TestLocalSubtensorMerge
:
class
TestLocalSubtensorMerge
:
def
setup_method
(
self
):
def
setup_method
(
self
):
...
...
tests/tensor/test_basic.py
浏览文件 @
42a7adb9
...
@@ -287,7 +287,7 @@ TestAlloc13GradBroadcast = makeBroadcastTester(
...
@@ -287,7 +287,7 @@ TestAlloc13GradBroadcast = makeBroadcastTester(
),
),
)
)
#
un
broadcast a row to a matrix
# broadcast a row to a matrix
TestAllocb1GradBroadcast
=
makeBroadcastTester
(
TestAllocb1GradBroadcast
=
makeBroadcastTester
(
name
=
"Allocb1GradTester"
,
name
=
"Allocb1GradTester"
,
op
=
lambda
x
:
alloc
(
x
,
s1
,
s2
),
op
=
lambda
x
:
alloc
(
x
,
s1
,
s2
),
...
@@ -299,7 +299,7 @@ TestAllocb1GradBroadcast = makeBroadcastTester(
...
@@ -299,7 +299,7 @@ TestAllocb1GradBroadcast = makeBroadcastTester(
),
),
)
)
#
un
broadcast a row to a tensor3
# broadcast a row to a tensor3
TestAllocb2GradBroadcast
=
makeBroadcastTester
(
TestAllocb2GradBroadcast
=
makeBroadcastTester
(
name
=
"Allocb2GradTester"
,
name
=
"Allocb2GradTester"
,
op
=
lambda
x
:
alloc
(
x
,
s1
,
s2
,
s3
),
op
=
lambda
x
:
alloc
(
x
,
s1
,
s2
,
s3
),
...
@@ -311,7 +311,7 @@ TestAllocb2GradBroadcast = makeBroadcastTester(
...
@@ -311,7 +311,7 @@ TestAllocb2GradBroadcast = makeBroadcastTester(
),
),
)
)
#
un
broadcast a col to a matrix
# broadcast a col to a matrix
TestAllocb3GradBroadcast
=
makeBroadcastTester
(
TestAllocb3GradBroadcast
=
makeBroadcastTester
(
name
=
"Allocb3GradTester"
,
name
=
"Allocb3GradTester"
,
op
=
lambda
x
:
alloc
(
x
,
s1
,
s2
),
op
=
lambda
x
:
alloc
(
x
,
s1
,
s2
),
...
@@ -323,7 +323,7 @@ TestAllocb3GradBroadcast = makeBroadcastTester(
...
@@ -323,7 +323,7 @@ TestAllocb3GradBroadcast = makeBroadcastTester(
),
),
)
)
#
un
broadcast a col to a tensor3
# broadcast a col to a tensor3
TestAllocb4GradBroadcast
=
makeBroadcastTester
(
TestAllocb4GradBroadcast
=
makeBroadcastTester
(
name
=
"Allocb4GradTester"
,
name
=
"Allocb4GradTester"
,
op
=
lambda
x
:
alloc
(
x
,
s1
,
s2
,
s3
),
op
=
lambda
x
:
alloc
(
x
,
s1
,
s2
,
s3
),
...
@@ -336,7 +336,7 @@ TestAllocb4GradBroadcast = makeBroadcastTester(
...
@@ -336,7 +336,7 @@ TestAllocb4GradBroadcast = makeBroadcastTester(
)
)
# Partial
un
broadcast of a dimshuffled input
# Partial broadcast of a dimshuffled input
TestAllocDimshuffleGradBroadcast
=
makeBroadcastTester
(
TestAllocDimshuffleGradBroadcast
=
makeBroadcastTester
(
name
=
"Allocb4GradTester"
,
name
=
"Allocb4GradTester"
,
op
=
lambda
x
:
alloc
(
x
.
dimshuffle
(
"x"
,
"x"
,
0
),
1
,
s2
,
s3
),
op
=
lambda
x
:
alloc
(
x
.
dimshuffle
(
"x"
,
"x"
,
0
),
1
,
s2
,
s3
),
...
...
tests/tensor/test_shape.py
浏览文件 @
42a7adb9
...
@@ -19,14 +19,12 @@ from pytensor.tensor.shape import (
...
@@ -19,14 +19,12 @@ from pytensor.tensor.shape import (
Shape
,
Shape
,
Shape_i
,
Shape_i
,
SpecifyShape
,
SpecifyShape
,
Unbroadcast
,
_specify_shape
,
_specify_shape
,
reshape
,
reshape
,
shape
,
shape
,
shape_tuple
,
shape_tuple
,
specify_broadcastable
,
specify_broadcastable
,
specify_shape
,
specify_shape
,
unbroadcast
,
)
)
from
pytensor.tensor.subtensor
import
Subtensor
from
pytensor.tensor.subtensor
import
Subtensor
from
pytensor.tensor.type
import
(
from
pytensor.tensor.type
import
(
...
@@ -696,66 +694,6 @@ def test_get_vector_length():
...
@@ -696,66 +694,6 @@ def test_get_vector_length():
assert
get_vector_length
(
x
)
==
10
assert
get_vector_length
(
x
)
==
10
class
TestUnbroadcast
:
def
test_basic
(
self
):
x
=
matrix
()
assert
unbroadcast
(
x
,
0
)
is
x
assert
unbroadcast
(
x
,
1
)
is
x
assert
unbroadcast
(
x
,
1
,
0
)
is
x
assert
unbroadcast
(
x
,
0
,
1
)
is
x
x
=
row
()
assert
unbroadcast
(
x
,
0
)
is
not
x
assert
unbroadcast
(
x
,
1
)
is
x
assert
unbroadcast
(
x
,
1
,
0
)
is
not
x
assert
unbroadcast
(
x
,
0
,
1
)
is
not
x
assert
unbroadcast
(
unbroadcast
(
x
,
0
),
0
)
.
owner
.
inputs
[
0
]
is
x
def
test_infer_shape
(
self
):
x
=
matrix
()
y
=
unbroadcast
(
x
,
0
)
f
=
pytensor
.
function
([
x
],
y
.
shape
)
assert
(
f
(
np
.
zeros
((
2
,
5
),
dtype
=
config
.
floatX
))
==
[
2
,
5
])
.
all
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
if
config
.
mode
!=
"FAST_COMPILE"
:
assert
len
(
topo
)
==
3
assert
isinstance
(
topo
[
0
]
.
op
,
Shape_i
)
assert
isinstance
(
topo
[
1
]
.
op
,
Shape_i
)
assert
isinstance
(
topo
[
2
]
.
op
,
MakeVector
)
x
=
row
()
y
=
unbroadcast
(
x
,
0
)
f
=
pytensor
.
function
([
x
],
y
.
shape
)
assert
(
f
(
np
.
zeros
((
1
,
5
),
dtype
=
config
.
floatX
))
==
[
1
,
5
])
.
all
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
if
config
.
mode
!=
"FAST_COMPILE"
:
assert
len
(
topo
)
==
2
assert
isinstance
(
topo
[
0
]
.
op
,
Shape_i
)
assert
isinstance
(
topo
[
1
]
.
op
,
MakeVector
)
def
test_error_checks
(
self
):
with
pytest
.
raises
(
TypeError
,
match
=
"needs integer axes"
):
Unbroadcast
(
0.0
)
with
pytest
.
raises
(
ValueError
,
match
=
"^Trying to unbroadcast"
):
Unbroadcast
(
1
)(
vector
())
class
TestUnbroadcastInferShape
(
utt
.
InferShapeTester
):
def
test_basic
(
self
):
rng
=
np
.
random
.
default_rng
(
3453
)
adtens4
=
tensor
(
dtype
=
"float64"
,
shape
=
(
1
,
1
,
1
,
None
))
adtens4_val
=
rng
.
random
((
1
,
1
,
1
,
3
))
.
astype
(
config
.
floatX
)
self
.
_compile_and_check
(
[
adtens4
],
[
Unbroadcast
(
0
,
2
)(
adtens4
)],
[
adtens4_val
],
Unbroadcast
,
warn
=
False
,
)
def
test_shape_tuple
():
def
test_shape_tuple
():
x
=
Variable
(
MyType2
(),
None
,
None
)
x
=
Variable
(
MyType2
(),
None
,
None
)
assert
shape_tuple
(
x
)
==
()
assert
shape_tuple
(
x
)
==
()
...
@@ -882,16 +820,3 @@ class TestVectorize:
...
@@ -882,16 +820,3 @@ class TestVectorize:
match
=
"Invalid number of shape arguments passed into vectorize node of SpecifyShape"
,
match
=
"Invalid number of shape arguments passed into vectorize node of SpecifyShape"
,
):
):
vectorize_node
(
node
,
tns
,
*
(
5
,
3
,
2
,
x
))
vectorize_node
(
node
,
tns
,
*
(
5
,
3
,
2
,
x
))
def
test_unbroadcast
(
self
):
mat
=
tensor
(
shape
=
(
1
,
1
,
)
)
tns
=
tensor
(
shape
=
(
4
,
1
,
1
,
1
))
node
=
unbroadcast
(
mat
,
0
)
.
owner
vect_node
=
vectorize_node
(
node
,
tns
)
assert
equal_computations
(
vect_node
.
outputs
,
[
unbroadcast
(
tns
,
2
)])
tests/test_rop.py
浏览文件 @
42a7adb9
...
@@ -28,7 +28,6 @@ from pytensor.graph.basic import Apply
...
@@ -28,7 +28,6 @@ from pytensor.graph.basic import Apply
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.tensor.math
import
argmax
,
dot
from
pytensor.tensor.math
import
argmax
,
dot
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.shape
import
unbroadcast
from
pytensor.tensor.type
import
matrix
,
vector
from
pytensor.tensor.type
import
matrix
,
vector
from
tests
import
unittest_tools
as
utt
from
tests
import
unittest_tools
as
utt
...
@@ -252,13 +251,6 @@ class TestRopLop(RopLopChecker):
...
@@ -252,13 +251,6 @@ class TestRopLop(RopLopChecker):
# vector
# vector
self
.
check_rop_lop
(
self
.
x
[:
4
]
.
dimshuffle
(
"x"
,
0
)
.
sum
(
axis
=
0
),
(
4
,))
self
.
check_rop_lop
(
self
.
x
[:
4
]
.
dimshuffle
(
"x"
,
0
)
.
sum
(
axis
=
0
),
(
4
,))
def
test_unbroadcast
(
self
):
# I need the sum, because the setup expects the output to be a
# vector
self
.
check_rop_lop
(
unbroadcast
(
self
.
x
[:
4
]
.
dimshuffle
(
"x"
,
0
),
0
)
.
sum
(
axis
=
1
),
(
1
,)
)
def
test_join
(
self
):
def
test_join
(
self
):
tv
=
np
.
asarray
(
self
.
rng
.
uniform
(
size
=
(
10
,)),
pytensor
.
config
.
floatX
)
tv
=
np
.
asarray
(
self
.
rng
.
uniform
(
size
=
(
10
,)),
pytensor
.
config
.
floatX
)
t
=
pytensor
.
shared
(
tv
)
t
=
pytensor
.
shared
(
tv
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论