提交 264a74f3 authored 作者: Zhouhan LIN's avatar Zhouhan LIN

clip and minimum, change ExtractDiag names

上级 466d8929
......@@ -8,8 +8,7 @@ from six.moves import StringIO
from theano import tensor, gof, Op
from theano.gradient import grad_not_implemented
import theano.tensor.clip
import theano.tensor.minimum
import theano.tensor as T
from theano.tensor.subtensor import IncSubtensor, Subtensor, get_idx_list
try:
......@@ -1177,10 +1176,10 @@ class GpuDiagonal(Subtensor):
# The following logic is inspired by C code of PyArray_Diagonal().
offset = self.offset
if offset > 0:
diag_size = theano.tensor.clip(dim2 - offset, 0, dim1)
diag_size = T.clip(dim2 - offset, 0, dim1)
elif offset < 0:
diag_size = theano.tensor.clip(dim1 + offset, 0, dim2)
diag_size = T.clip(dim1 + offset, 0, dim2)
else:
diag_size = theano.tensor.minimum(dim1, dim2)
diag_size = T.minimum(dim1, dim2)
out_shape.append(diag_size)
return [tuple(out_shape)]
......@@ -6045,10 +6045,10 @@ class ExtractDiag(Op):
def __init__(self, offset=0, axis1=0, axis2=1, view=False):
self.view = view
if self.view and not numpy_diagonal_return_view:
warnings.warn("View will forced to False. Diagonal property view is "
warnings.warn("View will forced to False. ExtractDiag property view is "
"set to True but numpy version %s and prior versions of "
"numpy.diagonal() do not return a view. Update "
"numpy to use Diagonal(view=True)" %
"numpy to use ExtractDiag(view=True)" %
numpy.version.version)
self.view = False
if self.view:
......@@ -6061,7 +6061,7 @@ class ExtractDiag(Op):
x = as_tensor_variable(x)
if x.ndim < 2:
raise ValueError('Diagonal needs an input with 2 or more '
raise ValueError('ExtractDiag needs an input with 2 or more '
'dimensions', x)
return Apply(self, [x], [x.type.__class__(
dtype=x.dtype,
......
......@@ -45,7 +45,7 @@ from theano.tensor import (_shared, wvector, bvector,
tile, patternbroadcast, Eye, Shape, Dot, PermuteRowElements,
ScalarFromTensor, TensorFromScalar, dtensor4, Rebroadcast, Alloc,
dtensor3, SpecifyShape, Mean,
itensor3, Tile, switch, Diagonal, Diag,
itensor3, Tile, switch, ExtractDiag, Diag,
nonzero, flatnonzero, nonzero_values,
stacklists, DimShuffle, hessian, ptp, power,
swapaxes, choose, Choose, NoneConst, AllocEmpty,
......@@ -7427,27 +7427,27 @@ class TestInferShape(utt.InferShapeTester):
[Tri()(aiscal, biscal, ciscal)],
[3, 5, 0], Tri)
# Diagonal
# ExtractDiag
atens3 = tensor3()
atens3_val = rand(4, 5, 3)
atens3_diag = Diagonal()(atens3)
atens3_diag = ExtractDiag()(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
atens3_diag = Diagonal(1)(atens3)
[atens3_val], ExtractDiag)
atens3_diag = ExtractDiag(1)(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
atens3_diag = Diagonal(-1)(atens3)
[atens3_val], ExtractDiag)
atens3_diag = ExtractDiag(-1)(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
atens3_diag = Diagonal(1, 0, 2)(atens3)
[atens3_val], ExtractDiag)
atens3_diag = ExtractDiag(1, 0, 2)(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
atens3_diag = Diagonal(1, 1, 2)(atens3)
[atens3_val], ExtractDiag)
atens3_diag = ExtractDiag(1, 1, 2)(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
atens3_diag = Diagonal(1, 2, 0)(atens3)
[atens3_val], ExtractDiag)
atens3_diag = ExtractDiag(1, 2, 0)(atens3)
self._compile_and_check([atens3], [atens3_diag],
[atens3_val], Diagonal)
[atens3_val], ExtractDiag)
# Diag
advec = dvector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论