Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
edcbac8e
提交
edcbac8e
authored
12月 31, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
1月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Clean up Join.make_node
上级
59f09d09
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
26 行增加
和
53 行删除
+26
-53
basic.py
aesara/tensor/basic.py
+26
-53
没有找到文件。
aesara/tensor/basic.py
浏览文件 @
edcbac8e
...
@@ -2240,47 +2240,32 @@ class Join(COp):
...
@@ -2240,47 +2240,32 @@ class Join(COp):
if
not
hasattr
(
self
,
"view"
):
if
not
hasattr
(
self
,
"view"
):
self
.
view
=
-
1
self
.
view
=
-
1
def
make_node
(
self
,
*
axis_and_
tensors
):
def
make_node
(
self
,
axis
,
*
tensors
):
"""
"""
Parameters
Parameters
----------
----------
axis: an Int or integer-valued Variable
axis
The axis upon which to join `tensors`.
tensors
tensors
A variable number (but not zero) of tensors to
A variable number of tensors to join along the specified axis.
concatenate along the specified axis. These tensors must have
These tensors must have the same shape along all dimensions other
the same shape along all dimensions other than this axis.
than `axis`.
Returns
-------
A symbolic Variable
It has the same ndim as the input tensors, and the most inclusive
dtype.
"""
"""
axis
,
tens
=
axis_and_tensors
[
0
],
axis_and_tensors
[
1
:]
if
not
tensors
:
if
not
tens
:
raise
ValueError
(
"Cannot join an empty list of tensors"
)
raise
ValueError
(
"Cannot join an empty list of tensors"
)
as_tensor_variable_args
=
[
as_tensor_variable
(
x
)
for
x
in
tens
]
dtypes
=
[
x
.
type
.
dtype
for
x
in
as_tensor_variable_args
]
out_dtype
=
aes
.
upcast
(
*
dtypes
)
def
output_maker
(
bcastable
):
return
tensor
(
dtype
=
out_dtype
,
broadcastable
=
bcastable
)
return
self
.
_make_node_internal
(
tensors
=
[
as_tensor_variable
(
x
)
for
x
in
tensors
]
axis
,
tens
,
as_tensor_variable_args
,
output_maker
out_dtype
=
aes
.
upcast
(
*
[
x
.
type
.
dtype
for
x
in
tensors
])
)
def
_make_node_internal
(
self
,
axis
,
tens
,
as_tensor_variable_args
,
output_maker
):
if
not
builtins
.
all
(
targs
.
type
.
ndim
for
targs
in
tensors
):
if
not
builtins
.
all
(
targs
.
type
.
ndim
for
targs
in
as_tensor_variable_args
):
raise
TypeError
(
raise
TypeError
(
"Join cannot handle arguments of dimension 0."
"Join cannot handle arguments of dimension 0."
"
For joining scalar values, see @stack
"
"
Use `stack` to join scalar values.
"
)
)
# Handle single-tensor joins immediately.
# Handle single-tensor joins immediately.
if
len
(
as_tensor_variable_arg
s
)
==
1
:
if
len
(
tensor
s
)
==
1
:
bcastable
=
list
(
as_tensor_variable_arg
s
[
0
]
.
type
.
broadcastable
)
bcastable
=
list
(
tensor
s
[
0
]
.
type
.
broadcastable
)
else
:
else
:
# When the axis is fixed, a dimension should be
# When the axis is fixed, a dimension should be
# broadcastable if at least one of the inputs is
# broadcastable if at least one of the inputs is
...
@@ -2288,17 +2273,15 @@ class Join(COp):
...
@@ -2288,17 +2273,15 @@ class Join(COp):
# except for the axis dimension.
# except for the axis dimension.
# Initialize bcastable all false, and then fill in some trues with
# Initialize bcastable all false, and then fill in some trues with
# the loops.
# the loops.
bcastable
=
[
False
]
*
len
(
as_tensor_variable_arg
s
[
0
]
.
type
.
broadcastable
)
bcastable
=
[
False
]
*
len
(
tensor
s
[
0
]
.
type
.
broadcastable
)
ndim
=
len
(
bcastable
)
ndim
=
len
(
bcastable
)
# Axis can also be a constant
if
not
isinstance
(
axis
,
int
):
if
not
isinstance
(
axis
,
int
):
try
:
try
:
# Note : `get_scalar_constant_value` returns a ndarray not
# an int
axis
=
int
(
get_scalar_constant_value
(
axis
))
axis
=
int
(
get_scalar_constant_value
(
axis
))
except
NotScalarConstantError
:
except
NotScalarConstantError
:
pass
pass
if
isinstance
(
axis
,
int
):
if
isinstance
(
axis
,
int
):
# Basically, broadcastable -> length 1, but the
# Basically, broadcastable -> length 1, but the
# converse does not hold. So we permit e.g. T/F/T
# converse does not hold. So we permit e.g. T/F/T
...
@@ -2310,12 +2293,12 @@ class Join(COp):
...
@@ -2310,12 +2293,12 @@ class Join(COp):
if
axis
<
-
ndim
:
if
axis
<
-
ndim
:
raise
IndexError
(
raise
IndexError
(
f
"
Join axis {int(axis)} out of bounds [0, {int(ndim)})
"
f
"
Axis value {axis} is out of range for the given input dimensions
"
)
)
if
axis
<
0
:
if
axis
<
0
:
axis
+=
ndim
axis
+=
ndim
for
x
in
as_tensor_variable_arg
s
:
for
x
in
tensor
s
:
for
current_axis
,
bflag
in
enumerate
(
x
.
type
.
broadcastable
):
for
current_axis
,
bflag
in
enumerate
(
x
.
type
.
broadcastable
):
# Constant negative axis can no longer be negative at
# Constant negative axis can no longer be negative at
# this point. It safe to compare this way.
# this point. It safe to compare this way.
...
@@ -2327,34 +2310,24 @@ class Join(COp):
...
@@ -2327,34 +2310,24 @@ class Join(COp):
bcastable
[
axis
]
=
False
bcastable
[
axis
]
=
False
except
IndexError
:
except
IndexError
:
raise
ValueError
(
raise
ValueError
(
'Join argument "axis" is out of range'
f
"Axis value {axis} is out of range for the given input dimensions"
" (given input dimensions)"
)
)
else
:
else
:
# When the axis may vary, no dimension can be guaranteed to be
# When the axis may vary, no dimension can be guaranteed to be
# broadcastable.
# broadcastable.
bcastable
=
[
False
]
*
len
(
as_tensor_variable_arg
s
[
0
]
.
type
.
broadcastable
)
bcastable
=
[
False
]
*
len
(
tensor
s
[
0
]
.
type
.
broadcastable
)
if
not
builtins
.
all
(
if
not
builtins
.
all
([
x
.
ndim
==
len
(
bcastable
)
for
x
in
tensors
]):
[
x
.
ndim
==
len
(
bcastable
)
for
x
in
as_tensor_variable_args
[
1
:]]
):
raise
TypeError
(
raise
TypeError
(
"
Join() can only join tensors with the same "
"number of dimensions.
"
"
Only tensors with the same number of dimensions can be joined
"
)
)
inputs
=
[
as_tensor_variable
(
axis
)]
+
list
(
as_tensor_variable_args
)
inputs
=
[
as_tensor_variable
(
axis
)]
+
list
(
tensors
)
if
inputs
[
0
]
.
type
not
in
int_types
:
raise
TypeError
(
"Axis could not be cast to an integer type"
,
axis
,
inputs
[
0
]
.
type
,
int_types
,
)
outputs
=
[
output_maker
(
bcastable
)]
if
inputs
[
0
]
.
type
.
dtype
not
in
int_dtypes
:
raise
TypeError
(
f
"Axis value {inputs[0]} must be an integer type"
)
node
=
Apply
(
self
,
inputs
,
outputs
)
return
Apply
(
self
,
inputs
,
[
tensor
(
dtype
=
out_dtype
,
broadcastable
=
bcastable
)])
return
node
def
perform
(
self
,
node
,
axis_and_tensors
,
out_
):
def
perform
(
self
,
node
,
axis_and_tensors
,
out_
):
(
out
,)
=
out_
(
out
,)
=
out_
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论