Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e593b0ac
提交
e593b0ac
authored
12月 15, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
12月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use NumPy C API to perform DimShuffle steps in its C implementation
上级
223ee154
显示空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
115 行增加
和
151 行删除
+115
-151
elemwise.py
aesara/gpuarray/elemwise.py
+1
-1
dispatch.py
aesara/link/jax/dispatch.py
+1
-1
elemwise.py
aesara/link/numba/dispatch/elemwise.py
+2
-2
dimshuffle.c
aesara/tensor/c_code/dimshuffle.c
+57
-80
elemwise.py
aesara/tensor/elemwise.py
+27
-52
inplace.py
aesara/tensor/inplace.py
+1
-1
test_jax.py
tests/link/test_jax.py
+1
-1
test_numba.py
tests/link/test_numba.py
+3
-10
test_elemwise.py
tests/tensor/test_elemwise.py
+22
-3
没有找到文件。
aesara/gpuarray/elemwise.py
浏览文件 @
e593b0ac
...
@@ -468,7 +468,7 @@ class GpuDimShuffle(DimShuffle):
...
@@ -468,7 +468,7 @@ class GpuDimShuffle(DimShuffle):
res
=
input
res
=
input
res
=
res
.
transpose
(
self
.
shuffle
+
self
.
drop
)
res
=
res
.
transpose
(
self
.
transposition
)
shape
=
list
(
res
.
shape
[:
len
(
self
.
shuffle
)])
shape
=
list
(
res
.
shape
[:
len
(
self
.
shuffle
)])
for
augm
in
self
.
augment
:
for
augm
in
self
.
augment
:
...
...
aesara/link/jax/dispatch.py
浏览文件 @
e593b0ac
...
@@ -710,7 +710,7 @@ def jax_funcify_Reshape(op, **kwargs):
...
@@ -710,7 +710,7 @@ def jax_funcify_Reshape(op, **kwargs):
def
jax_funcify_DimShuffle
(
op
,
**
kwargs
):
def
jax_funcify_DimShuffle
(
op
,
**
kwargs
):
def
dimshuffle
(
x
):
def
dimshuffle
(
x
):
res
=
jnp
.
transpose
(
x
,
op
.
shuffle
+
op
.
drop
)
res
=
jnp
.
transpose
(
x
,
op
.
transposition
)
shape
=
list
(
res
.
shape
[:
len
(
op
.
shuffle
)])
shape
=
list
(
res
.
shape
[:
len
(
op
.
shuffle
)])
...
...
aesara/link/numba/dispatch/elemwise.py
浏览文件 @
e593b0ac
...
@@ -319,7 +319,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
...
@@ -319,7 +319,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
@numba_funcify.register
(
DimShuffle
)
@numba_funcify.register
(
DimShuffle
)
def
numba_funcify_DimShuffle
(
op
,
**
kwargs
):
def
numba_funcify_DimShuffle
(
op
,
**
kwargs
):
shuffle
=
tuple
(
op
.
shuffle
)
shuffle
=
tuple
(
op
.
shuffle
)
drop
=
tuple
(
op
.
drop
)
transposition
=
tuple
(
op
.
transposition
)
augment
=
tuple
(
op
.
augment
)
augment
=
tuple
(
op
.
augment
)
inplace
=
op
.
inplace
inplace
=
op
.
inplace
...
@@ -352,7 +352,7 @@ def numba_funcify_DimShuffle(op, **kwargs):
...
@@ -352,7 +352,7 @@ def numba_funcify_DimShuffle(op, **kwargs):
@numba.njit
@numba.njit
def
dimshuffle_inner
(
x
,
shuffle
):
def
dimshuffle_inner
(
x
,
shuffle
):
res
=
np
.
transpose
(
x
,
shuffle
+
drop
)
res
=
np
.
transpose
(
x
,
transposition
)
shuffle_shape
=
res
.
shape
[:
len
(
shuffle
)]
shuffle_shape
=
res
.
shape
[:
len
(
shuffle
)]
new_shape
=
create_zeros_tuple
()
new_shape
=
create_zeros_tuple
()
...
...
aesara/tensor/c_code/dimshuffle.c
浏览文件 @
e593b0ac
#section support_code_apply
#section support_code_apply
int
APPLY_SPECIFIC
(
cpu_dimshuffle
)(
PyArrayObject
*
input
,
PyArrayObject
**
res
,
PARAMS_TYPE
*
params
)
{
int
APPLY_SPECIFIC
(
cpu_dimshuffle
)(
PyArrayObject
*
input
,
PyArrayObject
**
res
,
npy_bool
*
input_broadcastable
;
PARAMS_TYPE
*
params
)
{
npy_int64
*
new_order
;
npy_intp
nd_in
;
// This points to either the original input or a copy we create below.
npy_intp
nd_out
;
// Either way, this is what we should be working on/with.
PyArrayObject
*
basename
;
PyArrayObject
*
_input
;
npy_intp
*
dimensions
;
npy_intp
*
strides
;
if
(
!
PyArray_IS_C_CONTIGUOUS
(
params
->
input_broadcastable
))
{
PyErr_SetString
(
PyExc_RuntimeError
,
"DimShuffle: param input_broadcastable must be C-contiguous."
);
return
1
;
}
if
(
!
PyArray_IS_C_CONTIGUOUS
(
params
->
_new_order
))
{
PyErr_SetString
(
PyExc_RuntimeError
,
"DimShuffle: param _new_order must be C-contiguous."
);
return
1
;
}
input_broadcastable
=
(
npy_bool
*
)
PyArray_DATA
(
params
->
input_broadcastable
);
new_order
=
(
npy_int64
*
)
PyArray_DATA
(
params
->
_new_order
);
nd_in
=
PyArray_SIZE
(
params
->
input_broadcastable
);
nd_out
=
PyArray_SIZE
(
params
->
_new_order
);
/* check_input_nd */
if
(
PyArray_NDIM
(
input
)
!=
nd_in
)
{
PyErr_SetString
(
PyExc_NotImplementedError
,
"input nd"
);
return
1
;
}
/* clear_output */
if
(
*
res
)
if
(
*
res
)
Py_XDECREF
(
*
res
);
Py_XDECREF
(
*
res
);
/* get_base */
if
(
params
->
inplace
)
{
if
(
params
->
inplace
)
{
basename
=
input
;
_input
=
input
;
Py_INCREF
((
PyObject
*
)
basename
);
Py_INCREF
((
PyObject
*
)
_input
);
}
else
{
}
else
{
basename
=
_input
=
(
PyArrayObject
*
)
PyArray_FromAny
(
(
PyArrayObject
*
)
PyArray_FromAny
((
PyObject
*
)
input
,
(
PyObject
*
)
input
,
NULL
,
0
,
0
,
NPY_ARRAY_ALIGNED
|
NPY_ARRAY_ENSURECOPY
,
NULL
,
0
,
0
,
NPY_ARRAY_ALIGNED
|
NPY_ARRAY_ENSURECOPY
,
NULL
);
NULL
);
}
}
/* shape_statements and strides_statements */
PyArray_Dims
permute
;
dimensions
=
(
npy_intp
*
)
malloc
(
nd_out
*
sizeof
(
npy_intp
));
strides
=
(
npy_intp
*
)
malloc
(
nd_out
*
sizeof
(
npy_intp
));
if
(
dimensions
==
NULL
||
strides
==
NULL
)
{
PyErr_NoMemory
();
free
(
dimensions
);
free
(
strides
);
return
1
;
};
for
(
npy_intp
i
=
0
;
i
<
nd_out
;
++
i
)
{
if
(
!
PyArray_IntpConverter
((
PyObject
*
)
params
->
transposition
,
&
permute
))
{
if
(
new_order
[
i
]
!=
-
1
)
{
return
1
;
dimensions
[
i
]
=
PyArray_DIMS
(
basename
)[
new_order
[
i
]];
strides
[
i
]
=
PyArray_DIMS
(
basename
)[
new_order
[
i
]]
==
1
?
0
:
PyArray_STRIDES
(
basename
)[
new_order
[
i
]];
}
else
{
dimensions
[
i
]
=
1
;
strides
[
i
]
=
0
;
}
}
/*
res = res.transpose(self.transposition)
*/
PyArrayObject
*
transposed_input
=
(
PyArrayObject
*
)
PyArray_Transpose
(
_input
,
&
permute
);
PyDimMem_FREE
(
permute
.
ptr
);
npy_intp
*
res_shape
=
PyArray_DIMS
(
transposed_input
);
npy_intp
N_shuffle
=
PyArray_SIZE
(
params
->
shuffle
);
npy_intp
N_augment
=
PyArray_SIZE
(
params
->
augment
);
npy_intp
N
=
N_augment
+
N_shuffle
;
npy_intp
*
_reshape_shape
=
(
npy_intp
*
)
malloc
(
N
*
sizeof
(
npy_intp
));
if
(
_reshape_shape
==
NULL
)
{
PyErr_NoMemory
();
free
(
_reshape_shape
);
return
1
;
}
}
/* set the strides of the broadcasted dimensions.
/*
* This algorithm is from numpy: PyArray_Newshape() in
shape = list(res.shape[: len(self.shuffle)])
* cvs/numpy/numpy/core/src/multiarraymodule.c */
for augm in self.augment:
if
(
nd_out
>
0
)
{
shape.insert(augm, 1)
if
(
strides
[
nd_out
-
1
]
==
0
)
*/
strides
[
nd_out
-
1
]
=
PyArray_DESCR
(
basename
)
->
elsize
;
npy_intp
aug_idx
=
0
;
for
(
npy_intp
i
=
nd_out
-
2
;
i
>
-
1
;
--
i
)
{
int
res_idx
=
0
;
if
(
strides
[
i
]
==
0
)
for
(
npy_intp
i
=
0
;
i
<
N
;
i
++
)
{
strides
[
i
]
=
strides
[
i
+
1
]
*
dimensions
[
i
+
1
];
if
(
aug_idx
<
N_augment
&&
i
==
*
((
npy_intp
*
)
PyArray_GetPtr
(
params
->
augment
,
&
aug_idx
)))
{
_reshape_shape
[
i
]
=
1
;
aug_idx
++
;
}
else
{
_reshape_shape
[
i
]
=
res_shape
[
res_idx
];
res_idx
++
;
}
}
}
}
/* close_bracket */
PyArray_Dims
reshape_shape
=
{.
ptr
=
_reshape_shape
,
.
len
=
(
int
)
N
};
// create a new array.
*
res
=
(
PyArrayObject
*
)
PyArray_New
(
&
PyArray_Type
,
nd_out
,
dimensions
,
PyArray_TYPE
(
basename
),
strides
,
PyArray_DATA
(
basename
),
PyArray_ITEMSIZE
(
basename
),
// borrow only the writable flag from the base
// the NPY_OWNDATA flag will default to 0.
(
NPY_ARRAY_WRITEABLE
*
PyArray_ISWRITEABLE
(
basename
)),
NULL
);
if
(
*
res
==
NULL
)
{
/* res = res.reshape(shape) */
free
(
dimensions
);
*
res
=
(
PyArrayObject
*
)
PyArray_Newshape
(
transposed_input
,
&
reshape_shape
,
free
(
strides
);
NPY_CORDER
);
return
1
;
}
// recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
/* Py_XDECREF(transposed_input); */
PyArray_UpdateFlags
(
*
res
,
NPY_ARRAY_UPDATE_ALL
);
// we are making a view in both inplace and non-inplace cases
PyDimMem_FREE
(
reshape_shape
.
ptr
);
PyArray_SetBaseObject
(
*
res
,
(
PyObject
*
)
basename
);
free
(
strides
);
if
(
!*
res
)
{
free
(
dimensions
);
return
1
;
}
return
0
;
return
0
;
}
}
aesara/tensor/elemwise.py
浏览文件 @
e593b0ac
...
@@ -119,47 +119,27 @@ class DimShuffle(ExternalCOp):
...
@@ -119,47 +119,27 @@ class DimShuffle(ExternalCOp):
@property
@property
def
params_type
(
self
):
def
params_type
(
self
):
# We can't directly create `params_type` as class attribute
# because of importation issues related to TensorType.
return
ParamsType
(
return
ParamsType
(
input_broadcastable
=
TensorType
(
dtype
=
"bool"
,
broadcastable
=
(
False
,))
,
shuffle
=
lvector
,
_new_order
=
lvector
,
augment
=
lvector
,
transposition
=
TensorType
(
dtype
=
"uint32"
,
broadcastable
=
(
False
,))
,
transposition
=
lvector
,
inplace
=
scalar_bool
,
inplace
=
scalar_bool
,
)
)
@property
def
__init__
(
self
,
input_broadcastable
,
new_order
):
def
_new_order
(
self
):
# Param for C code.
# self.new_order may contain 'x', which is not a valid integer value.
# We replace it with -1.
return
[(
-
1
if
x
==
"x"
else
x
)
for
x
in
self
.
new_order
]
@property
def
transposition
(
self
):
return
self
.
shuffle
+
self
.
drop
def
__init__
(
self
,
input_broadcastable
,
new_order
,
inplace
=
True
):
super
()
.
__init__
([
self
.
c_func_file
],
self
.
c_func_name
)
super
()
.
__init__
([
self
.
c_func_file
],
self
.
c_func_name
)
self
.
input_broadcastable
=
tuple
(
input_broadcastable
)
self
.
input_broadcastable
=
tuple
(
input_broadcastable
)
self
.
new_order
=
tuple
(
new_order
)
self
.
new_order
=
tuple
(
new_order
)
if
inplace
is
True
:
self
.
inplace
=
inplace
self
.
inplace
=
True
else
:
raise
ValueError
(
"DimShuffle is inplace by default and hence the inplace for DimShuffle must be true"
)
for
i
,
j
in
enumerate
(
new_order
):
for
i
,
j
in
enumerate
(
new_order
):
if
j
!=
"x"
:
if
j
!=
"x"
:
# There is a bug in numpy that results in
# isinstance(x, integer_types) returning False for
# numpy integers. See
# <http://projects.scipy.org/numpy/ticket/2235>.
if
not
isinstance
(
j
,
(
int
,
np
.
integer
)):
if
not
isinstance
(
j
,
(
int
,
np
.
integer
)):
raise
TypeError
(
raise
TypeError
(
"DimShuffle indices must be
python ints.
"
"DimShuffle indices must be
Python ints; got
"
f
"
Got: '{j}' of type '{type(j)}'
."
f
"
{j} of type {type(j)}
."
)
)
if
j
>=
len
(
input_broadcastable
):
if
j
>=
len
(
input_broadcastable
):
raise
ValueError
(
raise
ValueError
(
...
@@ -169,31 +149,30 @@ class DimShuffle(ExternalCOp):
...
@@ -169,31 +149,30 @@ class DimShuffle(ExternalCOp):
if
j
in
new_order
[(
i
+
1
)
:]:
if
j
in
new_order
[(
i
+
1
)
:]:
raise
ValueError
(
raise
ValueError
(
"The same input dimension may not appear "
"The same input dimension may not appear "
"twice in the list of output dimensions"
,
f
"twice in the list of output dimensions: {new_order}"
new_order
,
)
)
#
list of dimensions of the input
to drop
#
List of input dimensions
to drop
self
.
drop
=
[]
drop
=
[]
for
i
,
b
in
enumerate
(
input_broadcastable
):
for
i
,
b
in
enumerate
(
input_broadcastable
):
if
i
not
in
new_order
:
if
i
not
in
new_order
:
#
w
e want to drop this dimension because it's not a value in
#
W
e want to drop this dimension because it's not a value in
#
new_order
#
`new_order`
if
b
==
1
:
# 1 aka True
if
b
==
1
:
self
.
drop
.
append
(
i
)
drop
.
append
(
i
)
else
:
else
:
#
w
e cannot drop non-broadcastable dimensions
#
W
e cannot drop non-broadcastable dimensions
raise
ValueError
(
raise
ValueError
(
"
You cannot drop a non-broadcastable dimension:"
,
"
Cannot drop a non-broadcastable dimension: "
f
"
{input_broadcastable}, {new_order}"
,
f
"
{input_broadcastable}, {new_order}"
)
)
#
t
his is the list of the original dimensions that we keep
#
T
his is the list of the original dimensions that we keep
self
.
shuffle
=
[
x
for
x
in
new_order
if
x
!=
"x"
]
self
.
shuffle
=
[
x
for
x
in
new_order
if
x
!=
"x"
]
self
.
transposition
=
self
.
shuffle
+
drop
#
l
ist of dimensions of the output that are broadcastable and were not
#
L
ist of dimensions of the output that are broadcastable and were not
# in the original input
# in the original input
self
.
augment
=
[
i
for
i
,
x
in
enumerate
(
new_order
)
if
x
==
"x"
]
self
.
augment
=
sorted
([
i
for
i
,
x
in
enumerate
(
new_order
)
if
x
==
"x"
])
if
self
.
inplace
:
if
self
.
inplace
:
self
.
view_map
=
{
0
:
[
0
]}
self
.
view_map
=
{
0
:
[
0
]}
...
@@ -241,27 +220,23 @@ class DimShuffle(ExternalCOp):
...
@@ -241,27 +220,23 @@ class DimShuffle(ExternalCOp):
return
"DimShuffle{
%
s}"
%
","
.
join
(
str
(
x
)
for
x
in
self
.
new_order
)
return
"DimShuffle{
%
s}"
%
","
.
join
(
str
(
x
)
for
x
in
self
.
new_order
)
def
perform
(
self
,
node
,
inp
,
out
,
params
):
def
perform
(
self
,
node
,
inp
,
out
,
params
):
(
input
,)
=
inp
(
res
,)
=
inp
(
storage
,)
=
out
(
storage
,)
=
out
# drop
res
=
input
if
type
(
res
)
!=
np
.
ndarray
and
type
(
res
)
!=
np
.
memmap
:
if
type
(
res
)
!=
np
.
ndarray
and
type
(
res
)
!=
np
.
memmap
:
raise
TypeError
(
res
)
raise
TypeError
(
res
)
# transpose
res
=
res
.
transpose
(
self
.
transposition
)
res
=
res
.
transpose
(
self
.
shuffle
+
self
.
drop
)
# augment
shape
=
list
(
res
.
shape
[:
len
(
self
.
shuffle
)])
shape
=
list
(
res
.
shape
[:
len
(
self
.
shuffle
)])
for
augm
in
self
.
augment
:
for
augm
in
self
.
augment
:
shape
.
insert
(
augm
,
1
)
shape
.
insert
(
augm
,
1
)
res
=
res
.
reshape
(
shape
)
res
=
res
.
reshape
(
shape
)
# copy (if not inplace)
if
not
self
.
inplace
:
if
not
self
.
inplace
:
res
=
np
.
copy
(
res
)
res
=
np
.
copy
(
res
)
storage
[
0
]
=
np
.
asarray
(
res
)
# asarray puts scalars back into array
storage
[
0
]
=
np
.
asarray
(
res
)
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
(
ishp
,)
=
shapes
(
ishp
,)
=
shapes
...
...
aesara/tensor/inplace.py
浏览文件 @
e593b0ac
...
@@ -399,4 +399,4 @@ pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right"))
...
@@ -399,4 +399,4 @@ pprint.assign(pow_inplace, printing.OperatorPrinter("**=", 1, "right"))
def
transpose_inplace
(
x
,
**
kwargs
):
def
transpose_inplace
(
x
,
**
kwargs
):
"Perform a transpose on a tensor without copying the underlying storage"
"Perform a transpose on a tensor without copying the underlying storage"
dims
=
list
(
range
(
x
.
ndim
-
1
,
-
1
,
-
1
))
dims
=
list
(
range
(
x
.
ndim
-
1
,
-
1
,
-
1
))
return
DimShuffle
(
x
.
broadcastable
,
dims
,
inplace
=
True
)(
x
)
return
DimShuffle
(
x
.
broadcastable
,
dims
)(
x
)
tests/link/test_jax.py
浏览文件 @
e593b0ac
...
@@ -856,7 +856,7 @@ def test_jax_Dimshuffle():
...
@@ -856,7 +856,7 @@ def test_jax_Dimshuffle():
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
a_aet
=
tensor
(
dtype
=
config
.
floatX
,
broadcastable
=
[
False
,
True
])
a_aet
=
tensor
(
dtype
=
config
.
floatX
,
broadcastable
=
[
False
,
True
])
x
=
aet_elemwise
.
DimShuffle
([
False
,
True
],
(
0
,)
,
inplace
=
True
)(
a_aet
)
x
=
aet_elemwise
.
DimShuffle
([
False
,
True
],
(
0
,))(
a_aet
)
x_fg
=
FunctionGraph
([
a_aet
],
[
x
])
x_fg
=
FunctionGraph
([
a_aet
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
...
...
tests/link/test_numba.py
浏览文件 @
e593b0ac
...
@@ -653,7 +653,7 @@ def test_AllocDiag(v, offset):
...
@@ -653,7 +653,7 @@ def test_AllocDiag(v, offset):
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"v, new_order
, inplace
"
,
"v, new_order"
,
[
[
# `{'drop': [], 'shuffle': [], 'augment': [0, 1]}`
# `{'drop': [], 'shuffle': [], 'augment': [0, 1]}`
(
(
...
@@ -662,7 +662,6 @@ def test_AllocDiag(v, offset):
...
@@ -662,7 +662,6 @@ def test_AllocDiag(v, offset):
np
.
array
(
1
,
dtype
=
np
.
int64
),
np
.
array
(
1
,
dtype
=
np
.
int64
),
),
),
(
"x"
,
"x"
),
(
"x"
,
"x"
),
True
,
),
),
# I.e. `a_aet.T`
# I.e. `a_aet.T`
# `{'drop': [], 'shuffle': [1, 0], 'augment': []}`
# `{'drop': [], 'shuffle': [1, 0], 'augment': []}`
...
@@ -671,7 +670,6 @@ def test_AllocDiag(v, offset):
...
@@ -671,7 +670,6 @@ def test_AllocDiag(v, offset):
aet
.
matrix
(
"a"
),
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
aet
.
matrix
(
"a"
),
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
),
),
(
1
,
0
),
(
1
,
0
),
True
,
),
),
# `{'drop': [], 'shuffle': [0, 1], 'augment': [2]}`
# `{'drop': [], 'shuffle': [0, 1], 'augment': [2]}`
(
(
...
@@ -679,7 +677,6 @@ def test_AllocDiag(v, offset):
...
@@ -679,7 +677,6 @@ def test_AllocDiag(v, offset):
aet
.
matrix
(
"a"
),
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
aet
.
matrix
(
"a"
),
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
),
),
(
1
,
0
,
"x"
),
(
1
,
0
,
"x"
),
True
,
),
),
# `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}`
# `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}`
(
(
...
@@ -688,7 +685,6 @@ def test_AllocDiag(v, offset):
...
@@ -688,7 +685,6 @@ def test_AllocDiag(v, offset):
np
.
array
([[[
1.0
,
2.0
]],
[[
3.0
,
4.0
]]],
dtype
=
config
.
floatX
),
np
.
array
([[[
1.0
,
2.0
]],
[[
3.0
,
4.0
]]],
dtype
=
config
.
floatX
),
),
),
(
"x"
,
2
,
"x"
,
0
,
"x"
),
(
"x"
,
2
,
"x"
,
0
,
"x"
),
True
,
),
),
# I.e. `a_aet.dimshuffle((0,))`
# I.e. `a_aet.dimshuffle((0,))`
# `{'drop': [1], 'shuffle': [0], 'augment': []}`
# `{'drop': [1], 'shuffle': [0], 'augment': []}`
...
@@ -698,7 +694,6 @@ def test_AllocDiag(v, offset):
...
@@ -698,7 +694,6 @@ def test_AllocDiag(v, offset):
np
.
array
([[
1.0
],
[
2.0
],
[
3.0
],
[
4.0
]],
dtype
=
config
.
floatX
),
np
.
array
([[
1.0
],
[
2.0
],
[
3.0
],
[
4.0
]],
dtype
=
config
.
floatX
),
),
),
(
0
,),
(
0
,),
True
,
),
),
(
(
set_test_value
(
set_test_value
(
...
@@ -706,7 +701,6 @@ def test_AllocDiag(v, offset):
...
@@ -706,7 +701,6 @@ def test_AllocDiag(v, offset):
np
.
array
([[
1.0
],
[
2.0
],
[
3.0
],
[
4.0
]],
dtype
=
config
.
floatX
),
np
.
array
([[
1.0
],
[
2.0
],
[
3.0
],
[
4.0
]],
dtype
=
config
.
floatX
),
),
),
(
0
,),
(
0
,),
True
,
),
),
(
(
set_test_value
(
set_test_value
(
...
@@ -714,12 +708,11 @@ def test_AllocDiag(v, offset):
...
@@ -714,12 +708,11 @@ def test_AllocDiag(v, offset):
np
.
array
([[[
1.0
]]],
dtype
=
config
.
floatX
),
np
.
array
([[[
1.0
]]],
dtype
=
config
.
floatX
),
),
),
(),
(),
True
,
),
),
],
],
)
)
def
test_Dimshuffle
(
v
,
new_order
,
inplace
):
def
test_Dimshuffle
(
v
,
new_order
):
g
=
aet_elemwise
.
DimShuffle
(
v
.
broadcastable
,
new_order
,
inplace
=
inplace
)(
v
)
g
=
aet_elemwise
.
DimShuffle
(
v
.
broadcastable
,
new_order
)(
v
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
g_fg
,
...
...
tests/tensor/test_elemwise.py
浏览文件 @
e593b0ac
...
@@ -52,12 +52,12 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
...
@@ -52,12 +52,12 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
ib
=
[(
entry
==
1
)
for
entry
in
xsh
]
ib
=
[(
entry
==
1
)
for
entry
in
xsh
]
x
=
self
.
type
(
self
.
dtype
,
ib
)(
"x"
)
x
=
self
.
type
(
self
.
dtype
,
ib
)(
"x"
)
e
=
self
.
op
(
ib
,
shuffle
)(
x
)
e
=
self
.
op
(
ib
,
shuffle
)(
x
)
f
=
copy
(
linker
)
.
accept
(
FunctionGraph
([
x
],
[
e
]))
.
make_function
(
)
f
=
aesara
.
function
([
x
],
e
,
mode
=
Mode
(
linker
=
linker
)
)
assert
f
(
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
))
.
shape
==
zsh
assert
f
(
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
))
.
shape
==
zsh
# test that DimShuffle.infer_shape work correctly
# test that DimShuffle.infer_shape work correctly
x
=
self
.
type
(
self
.
dtype
,
ib
)(
"x"
)
x
=
self
.
type
(
self
.
dtype
,
ib
)(
"x"
)
e
=
self
.
op
(
ib
,
shuffle
)(
x
)
e
=
self
.
op
(
ib
,
shuffle
)(
x
)
f
=
copy
(
linker
)
.
accept
(
FunctionGraph
([
x
],
[
e
.
shape
]))
.
make_function
(
)
f
=
aesara
.
function
([
x
],
e
.
shape
,
mode
=
Mode
(
linker
=
linker
)
)
assert
all
(
f
(
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
)))
==
all
(
zsh
)
assert
all
(
f
(
np
.
ones
(
xsh
,
dtype
=
self
.
dtype
)))
==
all
(
zsh
)
# Test when we drop a axis that is not broadcastable
# Test when we drop a axis that is not broadcastable
...
@@ -70,7 +70,7 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
...
@@ -70,7 +70,7 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
ib
=
[
True
,
True
,
False
]
ib
=
[
True
,
True
,
False
]
x
=
self
.
type
(
self
.
dtype
,
ib
)(
"x"
)
x
=
self
.
type
(
self
.
dtype
,
ib
)(
"x"
)
e
=
self
.
op
(
ib
,
(
1
,
2
))(
x
)
e
=
self
.
op
(
ib
,
(
1
,
2
))(
x
)
f
=
copy
(
linker
)
.
accept
(
FunctionGraph
([
x
],
[
e
.
shape
]))
.
make_function
(
)
f
=
aesara
.
function
([
x
],
e
.
shape
,
mode
=
Mode
(
linker
=
linker
)
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
f
(
np
.
ones
((
2
,
1
,
4
)))
f
(
np
.
ones
((
2
,
1
,
4
)))
...
@@ -119,6 +119,25 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
...
@@ -119,6 +119,25 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
y
.
eval
({
x
:
0
})
y
.
eval
({
x
:
0
})
def
test_c_views
(
self
):
x_at
=
vector
()
thunk
,
inputs
,
outputs
=
(
CLinker
()
.
accept
(
FunctionGraph
([
x_at
],
[
x_at
[
None
]]))
.
make_thunk
()
)
# This is a little hackish, but we're hoping that--by running this more than
# a few times--we're more likely to run into random memory that isn't the same
# as the broadcasted value; that way, we'll be able to tell that we're getting
# junk data from a poorly constructed array view.
x_val
=
np
.
broadcast_to
(
2039
,
(
5000
,))
for
i
in
range
(
1000
):
inputs
[
0
]
.
storage
[
0
]
=
x_val
thunk
()
# Make sure it's a view of the original data
assert
np
.
shares_memory
(
x_val
,
outputs
[
0
]
.
storage
[
0
])
# Confirm the broadcasted value in the output
assert
np
.
array_equiv
(
outputs
[
0
]
.
storage
[
0
],
2039
)
class
TestBroadcast
:
class
TestBroadcast
:
# this is to allow other types to reuse this class to test their ops
# this is to allow other types to reuse this class to test their ops
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论