提交 7fd891e9 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4749 from Sentient07/Props

Adding __props__ to all the Ops that has __init__
......@@ -15,6 +15,8 @@ class Minimal(gof.Op):
# If two Apply nodes have the same inputs and the ops compare equal...
# then they will be MERGED so they had better have computed the same thing!
__props__ = ()
def __init__(self):
# If you put things here, think about whether they change the outputs
# computed by # self.perform()
......@@ -25,12 +27,6 @@ class Minimal(gof.Op):
super(Minimal, self).__init__()
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, *args):
# HERE `args` must be THEANO VARIABLES
return gof.Apply(op=self, inputs=args, outputs=[tensor.lscalar()])
......
......@@ -27,11 +27,7 @@ class Solve(gof.Op):
# and keeps a memory workspace from call to call as a non-default Op
# output
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
__props__ = ()
def make_node(self, A, b):
A_ = tensor.as_tensor_variable(A)
......
......@@ -1321,6 +1321,8 @@ class CAReduce(Op):
"""
__props__ = ("scalar_op", "axis")
def __init__(self, scalar_op, axis=None):
if scalar_op.nin not in [-1, 2] or scalar_op.nout != 1:
raise NotImplementedError((
......@@ -1411,17 +1413,6 @@ class CAReduce(Op):
self.__dict__.update(d)
self.set_ufunc(self.scalar_op)
def __eq__(self, other):
return (type(self) == type(other) and
self.scalar_op == other.scalar_op and
self.axis == other.axis)
def __hash__(self):
if self.axis is None:
return hash(self.scalar_op)
else:
return hash(self.scalar_op) ^ hash(tuple(self.axis))
def __str__(self):
if self.axis is not None:
return "Reduce{%s}{%s}" % (
......@@ -1699,6 +1690,7 @@ class All(CAReduce):
specified axis(es).
"""
__props__ = ("axis",)
def __init__(self, axis=None):
CAReduce.__init__(self, scalar.and_, axis)
......@@ -1729,6 +1721,7 @@ class Any(CAReduce):
specified axis(es).
"""
__props__ = ("axis", )
def __init__(self, axis=None):
CAReduce.__init__(self, scalar.or_, axis)
......@@ -1806,20 +1799,13 @@ class CAReduceDtype(CAReduce):
* for complex dtypes, we use at least complex128.
"""
__props__ = ("scalar_op", "axis", "dtype", "acc_dtype")
def __init__(self, scalar_op, axis=None, dtype=None, acc_dtype=None):
CAReduce.__init__(self, scalar_op, axis=axis)
self.dtype = dtype
self.acc_dtype = acc_dtype
def __eq__(self, other):
return (CAReduce.__eq__(self, other) and
self.dtype == other.dtype and
self.acc_dtype == other.acc_dtype)
def __hash__(self):
return CAReduce.__hash__(self) ^ hash((self.dtype, self.acc_dtype))
def __setstate__(self, d):
super(CAReduceDtype, self).__setstate__(d)
if not hasattr(self, "dtype"):
......@@ -1966,10 +1952,24 @@ class Sum(CAReduceDtype):
"""
__props__ = ("axis", "dtype", "acc_dtype")
def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__(self, scalar.add, axis=axis,
dtype=dtype, acc_dtype=acc_dtype)
def __str__(self):
name = self.__class__.__name__
axis = ""
if self.axis is not None:
axis = ", ".join(str(x) for x in self.axis)
axis = "axis=[%s], " % axis
return "%s{%sacc_dtype=%s}" % (
name,
axis,
str(self.acc_dtype)
)
def grad(self, inp, grads):
x, = inp
......@@ -2014,6 +2014,7 @@ class Prod(CAReduceDtype):
input.
"""
__props__ = ("axis", "dtype", "acc_dtype")
def __init__(self, axis=None, dtype=None, acc_dtype=None,
no_zeros_in_input=False):
......@@ -2027,14 +2028,6 @@ class Prod(CAReduceDtype):
if 'no_zeros_in_input' not in dct:
self.no_zeros_in_input = False
def __eq__(self, other):
return (CAReduceDtype.__eq__(self, other) and
self.no_zeros_in_input == other.no_zeros_in_input)
def __hash__(self):
return (CAReduceDtype.__hash__(self) ^
hash(self.no_zeros_in_input))
def grad(self, inp, grads):
"""
The grad of this Op could be very easy, if it is was not for the case
......@@ -2196,6 +2189,9 @@ mul_without_zeros = MulWithoutZeros(scalar.upcast_out,
class ProdWithoutZeros(CAReduceDtype):
__props__ = ("axis", "dtype", "acc_dtype")
def __init__(self, axis=None, dtype=None, acc_dtype=None):
CAReduceDtype.__init__(self, mul_without_zeros, axis=axis,
dtype=dtype, acc_dtype=acc_dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论