提交 7fd891e9 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4749 from Sentient07/Props

Adding __props__ to all the Ops that has __init__
...@@ -15,6 +15,8 @@ class Minimal(gof.Op): ...@@ -15,6 +15,8 @@ class Minimal(gof.Op):
# If two Apply nodes have the same inputs and the ops compare equal... # If two Apply nodes have the same inputs and the ops compare equal...
# then they will be MERGED so they had better have computed the same thing! # then they will be MERGED so they had better have computed the same thing!
__props__ = ()
def __init__(self): def __init__(self):
# If you put things here, think about whether they change the outputs # If you put things here, think about whether they change the outputs
# computed by # self.perform() # computed by # self.perform()
...@@ -25,12 +27,6 @@ class Minimal(gof.Op): ...@@ -25,12 +27,6 @@ class Minimal(gof.Op):
super(Minimal, self).__init__() super(Minimal, self).__init__()
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, *args): def make_node(self, *args):
# HERE `args` must be THEANO VARIABLES # HERE `args` must be THEANO VARIABLES
return gof.Apply(op=self, inputs=args, outputs=[tensor.lscalar()]) return gof.Apply(op=self, inputs=args, outputs=[tensor.lscalar()])
......
...@@ -27,11 +27,7 @@ class Solve(gof.Op): ...@@ -27,11 +27,7 @@ class Solve(gof.Op):
# and keeps a memory workspace from call to call as a non-default Op # and keeps a memory workspace from call to call as a non-default Op
# output # output
def __eq__(self, other): __props__ = ()
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, A, b): def make_node(self, A, b):
A_ = tensor.as_tensor_variable(A) A_ = tensor.as_tensor_variable(A)
......
...@@ -1321,6 +1321,8 @@ class CAReduce(Op): ...@@ -1321,6 +1321,8 @@ class CAReduce(Op):
""" """
__props__ = ("scalar_op", "axis")
def __init__(self, scalar_op, axis=None): def __init__(self, scalar_op, axis=None):
if scalar_op.nin not in [-1, 2] or scalar_op.nout != 1: if scalar_op.nin not in [-1, 2] or scalar_op.nout != 1:
raise NotImplementedError(( raise NotImplementedError((
...@@ -1411,17 +1413,6 @@ class CAReduce(Op): ...@@ -1411,17 +1413,6 @@ class CAReduce(Op):
self.__dict__.update(d) self.__dict__.update(d)
self.set_ufunc(self.scalar_op) self.set_ufunc(self.scalar_op)
def __eq__(self, other):
return (type(self) == type(other) and
self.scalar_op == other.scalar_op and
self.axis == other.axis)
def __hash__(self):
if self.axis is None:
return hash(self.scalar_op)
else:
return hash(self.scalar_op) ^ hash(tuple(self.axis))
def __str__(self): def __str__(self):
if self.axis is not None: if self.axis is not None:
return "Reduce{%s}{%s}" % ( return "Reduce{%s}{%s}" % (
...@@ -1699,6 +1690,7 @@ class All(CAReduce): ...@@ -1699,6 +1690,7 @@ class All(CAReduce):
specified axis(es). specified axis(es).
""" """
__props__ = ("axis",)
def __init__(self, axis=None): def __init__(self, axis=None):
CAReduce.__init__(self, scalar.and_, axis) CAReduce.__init__(self, scalar.and_, axis)
...@@ -1729,6 +1721,7 @@ class Any(CAReduce): ...@@ -1729,6 +1721,7 @@ class Any(CAReduce):
specified axis(es). specified axis(es).
""" """
__props__ = ("axis", )
def __init__(self, axis=None): def __init__(self, axis=None):
CAReduce.__init__(self, scalar.or_, axis) CAReduce.__init__(self, scalar.or_, axis)
...@@ -1806,20 +1799,13 @@ class CAReduceDtype(CAReduce): ...@@ -1806,20 +1799,13 @@ class CAReduceDtype(CAReduce):
* for complex dtypes, we use at least complex128. * for complex dtypes, we use at least complex128.
""" """
__props__ = ("scalar_op", "axis", "dtype", "acc_dtype")
def __init__(self, scalar_op, axis=None, dtype=None, acc_dtype=None): def __init__(self, scalar_op, axis=None, dtype=None, acc_dtype=None):
CAReduce.__init__(self, scalar_op, axis=axis) CAReduce.__init__(self, scalar_op, axis=axis)
self.dtype = dtype self.dtype = dtype
self.acc_dtype = acc_dtype self.acc_dtype = acc_dtype
def __eq__(self, other):
return (CAReduce.__eq__(self, other) and
self.dtype == other.dtype and
self.acc_dtype == other.acc_dtype)
def __hash__(self):
return CAReduce.__hash__(self) ^ hash((self.dtype, self.acc_dtype))
def __setstate__(self, d): def __setstate__(self, d):
super(CAReduceDtype, self).__setstate__(d) super(CAReduceDtype, self).__setstate__(d)
if not hasattr(self, "dtype"): if not hasattr(self, "dtype"):
...@@ -1966,10 +1952,24 @@ class Sum(CAReduceDtype): ...@@ -1966,10 +1952,24 @@ class Sum(CAReduceDtype):
""" """
__props__ = ("axis", "dtype", "acc_dtype")
def __init__(self, axis=None, dtype=None, acc_dtype=None): def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__(self, scalar.add, axis=axis, CAReduceDtype.__init__(self, scalar.add, axis=axis,
dtype=dtype, acc_dtype=acc_dtype) dtype=dtype, acc_dtype=acc_dtype)
def __str__(self):
name = self.__class__.__name__
axis = ""
if self.axis is not None:
axis = ", ".join(str(x) for x in self.axis)
axis = "axis=[%s], " % axis
return "%s{%sacc_dtype=%s}" % (
name,
axis,
str(self.acc_dtype)
)
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
...@@ -2014,6 +2014,7 @@ class Prod(CAReduceDtype): ...@@ -2014,6 +2014,7 @@ class Prod(CAReduceDtype):
input. input.
""" """
__props__ = ("axis", "dtype", "acc_dtype")
def __init__(self, axis=None, dtype=None, acc_dtype=None, def __init__(self, axis=None, dtype=None, acc_dtype=None,
no_zeros_in_input=False): no_zeros_in_input=False):
...@@ -2027,14 +2028,6 @@ class Prod(CAReduceDtype): ...@@ -2027,14 +2028,6 @@ class Prod(CAReduceDtype):
if 'no_zeros_in_input' not in dct: if 'no_zeros_in_input' not in dct:
self.no_zeros_in_input = False self.no_zeros_in_input = False
def __eq__(self, other):
return (CAReduceDtype.__eq__(self, other) and
self.no_zeros_in_input == other.no_zeros_in_input)
def __hash__(self):
return (CAReduceDtype.__hash__(self) ^
hash(self.no_zeros_in_input))
def grad(self, inp, grads): def grad(self, inp, grads):
""" """
The grad of this Op could be very easy, if it is was not for the case The grad of this Op could be very easy, if it is was not for the case
...@@ -2196,6 +2189,9 @@ mul_without_zeros = MulWithoutZeros(scalar.upcast_out, ...@@ -2196,6 +2189,9 @@ mul_without_zeros = MulWithoutZeros(scalar.upcast_out,
class ProdWithoutZeros(CAReduceDtype): class ProdWithoutZeros(CAReduceDtype):
__props__ = ("axis", "dtype", "acc_dtype")
def __init__(self, axis=None, dtype=None, acc_dtype=None): def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__(self, mul_without_zeros, axis=axis, CAReduceDtype.__init__(self, mul_without_zeros, axis=axis,
dtype=dtype, acc_dtype=acc_dtype) dtype=dtype, acc_dtype=acc_dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论