Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6898f749
提交
6898f749
authored
7月 13, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
8月 07, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove BroadcastTo
上级
5f809cfe
隐藏空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
14 行增加
和
541 行删除
+14
-541
extra_ops.py
pytensor/link/jax/dispatch/extra_ops.py
+0
-14
extra_ops.py
pytensor/link/numba/dispatch/extra_ops.py
+0
-25
extra_ops.py
pytensor/tensor/extra_ops.py
+2
-143
extra_ops.py
pytensor/tensor/rewriting/extra_ops.py
+1
-47
test_extra_ops.py
tests/link/jax/test_extra_ops.py
+1
-24
test_extra_ops.py
tests/link/numba/test_extra_ops.py
+0
-35
test_extra_ops.py
tests/tensor/rewriting/test_extra_ops.py
+1
-72
test_extra_ops.py
tests/tensor/test_extra_ops.py
+9
-181
没有找到文件。
pytensor/link/jax/dispatch/extra_ops.py
浏览文件 @
6898f749
...
@@ -3,10 +3,8 @@ import warnings
...
@@ -3,10 +3,8 @@ import warnings
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.tensor.basic
import
infer_static_shape
from
pytensor.tensor.extra_ops
import
(
from
pytensor.tensor.extra_ops
import
(
Bartlett
,
Bartlett
,
BroadcastTo
,
CumOp
,
CumOp
,
FillDiagonal
,
FillDiagonal
,
FillDiagonalOffset
,
FillDiagonalOffset
,
...
@@ -102,18 +100,6 @@ def jax_funcify_RavelMultiIndex(op, **kwargs):
...
@@ -102,18 +100,6 @@ def jax_funcify_RavelMultiIndex(op, **kwargs):
return
ravelmultiindex
return
ravelmultiindex
@jax_funcify.register
(
BroadcastTo
)
def
jax_funcify_BroadcastTo
(
op
,
node
,
**
kwargs
):
shape
=
node
.
inputs
[
1
:]
static_shape
=
infer_static_shape
(
shape
)[
1
]
def
broadcast_to
(
x
,
*
shape
):
shape
=
tuple
(
st
if
st
is
not
None
else
s
for
s
,
st
in
zip
(
shape
,
static_shape
))
return
jnp
.
broadcast_to
(
x
,
shape
)
return
broadcast_to
@jax_funcify.register
(
FillDiagonal
)
@jax_funcify.register
(
FillDiagonal
)
def
jax_funcify_FillDiagonal
(
op
,
**
kwargs
):
def
jax_funcify_FillDiagonal
(
op
,
**
kwargs
):
def
filldiagonal
(
value
,
diagonal
):
def
filldiagonal
(
value
,
diagonal
):
...
...
pytensor/link/numba/dispatch/extra_ops.py
浏览文件 @
6898f749
...
@@ -2,7 +2,6 @@ import warnings
...
@@ -2,7 +2,6 @@ import warnings
import
numba
import
numba
import
numpy
as
np
import
numpy
as
np
from
numba.misc.special
import
literal_unroll
from
pytensor
import
config
from
pytensor
import
config
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
from
pytensor.link.numba.dispatch
import
basic
as
numba_basic
...
@@ -10,7 +9,6 @@ from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
...
@@ -10,7 +9,6 @@ from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
from
pytensor.raise_op
import
CheckAndRaise
from
pytensor.raise_op
import
CheckAndRaise
from
pytensor.tensor.extra_ops
import
(
from
pytensor.tensor.extra_ops
import
(
Bartlett
,
Bartlett
,
BroadcastTo
,
CumOp
,
CumOp
,
FillDiagonal
,
FillDiagonal
,
FillDiagonalOffset
,
FillDiagonalOffset
,
...
@@ -353,29 +351,6 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
...
@@ -353,29 +351,6 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
return
searchsorted
return
searchsorted
@numba_funcify.register
(
BroadcastTo
)
def
numba_funcify_BroadcastTo
(
op
,
node
,
**
kwargs
):
create_zeros_tuple
=
numba_basic
.
create_tuple_creator
(
lambda
_
:
0
,
len
(
node
.
inputs
)
-
1
)
# TODO broadcastable checks
@numba_basic.numba_njit
def
broadcast_to
(
x
,
*
shape
):
scalars_shape
=
create_zeros_tuple
()
i
=
0
for
s_i
in
literal_unroll
(
shape
):
scalars_shape
=
numba_basic
.
tuple_setitem
(
scalars_shape
,
i
,
numba_basic
.
to_scalar
(
s_i
)
)
i
+=
1
return
np
.
broadcast_to
(
x
,
scalars_shape
)
return
broadcast_to
@numba_funcify.register
(
CheckAndRaise
)
@numba_funcify.register
(
CheckAndRaise
)
def
numba_funcify_CheckAndRaise
(
op
,
node
,
**
kwargs
):
def
numba_funcify_CheckAndRaise
(
op
,
node
,
**
kwargs
):
error
=
op
.
exc_type
error
=
op
.
exc_type
...
...
pytensor/tensor/extra_ops.py
浏览文件 @
6898f749
...
@@ -23,7 +23,7 @@ from pytensor.scalar import int32 as int_t
...
@@ -23,7 +23,7 @@ from pytensor.scalar import int32 as int_t
from
pytensor.scalar
import
upcast
from
pytensor.scalar
import
upcast
from
pytensor.tensor
import
as_tensor_variable
from
pytensor.tensor
import
as_tensor_variable
from
pytensor.tensor
import
basic
as
at
from
pytensor.tensor
import
basic
as
at
from
pytensor.tensor.basic
import
get_vector_length
,
second
from
pytensor.tensor.basic
import
alloc
,
second
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.math
import
abs
as
pt_abs
from
pytensor.tensor.math
import
abs
as
pt_abs
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.math
import
all
as
pt_all
...
@@ -1584,141 +1584,6 @@ def broadcast_shape_iter(
...
@@ -1584,141 +1584,6 @@ def broadcast_shape_iter(
return
tuple
(
result_dims
)
return
tuple
(
result_dims
)
class
BroadcastTo
(
COp
):
"""An `Op` for `numpy.broadcast_to`."""
_output_type_depends_on_input_value
=
True
__props__
=
()
view_map
=
{
0
:
[
0
]}
def
__call__
(
self
,
a
,
shape
,
**
kwargs
):
return
super
()
.
__call__
(
a
,
*
shape
,
**
kwargs
)
def
make_node
(
self
,
a
,
*
shape
):
a
=
at
.
as_tensor_variable
(
a
)
shape
,
static_shape
=
at
.
infer_static_shape
(
shape
)
if
len
(
shape
)
<
a
.
ndim
:
raise
ValueError
(
f
"Broadcast target shape has {len(shape)} dims, which is shorter than input with {a.ndim} dims"
)
out
=
TensorType
(
dtype
=
a
.
type
.
dtype
,
shape
=
static_shape
)()
# Attempt to prevent in-place operations on this view-based output
out
.
tag
.
indestructible
=
True
return
Apply
(
self
,
[
a
]
+
shape
,
[
out
])
def
perform
(
self
,
node
,
inputs
,
output_storage
):
a
,
*
shape
=
inputs
z
=
output_storage
[
0
]
z
[
0
]
=
np
.
broadcast_to
(
a
,
shape
)
def
grad
(
self
,
inputs
,
outputs_gradients
):
a
,
*
shape
=
inputs
(
dout
,)
=
outputs_gradients
# Determine the dimensions that were added by broadcasting
new_dims
=
list
(
range
(
dout
.
ndim
-
a
.
ndim
))
d_wrt_a
=
broadcast_to
(
dout
,
shape
)
.
sum
(
axis
=
new_dims
)
# Determine the dimensions that were broadcast
_
,
static_shape
=
at
.
infer_static_shape
(
shape
)
# TODO: This needs to be performed at run-time when static shape
# information isn't available.
bcast_sums
=
[
i
for
i
,
(
a_s
,
s_s
)
in
enumerate
(
zip
(
a
.
type
.
shape
,
static_shape
[
-
a
.
ndim
:]))
if
a_s
==
1
and
s_s
!=
1
]
if
bcast_sums
:
d_wrt_a
=
d_wrt_a
.
sum
(
axis
=
bcast_sums
,
keepdims
=
True
)
return
[
d_wrt_a
]
+
[
grad_undefined
(
self
,
i
,
shp
)
for
i
,
shp
in
enumerate
(
shape
,
1
)
]
def
infer_shape
(
self
,
fgraph
,
node
,
ins_shapes
):
return
[
node
.
inputs
[
1
:]]
def
c_code
(
self
,
node
,
name
,
inputs
,
outputs
,
sub
):
inp_dims
=
node
.
inputs
[
0
]
.
ndim
out_dims
=
node
.
outputs
[
0
]
.
ndim
new_dims
=
out_dims
-
inp_dims
(
x
,
*
shape
)
=
inputs
(
out
,)
=
outputs
fail
=
sub
[
"fail"
]
# TODO: Could just use `PyArray_Return`, no?
dims_array
=
", "
.
join
(
[
f
"((dtype_{shape}*)(PyArray_DATA({shape})))[0]"
for
i
,
shape
in
enumerate
(
shape
)
]
)
src
=
(
"""
npy_intp itershape[
%(out_dims)
s] = {
%(dims_array)
s};
NpyIter *iter;
PyArrayObject *ops[1] = {
%(x)
s};
npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK;
npy_uint32 op_flags[1] = {NPY_ITER_READONLY};
PyArray_Descr *op_dtypes[1] = {NULL};
int oa_ndim =
%(out_dims)
s;
int* op_axes[1] = {NULL};
npy_intp buffersize = 0;
for(int i = 0; i <
%(inp_dims)
s; i++)
{
if ((PyArray_DIMS(
%(x)
s)[i] != 1) && (PyArray_DIMS(
%(x)
s)[i] != itershape[i +
%(new_dims)
s]))
{
PyErr_Format(PyExc_ValueError,
"Shape mismatch in broadcast_to: target shape[
%%
i] =
%%
lld is incompatible with input shape =
%%
lld.",
i,
(long long int) itershape[i +
%(new_dims)
s],
(long long int) PyArray_DIMS(
%(x)
s)[i]
);
%(fail)
s
}
}
iter = NpyIter_AdvancedNew(
1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize
);
%(out)
s = NpyIter_GetIterView(iter, 0);
if(
%(out)
s == NULL){
NpyIter_Deallocate(iter);
%(fail)
s;
}
if (NpyIter_Deallocate(iter) != NPY_SUCCEED) {
%(fail)
s;
}
"""
%
locals
()
)
return
src
def
c_code_cache_version
(
self
):
return
(
2
,)
broadcast_to_
=
BroadcastTo
()
def
geomspace
(
start
,
end
,
steps
,
base
=
10.0
):
def
geomspace
(
start
,
end
,
steps
,
base
=
10.0
):
from
pytensor.tensor.math
import
log
from
pytensor.tensor.math
import
log
...
@@ -1762,13 +1627,7 @@ def broadcast_to(
...
@@ -1762,13 +1627,7 @@ def broadcast_to(
broadcasted array may refer to a single memory location.
broadcasted array may refer to a single memory location.
"""
"""
x
=
at
.
as_tensor
(
x
)
return
alloc
(
x
,
*
shape
)
shape_len
=
get_vector_length
(
shape
)
if
x
.
ndim
==
0
and
shape_len
==
0
:
return
x
return
broadcast_to_
(
x
,
shape
)
def
broadcast_arrays
(
*
args
:
TensorVariable
)
->
Tuple
[
TensorVariable
,
...
]:
def
broadcast_arrays
(
*
args
:
TensorVariable
)
->
Tuple
[
TensorVariable
,
...
]:
...
...
pytensor/tensor/rewriting/extra_ops.py
浏览文件 @
6898f749
...
@@ -2,7 +2,7 @@ import pytensor.scalar.basic as aes
...
@@ -2,7 +2,7 @@ import pytensor.scalar.basic as aes
from
pytensor.graph.rewriting.basic
import
node_rewriter
from
pytensor.graph.rewriting.basic
import
node_rewriter
from
pytensor.tensor.basic
import
Alloc
,
as_tensor_variable
from
pytensor.tensor.basic
import
Alloc
,
as_tensor_variable
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.extra_ops
import
BroadcastTo
,
Repeat
,
Unique
from
pytensor.tensor.extra_ops
import
Repeat
,
Unique
from
pytensor.tensor.rewriting.basic
import
register_canonicalize
,
register_useless
from
pytensor.tensor.rewriting.basic
import
register_canonicalize
,
register_useless
...
@@ -60,39 +60,6 @@ def local_Unique_Alloc_lift(fgraph, node):
...
@@ -60,39 +60,6 @@ def local_Unique_Alloc_lift(fgraph, node):
return
[
new_x
]
return
[
new_x
]
@register_useless
@register_canonicalize
@node_rewriter
([
Unique
])
def
local_Unique_BroadcastTo_lift
(
fgraph
,
node
):
"""Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``.
This isn't really so much a lift as a "reduction/consumption".
"""
if
not
isinstance
(
node
.
op
,
Unique
):
return
False
if
(
node
.
op
.
return_index
or
node
.
op
.
return_inverse
or
node
.
op
.
return_counts
or
node
.
op
.
axis
is
not
None
):
return
False
bcast_var
=
node
.
inputs
[
0
]
if
not
(
bcast_var
.
owner
and
isinstance
(
bcast_var
.
owner
.
op
,
BroadcastTo
)):
return
False
bcasted_var
,
*
bcast_shape
=
bcast_var
.
owner
.
inputs
new_unique
,
*
_
=
node
.
op
.
make_node
(
bcasted_var
)
.
outputs
old_out
=
node
.
outputs
[
0
]
new_x
=
as_tensor_variable
(
new_unique
,
ndim
=
old_out
.
ndim
,
dtype
=
old_out
.
dtype
)
return
[
new_x
]
@register_useless
@register_useless
@register_canonicalize
@register_canonicalize
@node_rewriter
([
Unique
])
@node_rewriter
([
Unique
])
...
@@ -161,16 +128,3 @@ def local_Unique_second(fgraph, node):
...
@@ -161,16 +128,3 @@ def local_Unique_second(fgraph, node):
old_out
=
node
.
outputs
[
0
]
old_out
=
node
.
outputs
[
0
]
new_x
=
as_tensor_variable
(
new_unique
,
ndim
=
old_out
.
ndim
,
dtype
=
old_out
.
dtype
)
new_x
=
as_tensor_variable
(
new_unique
,
ndim
=
old_out
.
ndim
,
dtype
=
old_out
.
dtype
)
return
[
new_x
]
return
[
new_x
]
@register_useless
@register_canonicalize
@node_rewriter
([
BroadcastTo
])
def
local_remove_scalar_BroadcastTo
(
fgraph
,
node
):
bcast_shape
=
node
.
inputs
[
1
:]
if
not
bcast_shape
:
bcasted_var
=
node
.
inputs
[
0
]
# If this isn't true, the graph is invalid
assert
bcasted_var
.
ndim
==
0
return
[
bcasted_var
]
tests/link/jax/test_extra_ops.py
浏览文件 @
6898f749
...
@@ -7,7 +7,7 @@ from pytensor.configdefaults import config
...
@@ -7,7 +7,7 @@ from pytensor.configdefaults import config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
get_test_value
from
pytensor.graph.op
import
get_test_value
from
pytensor.tensor
import
extra_ops
as
at_extra_ops
from
pytensor.tensor
import
extra_ops
as
at_extra_ops
from
pytensor.tensor.type
import
matrix
,
vector
from
pytensor.tensor.type
import
matrix
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -63,29 +63,6 @@ def test_extra_ops():
...
@@ -63,29 +63,6 @@ def test_extra_ops():
)
)
@pytest.mark.parametrize
(
"x, shape"
,
[
(
set_test_value
(
vector
(
"x"
),
np
.
random
.
random
(
size
=
(
2
,))
.
astype
(
config
.
floatX
)
),
[
at
.
as_tensor
(
3
,
dtype
=
np
.
int64
),
at
.
as_tensor
(
2
,
dtype
=
np
.
int64
)],
),
(
set_test_value
(
vector
(
"x"
),
np
.
random
.
random
(
size
=
(
2
,))
.
astype
(
config
.
floatX
)
),
[
at
.
as_tensor
(
3
,
dtype
=
np
.
int8
),
at
.
as_tensor
(
2
,
dtype
=
np
.
int64
)],
),
],
)
def
test_BroadcastTo
(
x
,
shape
):
out
=
at_extra_ops
.
broadcast_to
(
x
,
shape
)
fgraph
=
FunctionGraph
(
outputs
=
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
@pytest.mark.xfail
(
@pytest.mark.xfail
(
version_parse
(
jax
.
__version__
)
>=
version_parse
(
"0.2.12"
),
version_parse
(
jax
.
__version__
)
>=
version_parse
(
"0.2.12"
),
reason
=
"Omnistaging cannot be disabled"
,
reason
=
"Omnistaging cannot be disabled"
,
...
...
tests/link/numba/test_extra_ops.py
浏览文件 @
6898f749
...
@@ -36,41 +36,6 @@ def test_Bartlett(val):
...
@@ -36,41 +36,6 @@ def test_Bartlett(val):
)
)
@pytest.mark.parametrize
(
"x, shape"
,
[
(
set_test_value
(
at
.
vector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
config
.
floatX
)),
[
set_test_value
(
at
.
lscalar
(),
np
.
array
(
v
))
for
v
in
[
3
,
2
]],
),
(
set_test_value
(
at
.
vector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
config
.
floatX
)),
[
at
.
as_tensor
(
3
,
dtype
=
np
.
int64
),
at
.
as_tensor
(
2
,
dtype
=
np
.
int64
)],
),
(
set_test_value
(
at
.
vector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
config
.
floatX
)),
at
.
as_tensor
([
set_test_value
(
at
.
lscalar
(),
np
.
array
(
v
))
for
v
in
[
3
,
2
]]),
),
(
set_test_value
(
at
.
vector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
config
.
floatX
)),
[
at
.
as_tensor
(
3
,
dtype
=
np
.
int8
),
at
.
as_tensor
(
2
,
dtype
=
np
.
int64
)],
),
],
)
def
test_BroadcastTo
(
x
,
shape
):
g
=
extra_ops
.
BroadcastTo
()(
x
,
shape
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
g_fg
,
[
i
.
tag
.
test_value
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
(
SharedVariable
,
Constant
))
],
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"val, axis, mode"
,
"val, axis, mode"
,
[
[
...
...
tests/tensor/rewriting/test_extra_ops.py
浏览文件 @
6898f749
...
@@ -8,7 +8,7 @@ from pytensor.graph.fg import FunctionGraph
...
@@ -8,7 +8,7 @@ from pytensor.graph.fg import FunctionGraph
from
pytensor.graph.rewriting.utils
import
rewrite_graph
from
pytensor.graph.rewriting.utils
import
rewrite_graph
from
pytensor.tensor.basic
import
Alloc
,
alloc
,
as_tensor_variable
,
second
from
pytensor.tensor.basic
import
Alloc
,
alloc
,
as_tensor_variable
,
second
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.elemwise
import
DimShuffle
,
Elemwise
from
pytensor.tensor.extra_ops
import
BroadcastTo
,
Repeat
,
Unique
,
repeat
,
unique
from
pytensor.tensor.extra_ops
import
Repeat
,
Unique
,
repeat
,
unique
from
pytensor.tensor.type
import
dscalar
from
pytensor.tensor.type
import
dscalar
...
@@ -103,64 +103,6 @@ def test_local_Unique_Alloc_lift(
...
@@ -103,64 +103,6 @@ def test_local_Unique_Alloc_lift(
assert
np
.
array_equal
(
y_exp_val
,
y_val
)
assert
np
.
array_equal
(
y_exp_val
,
y_val
)
@pytest.mark.parametrize
(
"x_val, axis, new_shape"
,
[
(
np
.
array
(
-
10
,
dtype
=
np
.
int64
),
None
,
(
2
,
3
)),
(
np
.
array
([[
-
10
,
-
3
],
[
-
10
,
2
],
[
-
10
,
2
]],
dtype
=
np
.
int64
),
None
,
(
2
,
3
,
2
)),
],
)
@pytest.mark.parametrize
(
"return_index"
,
[
False
])
@pytest.mark.parametrize
(
"return_counts"
,
[
False
])
@pytest.mark.parametrize
(
"return_inverse"
,
[
False
])
def
test_local_Unique_BroadcastTo
(
x_val
,
axis
,
new_shape
,
return_index
,
return_counts
,
return_inverse
):
x
=
as_tensor_variable
(
x_val
)
.
type
()
y
=
unique
(
BroadcastTo
()(
x
,
tuple
(
new_shape
)),
return_index
=
return_index
,
return_counts
=
return_counts
,
return_inverse
=
return_inverse
,
axis
=
axis
,
)
if
isinstance
(
y
,
list
):
y
,
*
_
=
y
# This approach allows us to directly confirm that `x` is in the result.
y_fg
=
FunctionGraph
(
outputs
=
[
y
],
copy_inputs
=
False
)
y_rewritten_fg
=
rewrite_graph
(
y_fg
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_Unique_BroadcastTo_lift"
],
exclude
=
[
"local_Unique_scalar"
],
)
y_rewritten
=
y_rewritten_fg
.
outputs
[
0
]
y_rewritten_start
=
y_rewritten
assert
isinstance
(
y_rewritten_start
.
owner
.
op
,
Unique
)
assert
y_rewritten_start
.
owner
.
inputs
[
0
]
==
x
assert
not
any
(
isinstance
(
node
.
op
,
BroadcastTo
)
for
node
in
y_rewritten_fg
.
apply_nodes
)
default_mode
=
get_default_mode
()
# The rewrite has already been applied to `y_rewritten`, so we can--and
# should--exclude it from the compilation of both our reference, `y`, and
# the rewritten result, `y_rewritten`.
rewrite_mode
=
default_mode
.
excluding
(
"local_Unique_BroadcastTo_lift"
)
y_fn
=
function
([
x
],
[
y
,
y_rewritten
],
mode
=
rewrite_mode
)
# Make sure that the original `BroadcastTo` is used to compute the
# reference `y` result
assert
any
(
isinstance
(
node
.
op
,
BroadcastTo
)
for
node
in
y_fn
.
maker
.
fgraph
.
apply_nodes
)
y_exp_val
,
y_val
=
y_fn
(
x_val
)
assert
np
.
array_equal
(
y_exp_val
,
y_val
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"x_val, unique_axis, repeats, repeat_axis"
,
"x_val, unique_axis, repeats, repeat_axis"
,
[
[
...
@@ -287,16 +229,3 @@ def test_local_Unique_second(
...
@@ -287,16 +229,3 @@ def test_local_Unique_second(
y_exp_val
,
y_val
=
y_fn
(
x_val
)
y_exp_val
,
y_val
=
y_fn
(
x_val
)
assert
np
.
array_equal
(
y_exp_val
,
y_val
)
assert
np
.
array_equal
(
y_exp_val
,
y_val
)
def
test_local_remove_scalar_BroadcastTo
():
x
=
dscalar
()
y
=
BroadcastTo
()(
x
,
())
assert
isinstance
(
y
.
owner
.
op
,
BroadcastTo
)
res
=
rewrite_graph
(
y
,
clone
=
False
,
include
=
[
"canonicalize"
,
"local_remove_scalar_BroadcastTo"
]
)
assert
res
is
x
tests/tensor/test_extra_ops.py
浏览文件 @
6898f749
...
@@ -8,14 +8,12 @@ from pytensor import function
...
@@ -8,14 +8,12 @@ from pytensor import function
from
pytensor
import
tensor
as
at
from
pytensor
import
tensor
as
at
from
pytensor.compile.mode
import
Mode
from
pytensor.compile.mode
import
Mode
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Constant
,
applys_between
from
pytensor.graph.basic
import
Constant
,
applys_between
,
equal_computations
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.rewriting.db
import
RewriteDatabaseQuery
from
pytensor.raise_op
import
Assert
from
pytensor.raise_op
import
Assert
from
pytensor.tensor
import
alloc
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.extra_ops
import
(
from
pytensor.tensor.extra_ops
import
(
Bartlett
,
Bartlett
,
BroadcastTo
,
CpuContiguous
,
CpuContiguous
,
CumOp
,
CumOp
,
FillDiagonal
,
FillDiagonal
,
...
@@ -47,7 +45,6 @@ from pytensor.tensor.extra_ops import (
...
@@ -47,7 +45,6 @@ from pytensor.tensor.extra_ops import (
to_one_hot
,
to_one_hot
,
unravel_index
,
unravel_index
,
)
)
from
pytensor.tensor.subtensor
import
AdvancedIncSubtensor
from
pytensor.tensor.type
import
(
from
pytensor.tensor.type
import
(
TensorType
,
TensorType
,
dmatrix
,
dmatrix
,
...
@@ -61,7 +58,6 @@ from pytensor.tensor.type import (
...
@@ -61,7 +58,6 @@ from pytensor.tensor.type import (
lscalar
,
lscalar
,
matrix
,
matrix
,
scalar
,
scalar
,
tensor
,
tensor3
,
tensor3
,
vector
,
vector
,
)
)
...
@@ -1246,183 +1242,15 @@ def test_broadcast_shape_symbolic_one_symbolic():
...
@@ -1246,183 +1242,15 @@ def test_broadcast_shape_symbolic_one_symbolic():
assert
res_shape
[
2
]
.
data
==
3
assert
res_shape
[
2
]
.
data
==
3
class
TestBroadcastTo
(
utt
.
InferShapeTester
):
def
test_broadcast_to
():
def
setup_method
(
self
):
x
=
vector
(
"x"
)
super
()
.
setup_method
()
y1
=
scalar
(
dtype
=
"int64"
)
self
.
op_class
=
BroadcastTo
y2
=
scalar
(
dtype
=
"int64"
)
self
.
op
=
broadcast_to
def
test_avoid_useless_scalars
(
self
):
x
=
scalar
()
y
=
broadcast_to
(
x
,
())
assert
y
is
x
def
test_avoid_useless_subtensors
(
self
):
x
=
scalar
()
y
=
broadcast_to
(
x
,
(
1
,
2
))
# There shouldn't be any unnecessary `Subtensor` operations
# (e.g. from `at.as_tensor((1, 2))[0]`)
assert
y
.
owner
.
inputs
[
1
]
.
owner
is
None
assert
y
.
owner
.
inputs
[
2
]
.
owner
is
None
@pytest.mark.parametrize
(
"linker"
,
[
"cvm"
,
"py"
])
def
test_perform
(
self
,
linker
):
a
=
pytensor
.
shared
(
np
.
full
((
3
,
1
,
1
),
5
))
s_0
=
iscalar
(
"s_0"
)
s_1
=
iscalar
(
"s_1"
)
shape
=
(
s_0
,
s_1
,
1
)
bcast_res
=
broadcast_to
(
a
,
shape
)
assert
bcast_res
.
broadcastable
==
(
False
,
False
,
True
)
bcast_fn
=
pytensor
.
function
(
[
s_0
,
s_1
],
bcast_res
,
mode
=
Mode
(
optimizer
=
None
,
linker
=
linker
)
)
bcast_fn
.
vm
.
allow_gc
=
False
bcast_at
=
bcast_fn
(
3
,
4
)
bcast_np
=
np
.
broadcast_to
(
5
,
(
3
,
4
,
1
))
assert
np
.
array_equal
(
bcast_at
,
bcast_np
)
with
pytest
.
raises
(
ValueError
):
bcast_fn
(
5
,
4
)
if
linker
!=
"py"
:
bcast_var
=
bcast_fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
inputs
[
0
]
bcast_in
=
bcast_fn
.
vm
.
storage_map
[
a
]
bcast_out
=
bcast_fn
.
vm
.
storage_map
[
bcast_var
]
assert
np
.
shares_memory
(
bcast_out
[
0
],
bcast_in
[
0
])
def
test_make_node_error_handling
(
self
):
with
pytest
.
raises
(
ValueError
,
match
=
"Broadcast target shape has 1 dims, which is shorter than input with 2 dims"
,
):
broadcast_to
(
at
.
zeros
((
3
,
4
)),
(
5
,))
@pytest.mark.skipif
(
assert
equal_computations
(
not
config
.
cxx
,
reason
=
"G++ not available, so we need to skip this test."
[
broadcast_to
(
x
,
(
y1
,
y2
))],
[
alloc
(
x
,
y1
,
y2
)],
)
)
@pytest.mark.parametrize
(
"valid"
,
(
True
,
False
))
def
test_memory_leak
(
self
,
valid
):
import
gc
import
tracemalloc
from
pytensor.link.c.cvm
import
CVM
n
=
100
_000
x
=
pytensor
.
shared
(
np
.
ones
((
1
,
n
),
dtype
=
np
.
float64
))
y
=
broadcast_to
(
x
,
(
5
,
n
))
f
=
pytensor
.
function
([],
y
,
mode
=
Mode
(
optimizer
=
None
,
linker
=
"cvm"
))
assert
isinstance
(
f
.
vm
,
CVM
)
assert
len
(
f
.
maker
.
fgraph
.
apply_nodes
)
==
2
assert
any
(
isinstance
(
node
.
op
,
BroadcastTo
)
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
)
tracemalloc
.
start
()
blocks_last
=
None
block_diffs
=
[]
for
i
in
range
(
1
,
50
):
if
valid
:
x
.
set_value
(
np
.
ones
((
1
,
n
)))
_
=
f
()
else
:
x
.
set_value
(
np
.
ones
((
2
,
n
)))
try
:
_
=
f
()
except
ValueError
:
pass
else
:
raise
RuntimeError
(
"Should have failed"
)
_
=
gc
.
collect
()
blocks_i
,
_
=
tracemalloc
.
get_traced_memory
()
if
blocks_last
is
not
None
:
blocks_diff
=
(
blocks_i
-
blocks_last
)
//
10
**
3
block_diffs
.
append
(
blocks_diff
)
blocks_last
=
blocks_i
tracemalloc
.
stop
()
assert
np
.
all
(
np
.
array
(
block_diffs
)
<=
(
0
+
1e-8
))
@pytest.mark.parametrize
(
"fn,input_dims"
,
[
[
lambda
x
:
broadcast_to
(
x
,
(
1
,)),
(
1
,)],
[
lambda
x
:
broadcast_to
(
x
,
(
6
,
2
,
5
,
3
)),
(
1
,)],
[
lambda
x
:
broadcast_to
(
x
,
(
6
,
2
,
5
,
3
)),
(
5
,
1
)],
[
lambda
x
:
broadcast_to
(
x
,
(
6
,
2
,
1
,
3
)),
(
2
,
1
,
3
)],
],
)
def
test_gradient
(
self
,
fn
,
input_dims
):
rng
=
np
.
random
.
default_rng
(
43
)
utt
.
verify_grad
(
fn
,
[
rng
.
random
(
input_dims
)
.
astype
(
config
.
floatX
)],
n_tests
=
1
,
rng
=
rng
,
)
def
test_infer_shape
(
self
):
rng
=
np
.
random
.
default_rng
(
43
)
a
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
,
None
))
shape
=
list
(
a
.
shape
)
out
=
self
.
op
(
a
,
shape
)
self
.
_compile_and_check
(
[
a
]
+
shape
,
[
out
],
[
rng
.
random
((
2
,
1
,
3
))
.
astype
(
config
.
floatX
),
2
,
1
,
3
],
self
.
op_class
,
)
a
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
,
None
))
shape
=
[
iscalar
()
for
i
in
range
(
4
)]
self
.
_compile_and_check
(
[
a
]
+
shape
,
[
self
.
op
(
a
,
shape
)],
[
rng
.
random
((
2
,
1
,
3
))
.
astype
(
config
.
floatX
),
6
,
2
,
5
,
3
],
self
.
op_class
,
)
def
test_inplace
(
self
):
"""Make sure that in-place optimizations are *not* performed on the output of a ``BroadcastTo``."""
a
=
at
.
zeros
((
5
,))
d
=
at
.
vector
(
"d"
)
c
=
at
.
set_subtensor
(
a
[
np
.
r_
[
0
,
1
,
3
]],
d
)
b
=
broadcast_to
(
c
,
(
5
,))
q
=
b
[
np
.
r_
[
0
,
1
,
3
]]
e
=
at
.
set_subtensor
(
q
,
np
.
r_
[
0
,
0
,
0
])
opts
=
RewriteDatabaseQuery
(
include
=
[
"inplace"
])
py_mode
=
Mode
(
"py"
,
opts
)
e_fn
=
function
([
d
],
e
,
mode
=
py_mode
)
advincsub_node
=
e_fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
assert
isinstance
(
advincsub_node
.
op
,
AdvancedIncSubtensor
)
assert
isinstance
(
advincsub_node
.
inputs
[
0
]
.
owner
.
op
,
BroadcastTo
)
assert
advincsub_node
.
op
.
inplace
is
False
def
test_rebuild
(
self
):
x
=
vector
(
shape
=
(
50
,))
x_test
=
np
.
zeros
((
50
,),
dtype
=
config
.
floatX
)
i
=
0
y
=
broadcast_to
(
i
,
x
.
shape
)
assert
y
.
type
.
shape
==
(
50
,)
assert
y
.
shape
.
eval
({
x
:
x_test
})
==
(
50
,)
assert
y
.
eval
({
x
:
x_test
})
.
shape
==
(
50
,)
x_new
=
vector
(
shape
=
(
100
,))
x_new_test
=
np
.
zeros
((
100
,),
dtype
=
config
.
floatX
)
y_new
=
clone_replace
(
y
,
{
x
:
x_new
},
rebuild_strict
=
False
)
assert
y_new
.
type
.
shape
==
(
100
,)
assert
y_new
.
shape
.
eval
({
x_new
:
x_new_test
})
==
(
100
,)
assert
y_new
.
eval
({
x_new
:
x_new_test
})
.
shape
==
(
100
,)
def
test_broadcast_arrays
():
def
test_broadcast_arrays
():
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论