提交 48866b3f authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #4968 from abergeron/object2_props

Add a metaclass implementation of __props__ and generalize it to all theano objects
...@@ -500,9 +500,6 @@ and ``b`` are equal. ...@@ -500,9 +500,6 @@ and ``b`` are equal.
super(AXPBOp, self).__init__() super(AXPBOp, self).__init__()
def make_node(self, x): def make_node(self, x):
# check that the theano version has support for __props__.
assert hasattr(self, '_props'), ("Your version of theano is too"
"old to support __props__.")
x = theano.tensor.as_tensor_variable(x) x = theano.tensor.as_tensor_variable(x)
return theano.Apply(self, [x], [x.type()]) return theano.Apply(self, [x], [x.type()])
......
...@@ -790,48 +790,6 @@ class Op(utils.object2, PureOp, CLinkerOp): ...@@ -790,48 +790,6 @@ class Op(utils.object2, PureOp, CLinkerOp):
def __init__(self, use_c_code=theano.config.cxx): def __init__(self, use_c_code=theano.config.cxx):
self._op_use_c_code = use_c_code self._op_use_c_code = use_c_code
def _props(self):
"""
Tuple of properties of all attributes
"""
return tuple(getattr(self, a) for a in self.__props__)
def _props_dict(self):
"""This return a dict of all ``__props__`` key-> value.
This is useful in optimization to swap op that should have the
same props. This help detect error that the new op have at
least all the original props.
"""
return dict([(a, getattr(self, a))
for a in self.__props__])
def __hash__(self):
if hasattr(self, '__props__'):
return hash((type(self), self._props()))
else:
return super(Op, self).__hash__()
def __str__(self):
if hasattr(self, '__props__'):
if len(self.__props__) == 0:
return "%s" % (self.__class__.__name__,)
else:
return "%s{%s}" % (
self.__class__.__name__,
", ".join("%s=%r" % (p, getattr(self, p))
for p in self.__props__))
else:
return super(Op, self).__str__()
def __eq__(self, other):
if hasattr(self, '__props__'):
return (type(self) == type(other) and self._props() ==
other._props())
else:
return NotImplemented
def prepare_node(self, node, storage_map, compute_map): def prepare_node(self, node, storage_map, compute_map):
""" """
Make any special modifications that the Op needs before doing Make any special modifications that the Op needs before doing
......
...@@ -4,7 +4,7 @@ import sys ...@@ -4,7 +4,7 @@ import sys
import traceback import traceback
import numpy import numpy
from six import iteritems, integer_types, string_types from six import iteritems, integer_types, string_types, with_metaclass
from six.moves import StringIO from six.moves import StringIO
from theano import config from theano import config
...@@ -152,22 +152,73 @@ class MethodNotDefined(Exception): ...@@ -152,22 +152,73 @@ class MethodNotDefined(Exception):
function has been left out of an implementation class. function has been left out of an implementation class.
""" """
pass
class MetaObject(type):
def __new__(cls, name, bases, dct):
props = dct.get('__props__', None)
if props is not None:
if not isinstance(props, tuple):
raise TypeError("__props__ has to be a tuple")
if not all(isinstance(p, str) for p in props):
raise TypeError("elements of __props__ have to be strings")
def _props(self):
"""
Tuple of properties of all attributes
"""
return tuple(getattr(self, a) for a in props)
dct['_props'] = _props
def _props_dict(self):
"""This return a dict of all ``__props__`` key-> value.
This is useful in optimization to swap op that should have the
same props. This help detect error that the new op have at
least all the original props.
"""
return dict([(a, getattr(self, a))
for a in props])
dct['_props_dict'] = _props_dict
if '__hash__' not in dct:
def __hash__(self):
return hash((type(self),
tuple(getattr(self, a) for a in props)))
dct['__hash__'] = __hash__
if '__eq__' not in dct:
def __eq__(self, other):
return (type(self) == type(other) and
tuple(getattr(self, a) for a in props) ==
tuple(getattr(other, a) for a in props))
dct['__eq__'] = __eq__
if '__str__' not in dct:
if len(props) == 0:
def __str__(self):
return "%s" % (self.__class__.__name__,)
else:
def __str__(self):
return "%s{%s}" % (
self.__class__.__name__,
", ".join("%s=%r" % (p, getattr(self, p))
for p in props))
dct['__str__'] = __str__
return type.__new__(cls, name, bases, dct)
class object2(object): class object2(with_metaclass(MetaObject, object)):
__slots__ = [] __slots__ = []
if 0:
def __hash__(self):
# this fixes silent-error-prone new-style class behavior
if hasattr(self, '__eq__') or hasattr(self, '__cmp__'):
raise TypeError("unhashable object: %s" % self)
return id(self)
def __ne__(self, other): def __ne__(self, other):
return not self == other return not self == other
class scratchpad: class scratchpad(object):
def clear(self): def clear(self):
self.__dict__.clear() self.__dict__.clear()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论