Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
856aa0b6
提交
856aa0b6
authored
8月 29, 2015
作者:
Frédéric Bastien
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3267 from koningrobot/tensordot-as-dot
Implement batched_tensordot in terms of batched_dot
上级
77f6b2be
8dae1fbe
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
138 行增加
和
119 行删除
+138
-119
basic.py
theano/tensor/basic.py
+138
-119
没有找到文件。
theano/tensor/basic.py
浏览文件 @
856aa0b6
...
@@ -1056,6 +1056,16 @@ def _scal_elemwise_with_nfunc(nfunc, nin, nout):
...
@@ -1056,6 +1056,16 @@ def _scal_elemwise_with_nfunc(nfunc, nin, nout):
_scal_elemwise
=
_scal_elemwise_with_nfunc
(
None
,
None
,
None
)
_scal_elemwise
=
_scal_elemwise_with_nfunc
(
None
,
None
,
None
)
def
_pack
(
x
):
"""
Convert x to a list if it is an iterable, otherwise wrap it in a list.
"""
try
:
return
list
(
x
)
except
TypeError
:
return
[
x
]
#########################
#########################
# Casting Operations
# Casting Operations
#########################
#########################
...
@@ -3357,24 +3367,11 @@ def batched_tensordot(x, y, axes=2):
...
@@ -3357,24 +3367,11 @@ def batched_tensordot(x, y, axes=2):
3rd axis of b must have the same shape; the same is true for
3rd axis of b must have the same shape; the same is true for
the 3rd axis of a and the 5th axis of b.
the 3rd axis of a and the 5th axis of b.
Like tensordot, this function uses a series of dimshuffles and
reshapes to reduce the tensor dot product to a matrix or vector
dot product. Finally, it calls batched_dot to compute the result.
"""
"""
if
isinstance
(
axes
,
(
list
,
numpy
.
ndarray
)):
return
_tensordot_as_dot
(
x
,
y
,
axes
,
dot
=
batched_dot
,
batched
=
True
)
if
isinstance
(
axes
,
list
):
axes
=
numpy
.
asarray
(
axes
)
else
:
axes
=
axes
.
copy
()
assert
numpy
.
greater
(
axes
,
0
)
.
all
(),
(
"All axes should be greater than one, as the "
"first axis is iterated over (batch-wise scan)"
)
axes
-=
1
result
,
updates
=
theano
.
scan
(
fn
=
lambda
x_mat
,
y_mat
:
theano
.
tensor
.
tensordot
(
x_mat
,
y_mat
,
axes
),
outputs_info
=
None
,
sequences
=
[
x
,
y
],
non_sequences
=
None
)
return
result
def
split
(
x
,
splits_size
,
n_splits
,
axis
=
0
):
def
split
(
x
,
splits_size
,
n_splits
,
axis
=
0
):
...
@@ -5273,6 +5270,129 @@ def dot(a, b):
...
@@ -5273,6 +5270,129 @@ def dot(a, b):
# Linalg : TensorDot
# Linalg : TensorDot
#########################
#########################
def
_tensordot_as_dot
(
a
,
b
,
axes
,
dot
,
batched
):
"""
Reduces a tensor dot product to a matrix or vector dot product. Based
on code from Tijmen Tieleman's gnumpy
(http://www.cs.toronto.edu/~tijmen/gnumpy.html).
Please see the documentation of tensordot for the meaning of the a, b
and axes arguments.
:param dot: a function that accepts two symbolic variables and computes
the appropriate dot product (e.g. dot, batched_dot)
:type dot: function
:param batched: whether to treat the first axis of a and b as a batch
axis. If so, this axis will be preserved in the output,
allowing this function to be used also for batched
tensor dot products.
:type batched: boolean
:returns: a tensor with shape equal to the concatenation of a's shape
(less any dimensions that were summed over) and b's shape
(less the first dimension and any dimensions that were summed
over).
:rtype: symbolic tensor
"""
a
,
b
=
as_tensor_variable
(
a
),
as_tensor_variable
(
b
)
if
not
numpy
.
isscalar
(
axes
)
and
len
(
axes
)
!=
2
:
raise
ValueError
(
'Axes should be an integer or a '
'list/tuple of len 2 (
%
s was provided)'
%
repr
(
axes
))
# if 'axes' is a number of axes to multiply and sum over (trailing axes
# of a, leading axes of b), we can just reshape and use dot.
elif
numpy
.
isscalar
(
axes
):
axes
=
int
(
axes
)
for
operand_name
,
operand
in
((
"a"
,
a
),
(
"b"
,
b
)):
if
axes
>
operand
.
ndim
:
raise
ValueError
(
'axes can not be larger than the dimension of
%
s '
'(
%
s.ndim=
%
i, axes=
%
i)'
%
(
operand_name
,
operand_name
,
operand
.
ndim
,
axes
))
if
batched
and
axes
==
operand
.
ndim
:
raise
ValueError
(
'axes to sum over must not include the batch axis '
'of
%
s (
%
s.ndim=
%
i, axes=
%
i)'
%
(
operand_name
,
operand_name
,
operand
.
ndim
,
axes
))
batch_axes
=
1
if
batched
else
0
a_outaxes
=
slice
(
0
,
a
.
ndim
-
axes
)
b_outaxes
=
slice
(
batch_axes
+
axes
,
b
.
ndim
)
outshape
=
concatenate
([
a
.
shape
[
a_outaxes
],
b
.
shape
[
b_outaxes
]])
outbcast
=
a
.
broadcastable
[
a_outaxes
]
+
b
.
broadcastable
[
b_outaxes
]
outndim
=
len
(
outbcast
)
a_shape
=
[
1
]
*
2
b_shape
=
[
1
]
*
2
# compute total size of summed axes
for
i
in
xrange
(
0
,
axes
):
a_shape
[
1
]
*=
a
.
shape
[
-
(
i
+
1
)]
b_shape
[
0
]
*=
b
.
shape
[
batch_axes
+
i
]
# compute total size of other axes
for
i
in
xrange
(
0
,
a
.
ndim
-
axes
-
batch_axes
):
a_shape
[
0
]
*=
a
.
shape
[
batch_axes
+
i
]
for
i
in
xrange
(
0
,
b
.
ndim
-
axes
-
batch_axes
):
b_shape
[
1
]
*=
b
.
shape
[
-
(
i
+
1
)]
if
batched
:
a_shape
.
insert
(
0
,
a
.
shape
[
0
])
b_shape
.
insert
(
0
,
b
.
shape
[
0
])
a_reshaped
=
a
.
reshape
(
a_shape
)
b_reshaped
=
b
.
reshape
(
b_shape
)
out_reshaped
=
dot
(
a_reshaped
,
b_reshaped
)
out
=
out_reshaped
.
reshape
(
outshape
,
outndim
)
# Make sure the broadcastable pattern of the result is correct,
# since some shape information can be lost in the reshapes.
return
patternbroadcast
(
out
,
outbcast
)
# if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first.
else
:
axes
=
[
_pack
(
axes_
)
for
axes_
in
axes
]
if
len
(
axes
[
0
])
!=
len
(
axes
[
1
]):
raise
ValueError
(
'Axes elements must have the same length.'
)
for
i
,
(
operand_name
,
operand
)
in
enumerate
(((
"a"
,
a
),
(
"b"
,
b
))):
if
len
(
axes
[
i
])
>
operand
.
ndim
:
raise
ValueError
(
'axes[
%
i] should be array_like with length less than '
'the dimensions of
%
s (
%
s.ndim=
%
i, len(axes[0])=
%
i).'
%
(
i
,
operand_name
,
operand_name
,
operand
.
ndim
,
len
(
axes
[
i
])))
if
len
(
axes
[
i
])
>
0
and
numpy
.
max
(
axes
[
i
])
>=
operand
.
ndim
:
raise
ValueError
(
'axes[
%
i] contains dimensions greater than or equal '
'to
%
s.ndim (
%
s.ndim=
%
i, max(axes[0])=
%
i).'
%
(
i
,
operand_name
,
operand_name
,
operand
.
ndim
,
numpy
.
max
(
numpy
.
array
(
axes
[
i
]))))
if
batched
and
0
in
axes
[
i
]:
raise
ValueError
(
'axes to sum over must not contain the batch axis '
'(axes[
%
i]=
%
s)'
%
(
i
,
axes
[
i
]))
batch_axes
=
[
0
]
if
batched
else
[]
other_axes
=
[[
x
for
x
in
xrange
(
operand
.
ndim
)
if
x
not
in
axes
[
i
]
and
x
not
in
batch_axes
]
for
i
,
operand
in
enumerate
((
a
,
b
))]
a_shuffled
=
a
.
dimshuffle
(
batch_axes
+
other_axes
[
0
]
+
axes
[
0
])
b_shuffled
=
b
.
dimshuffle
(
batch_axes
+
axes
[
1
]
+
other_axes
[
1
])
# now that a and b are in the right order, recur with integer axes
return
_tensordot_as_dot
(
a_shuffled
,
b_shuffled
,
len
(
axes
[
0
]),
dot
=
dot
,
batched
=
batched
)
def
tensordot
(
a
,
b
,
axes
=
2
):
def
tensordot
(
a
,
b
,
axes
=
2
):
"""
"""
Compute a generalized dot product over provided axes.
Compute a generalized dot product over provided axes.
...
@@ -5373,108 +5493,7 @@ def tensordot(a, b, axes=2):
...
@@ -5373,108 +5493,7 @@ def tensordot(a, b, axes=2):
See the documentation of numpy.tensordot for more examples.
See the documentation of numpy.tensordot for more examples.
"""
"""
a
,
b
=
as_tensor_variable
(
a
),
as_tensor_variable
(
b
)
return
_tensordot_as_dot
(
a
,
b
,
axes
,
dot
=
dot
,
batched
=
False
)
# axes must be a scalar or list/tuple of length 2
if
not
numpy
.
isscalar
(
axes
)
and
len
(
axes
)
!=
2
:
raise
ValueError
(
'Axes should be an integer or a '
'list/tuple of len 2 (
%
s was provided)'
%
repr
(
axes
))
# if 'axes' is a number of axes to multiply and sum over (trailing axes
# of a, leading axes of b), we can just reshape and use dot.
elif
numpy
.
isscalar
(
axes
):
axes
=
int
(
axes
)
# check if axes is valid given the dimension of a and b
if
axes
>
a
.
ndim
:
raise
ValueError
(
'axes can not be larger than the dimension of '
'a (a.ndim=
%
i, axes=
%
i)'
%
(
a
.
ndim
,
axes
))
if
axes
>
b
.
ndim
:
raise
ValueError
(
'axes can not be larger than than the dimension '
'of b (b.ndim=
%
i, axes=
%
i)'
%
(
b
.
ndim
,
axes
))
outshape
=
concatenate
([
a
.
shape
[:
a
.
ndim
-
axes
],
b
.
shape
[
axes
:]])
outbcast
=
a
.
broadcastable
[:
a
.
ndim
-
axes
]
+
b
.
broadcastable
[
axes
:]
outndim
=
a
.
ndim
+
b
.
ndim
-
(
2
*
axes
)
a_shape_0
=
b_shape_0
=
a_shape_1
=
b_shape_1
=
1
for
s0
in
xrange
(
a
.
ndim
-
axes
):
a_shape_0
*=
a
.
shape
[
s0
]
for
s0
in
xrange
(
axes
):
b_shape_0
*=
b
.
shape
[
s0
]
for
s1
in
xrange
(
a
.
ndim
-
axes
,
a
.
ndim
):
a_shape_1
*=
a
.
shape
[
s1
]
for
s1
in
xrange
(
axes
,
b
.
ndim
):
b_shape_1
*=
b
.
shape
[
s1
]
a_reshaped
=
a
.
reshape
((
a_shape_0
,
a_shape_1
),
ndim
=
2
)
b_reshaped
=
b
.
reshape
((
b_shape_0
,
b_shape_1
),
ndim
=
2
)
out
=
_dot
(
a_reshaped
,
b_reshaped
)
.
reshape
(
outshape
,
outndim
)
# Make sure the broadcastable pattern of the result is correct,
# since some shape information can be lost in the reshapes.
return
patternbroadcast
(
out
,
outbcast
)
# if 'axes' is a list, transpose a and b such that the summed axes of a
# are last and the summed axes of b are first.
else
:
# get first axis element as a tuple
try
:
a_axes
=
tuple
(
axes
[
0
])
except
TypeError
:
a_axes
=
tuple
([
axes
[
0
]])
# get second axis element as a tuple
try
:
b_axes
=
tuple
(
axes
[
1
])
except
TypeError
:
b_axes
=
tuple
([
axes
[
1
]])
# the two axes lists must have the same length
if
len
(
a_axes
)
!=
len
(
b_axes
):
raise
ValueError
(
'Axes elements must have the same length.'
)
# check that there aren't more axes than a has dimensions
if
len
(
a_axes
)
>
a
.
ndim
:
raise
ValueError
(
'axes[0] should be array_like with length '
'less than the dimensions of a '
'(a.ndim=
%
i, len(axes[0])=
%
i).'
%
(
a
.
ndim
,
len
(
a_axes
)))
# check that a_axes doesn't contain an axis greater than or equal to
# a's dimensions. also check if len > 0 so numpy.max won't raise an
# error.
if
len
(
a_axes
)
>
0
and
numpy
.
max
(
numpy
.
array
(
a_axes
))
>=
a
.
ndim
:
raise
ValueError
(
'axes[0] contains dimensions greater than or '
'equal to a.ndim (a.ndim=
%
i, max(axes[0])=
%
i).'
%
(
a
.
ndim
,
numpy
.
max
(
numpy
.
array
(
a_axes
))))
# check that there aren't more axes than b has dimensions
if
len
(
b_axes
)
>
b
.
ndim
:
raise
ValueError
(
'axes[1] should be array_like, of length '
'smaller than the dimension of b '
'(a.ndim=
%
i, len(axes[0])=
%
i).'
%
(
b
.
ndim
,
len
(
b_axes
)))
# check that b_axes doesn't contain an axis greater than or equal to
# b's dimensions. also check if len > 0 so numpy.max won't raise an
# error.
if
len
(
b_axes
)
>
0
and
numpy
.
max
(
numpy
.
array
(
b_axes
))
>=
b
.
ndim
:
raise
ValueError
(
'axes[1] contains dimensions greater than or '
'equal to b.ndim (b.ndim=
%
i, max(axes[1])=
%
i).'
%
(
b
.
ndim
,
numpy
.
max
(
numpy
.
array
(
b_axes
))))
a_order
=
(
tuple
(
x
for
x
in
tuple
(
xrange
(
a
.
ndim
))
if
x
not
in
a_axes
)
+
a_axes
)
b_order
=
(
b_axes
+
tuple
(
x
for
x
in
tuple
(
xrange
(
b
.
ndim
))
if
x
not
in
b_axes
))
a_shuffled
=
a
.
dimshuffle
(
a_order
)
b_shuffled
=
b
.
dimshuffle
(
b_order
)
# now that a and b are in the right order, call tensordot recursively
return
tensordot
(
a_shuffled
,
b_shuffled
,
len
(
a_axes
))
def
outer
(
x
,
y
):
def
outer
(
x
,
y
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论