提交 c819f527 authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

add Tri op, tri/tril/triu functions

上级 bbc25bc6
......@@ -3184,6 +3184,112 @@ def ones(shape, dtype=None):
return alloc(numpy.array(1, dtype=dtype), *shape)
class Tri(gof.Op):
def __init__(self, dtype=None):
if dtype is None:
dtype = config.floatX
self.dtype = dtype
def make_node(self, N, M, k):
N = as_tensor_variable(N)
M = as_tensor_variable(M)
k = as_tensor_variable(k)
return gof.Apply(self, [N, M, k],
[TensorType(dtype=self.dtype, broadcastable=(False, False))()])
def perform(self, node, inp, out_):
N, M, k = inp
out, = out_
out[0] = numpy.tri(N, M, k, dtype=self.dtype)
def infer_shape(self, node, in_shapes):
out_shape = [node.inputs[0], node.inputs[1]]
return [out_shape]
def grad(self, inp, grads):
return [grad_undefined(self, i, inp[i]) for i in xrange(3)]
def __eq__(self, other):
return type(self) == type(other) and self.dtype == other.dtype
def __hash__(self):
return hash(self.dtype) ^ hash(type(self))
def tri(N, M=None, k=0, dtype=None):
"""
An array with ones at and below the given diagonal and zeros elsewhere.
Parameters
----------
N : int
Number of rows in the array.
M : int, optional
Number of columns in the array.
By default, `M` is taken equal to `N`.
k : int, optional
The sub-diagonal at and below which the array is filled.
`k` = 0 is the main diagonal, while `k` < 0 is below it,
and `k` > 0 is above. The default is 0.
dtype : dtype, optional
Data type of the returned array. The default is float.
Returns
-------
tri : Array of shape (N, M)
Array with its lower triangle filled with ones and zero elsewhere;
in other words ``T[i,j] == 1`` for ``i <= j + k``, 0 otherwise.
"""
if dtype is None:
dtype = config.floatX
if M is None:
M = N
op = Tri(dtype)
return op(N, M, k)
def tril(m, k=0):
"""
Lower triangle of an array.
Return a copy of an array with elements above the `k`-th diagonal zeroed.
Parameters
----------
m : array_like, shape (M, N)
Input array.
k : int, optional
Diagonal above which to zero elements. `k = 0` (the default) is the
main diagonal, `k < 0` is below it and `k > 0` is above.
Returns
-------
tril : array, shape (M, N)
Lower triangle of `m`, of same shape and data-type as `m`.
See Also
--------
triu : same thing, only for the upper triangle
"""
return m * tri(m.shape[0], m.shape[1], k=k, dtype=m.dtype)
def triu(m, k=0):
"""
Upper triangle of an array.
Return a copy of a matrix with the elements below the `k`-th diagonal
zeroed.
Please refer to the documentation for `tril` for further details.
See Also
--------
tril : lower triangle of an array
"""
return m * (1 - tri(m.shape[0], m.shape[1], k=k-1, dtype=m.dtype))
class Eye(gof.Op):
def __init__(self, dtype=None):
if dtype is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论