提交 06f7bd5d authored 作者: gdesjardins's avatar gdesjardins

* top-level Variable class raises NotImplementedError when trying using comparison operators

* changed inheritance order of TensorVariable to include _py_operators before Variable, so that TensorVariable can use the comparison operators defined in _py_operators
上级 e0df6d68
...@@ -299,6 +299,15 @@ class Variable(utils.object2): ...@@ -299,6 +299,15 @@ class Variable(utils.object2):
cp = self.__class__(self.type, None, None, self.name) cp = self.__class__(self.type, None, None, self.name)
cp.tag = copy(self.tag) cp.tag = copy(self.tag)
return cp return cp
def __lt__(self,other):
raise NotImplementedError('Subclasses of Variable must implement __lt__')
def __le__(self,other):
raise NotImplementedError('Subclasses of Variable must implement __le__')
def __gt__(self,other):
raise NotImplementedError('Subclasses of Variable must implement __gt__')
def __ge__(self,other):
raise NotImplementedError('Subclasses of Variable must implement __ge__')
class Value(Variable): class Value(Variable):
""" """
......
...@@ -1130,7 +1130,7 @@ class _tensor_py_operators: ...@@ -1130,7 +1130,7 @@ class _tensor_py_operators:
def get_constant_value(self): def get_constant_value(self):
return get_constant_value(self) return get_constant_value(self)
class TensorVariable(Variable, _tensor_py_operators): class TensorVariable(_tensor_py_operators, Variable):
"""Subclass to add the tensor operators to the basic `Variable` class.""" """Subclass to add the tensor operators to the basic `Variable` class."""
TensorType.Variable = TensorVariable TensorType.Variable = TensorVariable
...@@ -1162,7 +1162,7 @@ class TensorConstantSignature(tuple): ...@@ -1162,7 +1162,7 @@ class TensorConstantSignature(tuple):
sum = property(_get_sum) sum = property(_get_sum)
class TensorConstant(Constant, _tensor_py_operators): class TensorConstant(_tensor_py_operators, Constant):
"""Subclass to add the tensor operators to the basic `Constant` class. """Subclass to add the tensor operators to the basic `Constant` class.
To create a TensorConstant, use the `constant` function in this module. To create a TensorConstant, use the `constant` function in this module.
...@@ -1171,7 +1171,7 @@ class TensorConstant(Constant, _tensor_py_operators): ...@@ -1171,7 +1171,7 @@ class TensorConstant(Constant, _tensor_py_operators):
return TensorConstantSignature((self.type, self.data)) return TensorConstantSignature((self.type, self.data))
TensorType.Constant = TensorConstant TensorType.Constant = TensorConstant
class TensorValue(Value, _tensor_py_operators): class TensorValue(_tensor_py_operators, Value):
"""Subclass to add the tensor operators to the basic `Value` class. """Subclass to add the tensor operators to the basic `Value` class.
To create a TensorValue, use the `value` function in this module. To create a TensorValue, use the `value` function in this module.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论