提交 2043d632 authored 作者: Tanjay94's avatar Tanjay94

Moved kron to slinalg and added deprecation_warning.

上级 bb7f0aae
from theano import tensor from theano import tensor
from theano.tensor.slinalg import kron
def deprecation_warning():
# Make sure the warning is displayed only once.
if deprecation_warning.already_displayed:
return
def kron(a, b): warnings.warn(
""" Kronecker product "theano modules are deprecated and will be removed in release 0.7",
stacklevel=3)
Same as scipy.linalg.kron(a, b). deprecation_warning.already_displayed = True
\ No newline at end of file
:note: numpy.kron(a, b) != scipy.linalg.kron(a, b)!
They don't have the same shape and order when
a.ndim != b.ndim != 2.
:param a: array_like
:param b: array_like
:return: array_like with a.ndim + b.ndim - 2 dimensions.
"""
a = tensor.as_tensor_variable(a)
b = tensor.as_tensor_variable(b)
if (a.ndim + b.ndim <= 2):
raise TypeError('kron: inputs dimensions must sum to 3 or more. '
'You passed %d and %d.' % (a.ndim, b.ndim))
o = tensor.outer(a, b)
o = o.reshape(tensor.concatenate((a.shape, b.shape)),
a.ndim + b.ndim)
shf = o.dimshuffle(0, 2, 1, * range(3, o.ndim))
if shf.ndim == 3:
shf = o.dimshuffle(1, 0, 2)
o = shf.flatten()
else:
o = shf.reshape((o.shape[0] * o.shape[2],
o.shape[1] * o.shape[3]) +
tuple([o.shape[i] for i in range(4, o.ndim)]))
return o
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论