提交 69307dda authored 作者: Tanjay94's avatar Tanjay94

Auto generate Op.{hash,eq,str}.

上级 06bf5a35
......@@ -575,21 +575,24 @@ class Op(utils.object2, PureOp, CLinkerOp):
def __init__(self, use_c_code=theano.config.cxx):
self._op_use_c_code = use_c_code
def _props(self):
return (getattr(self, a) for a in self.__props__)
def __hash__(self):
if hasattr(self, 'props'):
return hash((type(self), self.props()))
if hasattr(self, '__props__'):
return hash((type(self), self._props()))
else:
return super(Op, self).__hash__()
def __str__(self):
if hasattr(self, 'props'):
return "%s{%s}" % (self.__class__.__name__, ", ".join(str(p) for p in self.props()))
if hasattr(self, '__props__'):
return "%s{%s}" % (self.__class__.__name__, ", ".join(str(p) for p in self._props()))
else:
return super(Op, self).__str__()
def __eq__(self, other):
if hasattr(self, 'props'):
return (type(self) == type(other) and self.props() == other.props())
if hasattr(self, '__props__'):
return (type(self) == type(other) and self._props() == other._props())
else:
return NotImplemented
......
......@@ -34,14 +34,6 @@ class MatrixPinv(Op):
def __init__(self):
pass
def props(self):
"""Function exposing different properties of each instance of the
op.
For the ``MatrixPinv`` op, there are no properties to be exposed.
"""
return ()
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2
......@@ -67,14 +59,6 @@ class MatrixInverse(Op):
def __init__(self):
pass
def props(self):
"""Function exposing different properties of each instance of the
op.
For the ``MatrixInverse`` op, there are no properties to be exposed.
"""
return ()
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2
......@@ -298,14 +282,6 @@ class Eig(Op):
"""
_numop = staticmethod(numpy.linalg.eig)
def props(self):
"""Function exposing different properties of each instance of the
op.
For the ``Eig`` op, there are no properties to be exposed.
"""
return ()
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2
......@@ -329,14 +305,12 @@ class Eigh(Eig):
"""
_numop = staticmethod(numpy.linalg.eigh)
__props__ = ('UPLO',)
def __init__(self, UPLO='L'):
assert UPLO in ['L', 'U']
self.UPLO = UPLO
def props(self):
return self.UPLO,
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2
......@@ -397,6 +371,8 @@ class EighGrad(Op):
"""Gradient of an eigensystem of a Hermitian matrix.
"""
__props__ = ('UPLO',)
def __init__(self, UPLO='L'):
assert UPLO in ['L', 'U']
self.UPLO = UPLO
......@@ -407,9 +383,6 @@ class EighGrad(Op):
self.tri0 = numpy.triu
self.tri1 = lambda a: numpy.tril(a, -1)
def props(self):
return (self.UPLO,)
def make_node(self, x, w, v, gw, gv):
x, w, v, gw, gv = map(as_tensor_variable, (x, w, v, gw, gv))
assert x.ndim == 2
......@@ -468,6 +441,7 @@ class QRFull(Op):
and r is upper-triangular.
"""
_numop = staticmethod(numpy.linalg.qr)
__props__ = ('mode',)
def __init__(self, mode):
self.mode = mode
......@@ -479,9 +453,6 @@ class QRFull(Op):
r = theano.tensor.matrix(dtype=x.dtype)
return Apply(self, [x], [q, r])
def props(self):
return self.mode
def perform(self, node, (x,), (q, r)):
assert x.ndim == 2, "The input of qr function should be a matrix."
......@@ -496,13 +467,11 @@ class QRIncomplete(Op):
Factor the matrix a as qr and return a single matrix.
"""
_numop = staticmethod(numpy.linalg.qr)
__props__ = ('mode',)
def __init__(self, mode):
self.mode = mode
def props(self):
return self.mode
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2, "The input of qr function should be a matrix."
......@@ -570,6 +539,7 @@ class SVD(Op):
# See doc in the docstring of the function just after this class.
_numop = staticmethod(numpy.linalg.svd)
__props__ = ('full_matrices', 'compute_uv')
def __init__(self, full_matrices=True, compute_uv=True):
"""
......@@ -587,9 +557,6 @@ class SVD(Op):
self.full_matrices = full_matrices
self.compute_uv = compute_uv
def props(self):
return self.full_matrices, self.compute_uv,
def make_node(self, x):
x = as_tensor_variable(x)
assert x.ndim == 2, "The input of svd function should be a matrix."
......
......@@ -42,14 +42,13 @@ class Cholesky(Op):
#TODO: inplace
#TODO: for specific dtypes
#TODO: LAPACK wrapper with in-place behavior, for solve also
__props__ = ('lower', 'destructive')
def __init__(self, lower=True):
self.lower = lower
self.destructive = False
def props(self):
return (self.lower,
self.destructive)
def infer_shape(self, node, shapes):
return [shapes[0]]
......@@ -75,14 +74,13 @@ cholesky = Cholesky()
class CholeskyGrad(Op):
"""
"""
__props__ = ('lower', 'destructive')
def __init__(self, lower=True):
self.lower = lower
self.destructive = False
def props(self):
return (self.lower,
self.destructive)
def make_node(self, x, l, dz):
x = as_tensor_variable(x)
l = as_tensor_variable(l)
......@@ -141,6 +139,9 @@ class CholeskyGrad(Op):
class Solve(Op):
"""Solve a system of linear equations"""
__props__ = ('A_structure', 'lower', 'overwrite_A', 'overwrite_b')
def __init__(self,
A_structure='general',
lower=False,
......@@ -153,12 +154,6 @@ class Solve(Op):
self.overwrite_A = overwrite_A
self.overwrite_b = overwrite_b
def props(self):
return (self.A_structure,
self.lower,
self.overwrite_A,
self.overwrite_b)
def __repr__(self):
return 'Solve{%s}' % str(self.props())
......@@ -201,13 +196,12 @@ class Eigvalsh(Op):
"""Generalized eigenvalues of a Hermetian positive definite eigensystem
"""
__props__ = ('lower',)
def __init__(self, lower=True):
assert lower in [True, False]
self.lower = lower
def props(self):
return (self.lower,)
def make_node(self, a, b):
assert imported_scipy, (
"Scipy not available. Scipy is needed for the Eigvalsh op")
......@@ -258,6 +252,8 @@ class EigvalshGrad(Op):
# discussion on github at
# https://github.com/Theano/Theano/pull/1846#discussion-diff-12486764
__props__ = ('lower',)
def __init__(self, lower=True):
assert lower in [True, False]
self.lower = lower
......@@ -268,9 +264,6 @@ class EigvalshGrad(Op):
self.tri0 = numpy.triu
self.tri1 = lambda a: numpy.tril(a, -1)
def props(self):
return (self.lower,)
def make_node(self, a, b, gw):
assert imported_scipy, (
"Scipy not available. Scipy is needed for the GEigvalsh op")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论