提交 68c4085f authored 作者: abalkin's avatar abalkin

Added .argmin(), .argmax(), .clip(), .conj(), .repeat(), .trace(),

.std(), and .round() methods and .real and .imag attributes to TensorVariable.
上级 b8db858d
...@@ -1758,6 +1758,10 @@ class _tensor_py_operators: ...@@ -1758,6 +1758,10 @@ class _tensor_py_operators:
"""See `theano.tensor.var`""" """See `theano.tensor.var`"""
return var(self, axis, keepdims=keepdims) return var(self, axis, keepdims=keepdims)
def std(self, axis=None, keepdims=False):
"""See `theano.tensor.std`"""
return std(self, axis, keepdims=keepdims)
def min(self, axis=None, keepdims=False): def min(self, axis=None, keepdims=False):
"""See `theano.tensor.min`""" """See `theano.tensor.min`"""
return min(self, axis, keepdims=keepdims) return min(self, axis, keepdims=keepdims)
...@@ -1766,6 +1770,35 @@ class _tensor_py_operators: ...@@ -1766,6 +1770,35 @@ class _tensor_py_operators:
"""See `theano.tensor.max`""" """See `theano.tensor.max`"""
return max(self, axis, keepdims=keepdims) return max(self, axis, keepdims=keepdims)
def argmin(self, axis=None, keepdims=False):
"""See `theano.tensor.argmin`"""
return argmin(self, axis, keepdims=keepdims)
def argmax(self, axis=None, keepdims=False):
"""See `theano.tensor.argmax`"""
return argmax(self, axis, keepdims=keepdims)
def clip(self, a_min, a_max):
"Clip (limit) the values in an array."
return clip(self, a_min, a_max)
def conj(self):
"""See `theano.tensor.conj`"""
return conj(self)
def repeat(self, repeats, axis=None):
"""See `theano.tensor.repeat`"""
from theano.tensor.extra_ops import repeat
return repeat(self, repeats, axis)
def round(self, mode="half_away_from_zero"):
"""See `theano.tensor.round`"""
return round(self, mode)
def trace(self):
from theano.sandbox.linalg import trace
return trace(self)
# TO TRUMP NUMPY OPERATORS # TO TRUMP NUMPY OPERATORS
__array_priority__ = 1000 __array_priority__ = 1000
...@@ -2971,12 +3004,12 @@ def psi(a): ...@@ -2971,12 +3004,12 @@ def psi(a):
@_scal_elemwise_with_nfunc('real', 1, -1) @_scal_elemwise_with_nfunc('real', 1, -1)
def real(z): def real(z):
"""Return real component of complex-valued tensor `z`""" """Return real component of complex-valued tensor `z`"""
_tensor_py_operators.real = property(real)
@_scal_elemwise_with_nfunc('imag', 1, -1) @_scal_elemwise_with_nfunc('imag', 1, -1)
def imag(z): def imag(z):
"""Return imaginary component of complex-valued tensor `z`""" """Return imaginary component of complex-valued tensor `z`"""
_tensor_py_operators.imag = property(imag)
@_scal_elemwise_with_nfunc('angle', 1, -1) @_scal_elemwise_with_nfunc('angle', 1, -1)
def angle(z): def angle(z):
......
...@@ -7006,6 +7006,23 @@ class TestTensorInstanceMethods(unittest.TestCase): ...@@ -7006,6 +7006,23 @@ class TestTensorInstanceMethods(unittest.TestCase):
self.vars = matrices('X', 'Y') self.vars = matrices('X', 'Y')
self.vals = [rand(2,2),rand(2,2)] self.vals = [rand(2,2),rand(2,2)]
def test_argmin(self):
X, _ = self.vars
x, _ = self.vals
self.assertTrue(numpy.all(X.argmin().eval({X: x}) == x.argmin()))
def test_argmax(self):
X, _ = self.vars
x, _ = self.vals
self.assertTrue(numpy.all(X.argmax().eval({X: x}) == x.argmax()))
def test_dot(self):
X, Y = self.vars
x, y = self.vals
Z = X.clip(0.5 - Y, 0.5 + Y)
z = x.clip(0.5 - y, 0.5 + y)
self.assertTrue(numpy.all(Z.eval({X: x, Y: y}) == z))
def test_dot(self): def test_dot(self):
X, Y = self.vars X, Y = self.vars
x, y = self.vals x, y = self.vals
...@@ -7013,7 +7030,43 @@ class TestTensorInstanceMethods(unittest.TestCase): ...@@ -7013,7 +7030,43 @@ class TestTensorInstanceMethods(unittest.TestCase):
Z = X.dot(Y) Z = X.dot(Y)
z = x.dot(y) z = x.dot(y)
self.assertTrue(numpy.all(x.dot(z) == X.dot(Z).eval({X: x, Z: z}))) self.assertTrue(numpy.all(x.dot(z) == X.dot(Z).eval({X: x, Z: z})))
def test_real_imag(self):
X, Y = self.vars
x, y = self.vals
Z = X + Y * 1j
z = x + y * 1j
self.assertTrue(numpy.all(Z.real.eval({Z: z}) == x))
self.assertTrue(numpy.all(Z.imag.eval({Z: z}) == y))
def test_conj(self):
X, Y = self.vars
x, y = self.vals
Z = X + Y * 1j
z = x + y * 1j
self.assertTrue(numpy.all(Z.conj().eval({Z: z}) == z.conj()))
def test_round(self):
X, _ = self.vars
x, _ = self.vals
self.assertTrue(numpy.all(X.round().eval({X: x}) == x.round()))
def test_std(self):
X, _ = self.vars
x, _ = self.vals
self.assertTrue(numpy.all(X.std().eval({X: x}) == x.std()))
def test_repeat(self):
X, _ = self.vars
x, _ = self.vals
self.assertTrue(numpy.all(X.repeat(2).eval({X: x}) == x.repeat(2)))
def test_trace(self):
X, _ = self.vars
x, _ = self.vals
self.assertTrue(numpy.all(X.trace().eval({X: x}) == x.trace()))
if __name__ == '__main__': if __name__ == '__main__':
t = TestInferShape('setUp') t = TestInferShape('setUp')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论