提交 4670fdb0 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

__props__ to theano/compile/ops.py, one doubt with parameter of type *axis

上级 0a1e50e7
...@@ -38,16 +38,11 @@ class ViewOp(gof.Op): ...@@ -38,16 +38,11 @@ class ViewOp(gof.Op):
# In the C code, the name of the input variable is %(iname)s, # In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version = {}
__props__ = ()
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def perform(self, node, inp, out): def perform(self, node, inp, out):
x, = inp x, = inp
z, = out z, = out
...@@ -138,6 +133,7 @@ class DeepCopyOp(gof.Op): ...@@ -138,6 +133,7 @@ class DeepCopyOp(gof.Op):
c_code_and_version = {} c_code_and_version = {}
check_input = False check_input = False
__props__ = ()
def __init__(self): def __init__(self):
pass pass
...@@ -145,12 +141,6 @@ class DeepCopyOp(gof.Op): ...@@ -145,12 +141,6 @@ class DeepCopyOp(gof.Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def __hash__(self):
return hash(type(self))
def __eq__(self, other):
return type(self) == type(other)
def make_node(self, x): def make_node(self, x):
return gof.Apply(self, [x], [x.type()]) return gof.Apply(self, [x], [x.type()])
...@@ -228,12 +218,7 @@ class Shape(gof.Op): ...@@ -228,12 +218,7 @@ class Shape(gof.Op):
c_code_and_version = {} c_code_and_version = {}
check_input = False check_input = False
__props__ = ()
def __hash__(self):
return hash(type(self))
def __eq__(self, other):
return type(self) == type(other)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
...@@ -480,6 +465,8 @@ class FromFunctionOp(gof.Op): ...@@ -480,6 +465,8 @@ class FromFunctionOp(gof.Op):
raise an error if you attempt to get the gradient of a graph raise an error if you attempt to get the gradient of a graph
containing this op. containing this op.
""" """
__props__ = ("fn", "itypes", "otypes", "infer_shape")
def __init__(self, fn, itypes, otypes, infer_shape): def __init__(self, fn, itypes, otypes, infer_shape):
self.__fn = fn self.__fn = fn
self.itypes = itypes self.itypes = itypes
...@@ -488,13 +475,6 @@ class FromFunctionOp(gof.Op): ...@@ -488,13 +475,6 @@ class FromFunctionOp(gof.Op):
if self.__infer_shape is not None: if self.__infer_shape is not None:
self.infer_shape = self._infer_shape self.infer_shape = self._infer_shape
def __eq__(self, other):
return (type(self) == type(other) and
self.__fn == other.__fn)
def __hash__(self):
return hash(type(self)) ^ hash(self.__fn)
def __str__(self): def __str__(self):
return 'FromFunctionOp{%s}' % self.__fn.__name__ return 'FromFunctionOp{%s}' % self.__fn.__name__
...@@ -623,6 +603,7 @@ class Rebroadcast(gof.Op): ...@@ -623,6 +603,7 @@ class Rebroadcast(gof.Op):
c_code_and_version = {} c_code_and_version = {}
check_input = False check_input = False
__props__ = ("axis")
def __init__(self, *axis): def __init__(self, *axis):
self.axis = dict(axis) self.axis = dict(axis)
...@@ -630,14 +611,6 @@ class Rebroadcast(gof.Op): ...@@ -630,14 +611,6 @@ class Rebroadcast(gof.Op):
assert isinstance(axis, (numpy.integer, int)), ( assert isinstance(axis, (numpy.integer, int)), (
"Rebroadcast needs integer axes. Got ", axis) "Rebroadcast needs integer axes. Got ", axis)
def __eq__(self, other):
return type(self) == type(other) and self.axis == other.axis
def __hash__(self):
# no ambiguity because each item key is unique
items = sorted(iteritems(self.axis))
return hash((type(self), tuple(items)))
def __str__(self): def __str__(self):
if len(self.axis) == 0: if len(self.axis) == 0:
broadcast_pattern = [] broadcast_pattern = []
...@@ -768,12 +741,7 @@ class SpecifyShape(gof.Op): ...@@ -768,12 +741,7 @@ class SpecifyShape(gof.Op):
# In the C code, the name of the input variable is %(iname)s, # In the C code, the name of the input variable is %(iname)s,
# the output variable is %(oname)s. # the output variable is %(oname)s.
c_code_and_version = {} c_code_and_version = {}
__props__ = ()
def __hash__(self):
return hash(type(self))
def __eq__(self, other):
return type(self) == type(other)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论