cleaned up gof.op, added constant, same_properties and mergeable to Scalar

上级 cae06547
import unittest import unittest
from gof import ResultBase from gof import ResultBase, Op, Env, modes
from gof import Op
from gof import Env
from gof import modes
from scalar_ops import * from scalar_ops import *
......
差异被折叠。
差异被折叠。
...@@ -3,9 +3,7 @@ import numpy ...@@ -3,9 +3,7 @@ import numpy
from copy import copy from copy import copy
from gof import ResultBase from gof import ResultBase, GuardedOp, utils
from gof import Op
from gof import utils
def as_scalar(x, name = None): def as_scalar(x, name = None):
...@@ -21,13 +19,32 @@ class Scalar(ResultBase): ...@@ -21,13 +19,32 @@ class Scalar(ResultBase):
def __init__(self, dtype, name=None): def __init__(self, dtype, name=None):
self.dtype = dtype self.dtype = dtype
self.constant = False
ResultBase.__init__(self, role = None, data = None, name = name) ResultBase.__init__(self, role = None, data = None, name = name)
def __get_constant(self):
return self._constant
def __set_constant(self, value):
if value:
self.indestructible = True
self.constant = value
constant = property(__get_constant, __set_constant)
def validate(self, data): def validate(self, data):
py_type = self.py_type() py_type = self.py_type()
if not isinstance(data, py_type): if not isinstance(data, py_type):
raise TypeError("Expected %s instance." % py_type) raise TypeError("Expected %s instance." % py_type)
def same_properties(self, other):
return other.dtype == self.dtype
def mergeable(self, other):
return getattr(self, 'constant', False) \
and getattr(other, 'constant', False) \
and self.data == other.data
def py_type(self): def py_type(self):
return {'float64': float}[self.dtype] return {'float64': float}[self.dtype]
...@@ -74,7 +91,7 @@ class Scalar(ResultBase): ...@@ -74,7 +91,7 @@ class Scalar(ResultBase):
class ScalarMixedOp(Op): class ScalarMixedOp(GuardedOp):
nin = -1 nin = -1
nout = 1 nout = 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论