Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
28fa7b76
提交
28fa7b76
authored
2月 21, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
2月 25, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove unused inplace option in DimShuffle
上级
ff5df65f
隐藏空白字符变更
内嵌
并排
正在显示
5 个修改的文件
包含
17 行增加
和
55 行删除
+17
-55
elemwise.py
pytensor/link/jax/dispatch/elemwise.py
+1
-6
elemwise.py
pytensor/link/numba/dispatch/elemwise.py
+1
-7
elemwise.py
pytensor/link/pytorch/dispatch/elemwise.py
+1
-6
dimshuffle.c
pytensor/tensor/c_code/dimshuffle.c
+11
-23
elemwise.py
pytensor/tensor/elemwise.py
+3
-13
没有找到文件。
pytensor/link/jax/dispatch/elemwise.py
浏览文件 @
28fa7b76
...
@@ -79,12 +79,7 @@ def jax_funcify_DimShuffle(op, **kwargs):
...
@@ -79,12 +79,7 @@ def jax_funcify_DimShuffle(op, **kwargs):
for
augm
in
op
.
augment
:
for
augm
in
op
.
augment
:
shape
.
insert
(
augm
,
1
)
shape
.
insert
(
augm
,
1
)
res
=
jnp
.
reshape
(
res
,
shape
)
return
jnp
.
reshape
(
res
,
shape
)
if
not
op
.
inplace
:
res
=
jnp
.
copy
(
res
)
return
res
return
dimshuffle
return
dimshuffle
...
...
pytensor/link/numba/dispatch/elemwise.py
浏览文件 @
28fa7b76
...
@@ -414,7 +414,6 @@ def numba_funcify_DimShuffle(op, node, **kwargs):
...
@@ -414,7 +414,6 @@ def numba_funcify_DimShuffle(op, node, **kwargs):
shuffle
=
tuple
(
op
.
shuffle
)
shuffle
=
tuple
(
op
.
shuffle
)
transposition
=
tuple
(
op
.
transposition
)
transposition
=
tuple
(
op
.
transposition
)
augment
=
tuple
(
op
.
augment
)
augment
=
tuple
(
op
.
augment
)
inplace
=
op
.
inplace
ndim_new_shape
=
len
(
shuffle
)
+
len
(
augment
)
ndim_new_shape
=
len
(
shuffle
)
+
len
(
augment
)
...
@@ -474,12 +473,7 @@ def numba_funcify_DimShuffle(op, node, **kwargs):
...
@@ -474,12 +473,7 @@ def numba_funcify_DimShuffle(op, node, **kwargs):
new_shape
=
find_shape
(
shuffle_shape
)
new_shape
=
find_shape
(
shuffle_shape
)
# FIXME: Numba's `array.reshape` only accepts C arrays.
# FIXME: Numba's `array.reshape` only accepts C arrays.
res_reshape
=
np
.
reshape
(
np
.
ascontiguousarray
(
x
),
new_shape
)
return
np
.
reshape
(
np
.
ascontiguousarray
(
x
),
new_shape
)
if
not
inplace
:
return
res_reshape
.
copy
()
else
:
return
res_reshape
else
:
else
:
...
...
pytensor/link/pytorch/dispatch/elemwise.py
浏览文件 @
28fa7b76
...
@@ -61,12 +61,7 @@ def pytorch_funcify_DimShuffle(op, **kwargs):
...
@@ -61,12 +61,7 @@ def pytorch_funcify_DimShuffle(op, **kwargs):
for
augm
in
op
.
augment
:
for
augm
in
op
.
augment
:
shape
.
insert
(
augm
,
1
)
shape
.
insert
(
augm
,
1
)
res
=
torch
.
reshape
(
res
,
shape
)
return
torch
.
reshape
(
res
,
shape
)
if
not
op
.
inplace
:
res
=
res
.
clone
()
return
res
return
dimshuffle
return
dimshuffle
...
...
pytensor/tensor/c_code/dimshuffle.c
浏览文件 @
28fa7b76
...
@@ -7,10 +7,6 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
...
@@ -7,10 +7,6 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
npy_intp
*
dimensions
;
npy_intp
*
dimensions
;
npy_intp
*
strides
;
npy_intp
*
strides
;
// This points to either the original input or a copy we create below.
// Either way, this is what we should be working on/with.
PyArrayObject
*
_input
;
if
(
!
PyArray_IS_C_CONTIGUOUS
(
params
->
_new_order
))
{
if
(
!
PyArray_IS_C_CONTIGUOUS
(
params
->
_new_order
))
{
PyErr_SetString
(
PyExc_RuntimeError
,
"DimShuffle: param _new_order must be C-contiguous."
);
PyErr_SetString
(
PyExc_RuntimeError
,
"DimShuffle: param _new_order must be C-contiguous."
);
return
1
;
return
1
;
...
@@ -20,7 +16,7 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
...
@@ -20,7 +16,7 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
nd_out
=
PyArray_SIZE
(
params
->
_new_order
);
nd_out
=
PyArray_SIZE
(
params
->
_new_order
);
if
(
PyArray_NDIM
(
input
)
!=
nd_in
)
{
if
(
PyArray_NDIM
(
input
)
!=
nd_in
)
{
PyErr_SetString
(
PyExc_
NotImplemented
Error
,
"DimShuffle: Input has less dimensions than expected."
);
PyErr_SetString
(
PyExc_
Value
Error
,
"DimShuffle: Input has less dimensions than expected."
);
return
1
;
return
1
;
}
}
...
@@ -34,12 +30,12 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
...
@@ -34,12 +30,12 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
return
1
;
return
1
;
};
};
npy_intp
original_size
=
PyArray_SIZE
(
_
input
);
npy_intp
original_size
=
PyArray_SIZE
(
input
);
npy_intp
new_size
=
1
;
npy_intp
new_size
=
1
;
for
(
npy_intp
i
=
0
;
i
<
nd_out
;
++
i
)
{
for
(
npy_intp
i
=
0
;
i
<
nd_out
;
++
i
)
{
if
(
new_order
[
i
]
!=
-
1
)
{
if
(
new_order
[
i
]
!=
-
1
)
{
dimensions
[
i
]
=
PyArray_DIMS
(
_
input
)[
new_order
[
i
]];
dimensions
[
i
]
=
PyArray_DIMS
(
input
)[
new_order
[
i
]];
strides
[
i
]
=
PyArray_DIMS
(
_input
)[
new_order
[
i
]]
==
1
?
0
:
PyArray_STRIDES
(
_
input
)[
new_order
[
i
]];
strides
[
i
]
=
PyArray_DIMS
(
input
)[
new_order
[
i
]]
==
1
?
0
:
PyArray_STRIDES
(
input
)[
new_order
[
i
]];
}
else
{
}
else
{
dimensions
[
i
]
=
1
;
dimensions
[
i
]
=
1
;
strides
[
i
]
=
0
;
strides
[
i
]
=
0
;
...
@@ -57,22 +53,13 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
...
@@ -57,22 +53,13 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
if
(
*
res
)
if
(
*
res
)
Py_XDECREF
(
*
res
);
Py_XDECREF
(
*
res
);
if
(
params
->
inplace
)
{
_input
=
input
;
Py_INCREF
((
PyObject
*
)
_input
);
}
else
{
_input
=
(
PyArrayObject
*
)
PyArray_FromAny
(
(
PyObject
*
)
input
,
NULL
,
0
,
0
,
NPY_ARRAY_ALIGNED
|
NPY_ARRAY_ENSURECOPY
,
NULL
);
}
// Create the new array.
// Create the new array.
*
res
=
(
PyArrayObject
*
)
PyArray_New
(
&
PyArray_Type
,
nd_out
,
dimensions
,
*
res
=
(
PyArrayObject
*
)
PyArray_New
(
&
PyArray_Type
,
nd_out
,
dimensions
,
PyArray_TYPE
(
_
input
),
strides
,
PyArray_TYPE
(
input
),
strides
,
PyArray_DATA
(
_input
),
PyArray_ITEMSIZE
(
_
input
),
PyArray_DATA
(
input
),
PyArray_ITEMSIZE
(
input
),
// borrow only the writable flag from the base
// borrow only the writable flag from the base
// the NPY_OWNDATA flag will default to 0.
// the NPY_OWNDATA flag will default to 0.
(
NPY_ARRAY_WRITEABLE
*
PyArray_ISWRITEABLE
(
_
input
)),
(
NPY_ARRAY_WRITEABLE
*
PyArray_ISWRITEABLE
(
input
)),
NULL
);
NULL
);
if
(
*
res
==
NULL
)
{
if
(
*
res
==
NULL
)
{
...
@@ -81,12 +68,13 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
...
@@ -81,12 +68,13 @@ int APPLY_SPECIFIC(cpu_dimshuffle)(PyArrayObject *input, PyArrayObject **res, PA
return
1
;
return
1
;
}
}
// Declare it a view of the original input
Py_INCREF
((
PyObject
*
)
input
);
PyArray_SetBaseObject
(
*
res
,
(
PyObject
*
)
input
);
// recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
// recalculate flags: CONTIGUOUS, FORTRAN, ALIGNED
PyArray_UpdateFlags
(
*
res
,
NPY_ARRAY_UPDATE_ALL
);
PyArray_UpdateFlags
(
*
res
,
NPY_ARRAY_UPDATE_ALL
);
// we are making a view in both inplace and non-inplace cases
PyArray_SetBaseObject
(
*
res
,
(
PyObject
*
)
_input
);
free
(
strides
);
free
(
strides
);
free
(
dimensions
);
free
(
dimensions
);
return
0
;
return
0
;
...
...
pytensor/tensor/elemwise.py
浏览文件 @
28fa7b76
...
@@ -19,7 +19,6 @@ from pytensor.misc.frozendict import frozendict
...
@@ -19,7 +19,6 @@ from pytensor.misc.frozendict import frozendict
from
pytensor.npy_2_compat
import
normalize_axis_tuple
from
pytensor.npy_2_compat
import
normalize_axis_tuple
from
pytensor.printing
import
Printer
,
pprint
from
pytensor.printing
import
Printer
,
pprint
from
pytensor.scalar
import
get_scalar_type
from
pytensor.scalar
import
get_scalar_type
from
pytensor.scalar.basic
import
bool
as
scalar_bool
from
pytensor.scalar.basic
import
identity
as
scalar_identity
from
pytensor.scalar.basic
import
identity
as
scalar_identity
from
pytensor.scalar.basic
import
int64
,
transfer_type
,
upcast
from
pytensor.scalar.basic
import
int64
,
transfer_type
,
upcast
from
pytensor.tensor
import
elemwise_cgen
as
cgen
from
pytensor.tensor
import
elemwise_cgen
as
cgen
...
@@ -114,15 +113,15 @@ class DimShuffle(ExternalCOp):
...
@@ -114,15 +113,15 @@ class DimShuffle(ExternalCOp):
_f16_ok
=
True
_f16_ok
=
True
check_input
=
False
check_input
=
False
__props__
=
(
"input_ndim"
,
"new_order"
,
"inplace"
)
__props__
=
(
"input_ndim"
,
"new_order"
)
c_func_file
=
"c_code/dimshuffle.c"
c_func_file
=
"c_code/dimshuffle.c"
c_func_name
=
"APPLY_SPECIFIC(cpu_dimshuffle)"
c_func_name
=
"APPLY_SPECIFIC(cpu_dimshuffle)"
view_map
=
{
0
:
[
0
]}
@property
@property
def
params_type
(
self
):
def
params_type
(
self
):
return
ParamsType
(
return
ParamsType
(
_new_order
=
lvector
,
_new_order
=
lvector
,
inplace
=
scalar_bool
,
input_ndim
=
int64
,
input_ndim
=
int64
,
)
)
...
@@ -135,7 +134,6 @@ class DimShuffle(ExternalCOp):
...
@@ -135,7 +134,6 @@ class DimShuffle(ExternalCOp):
self
.
input_ndim
=
input_ndim
self
.
input_ndim
=
input_ndim
self
.
new_order
=
tuple
(
new_order
)
self
.
new_order
=
tuple
(
new_order
)
self
.
_new_order
=
[(
-
1
if
x
==
"x"
else
x
)
for
x
in
self
.
new_order
]
self
.
_new_order
=
[(
-
1
if
x
==
"x"
else
x
)
for
x
in
self
.
new_order
]
self
.
inplace
=
True
for
i
,
j
in
enumerate
(
new_order
):
for
i
,
j
in
enumerate
(
new_order
):
if
j
!=
"x"
:
if
j
!=
"x"
:
...
@@ -178,9 +176,6 @@ class DimShuffle(ExternalCOp):
...
@@ -178,9 +176,6 @@ class DimShuffle(ExternalCOp):
:
input_ndim
:
input_ndim
]
==
list
(
range
(
input_ndim
))
]
==
list
(
range
(
input_ndim
))
if
self
.
inplace
:
self
.
view_map
=
{
0
:
[
0
]}
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
self
.
__dict__
.
update
(
state
)
self
.
__dict__
.
update
(
state
)
if
not
hasattr
(
self
,
"func_files"
):
if
not
hasattr
(
self
,
"func_files"
):
...
@@ -248,12 +243,7 @@ class DimShuffle(ExternalCOp):
...
@@ -248,12 +243,7 @@ class DimShuffle(ExternalCOp):
new_shape
=
list
(
res
.
shape
[:
len
(
self
.
shuffle
)])
new_shape
=
list
(
res
.
shape
[:
len
(
self
.
shuffle
)])
for
augm
in
self
.
augment
:
for
augm
in
self
.
augment
:
new_shape
.
insert
(
augm
,
1
)
new_shape
.
insert
(
augm
,
1
)
res
=
res
.
reshape
(
new_shape
)
out
[
0
][
0
]
=
res
.
reshape
(
new_shape
)
if
not
self
.
inplace
:
res
=
np
.
copy
(
res
)
out
[
0
][
0
]
=
res
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
def
infer_shape
(
self
,
fgraph
,
node
,
shapes
):
(
ishp
,)
=
shapes
(
ishp
,)
=
shapes
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论