提交 b747a0cb authored 作者: bergstrj@iro.umontreal.ca's avatar bergstrj@iro.umontreal.ca

merged

......@@ -210,12 +210,33 @@ class _test_single(unittest.TestCase):
core.pop_mode()
def test_3(self):
a = core.Numpy2(data=numpy.ones((2,2)))
b = core.Numpy2(data=numpy.ones((2,2)))
a = core.Numpy2(data=numpy.random.rand(2,2))
b = core.Numpy2(data=numpy.random.rand(2,2))
c = core.add(a,b)
self.failUnless(c.data is None)
self.failUnless(c.state is Empty)
new_a = numpy.random.rand(2,2)
new_b = numpy.random.rand(2,2)
a.data = new_a
b.data = new_b
p = single(c)
p()
self.failUnless(core._approx_eq(c, numpy.ones((2,2))*2))
self.failUnless(core._approx_eq(c, new_a + new_b))
def test_get_element(self):
a_data = numpy.random.rand(2,2)
a = core.Numpy2(data=a_data)
a_i = a[0,0]
p = single(a_i)
for i in 0,1:
for j in 0,1:
p()
self.failUnless(a_data[i,j] == a_i.data)
if __name__ == '__main__':
unittest.main()
......
......@@ -116,33 +116,33 @@ class Numpy2(ResultBase):
################################
# Numpy2 specific functionality
#
__array__ = property(lambda self: self._data.__array__ )
__array_struct__ = property(lambda self: self._data.__array_struct__ )
__array__ = property(lambda self: self.data.__array__ )
__array_struct__ = property(lambda self: self.data.__array_struct__ )
def data_alloc(self):
return numpy.ndarray(self.shape, self.dtype)
# self._dtype is used when self._data hasn't been set yet
# self._dtype is used when self.data hasn't been set yet
def __dtype_get(self):
if self._data is None:
if self.data is None:
return self._dtype
else:
return self._data.dtype
return self.data.dtype
def __dtype_set(self, dtype):
if self._data is None:
if self.data is None:
self._dtype = dtype
else:
raise StateError('cannot set dtype after data has been set')
dtype = property(__dtype_get, __dtype_set)
# self._shape is used when self._data hasn't been set yet
# self._shape is used when self.data hasn't been set yet
def __shape_get(self):
if self._data is None:
if self.data is None:
return self._shape
else:
return self._data.shape
return self.data.shape
def __shape_set(self, shape):
if self._data is None:
if self.data is None:
self._shape = shape
else:
raise StateError('cannot set shape after data has been set')
......@@ -1729,7 +1729,7 @@ class _testCase_slicing(unittest.TestCase):
wa1 = wrap(a)[0:8:2]
for i in xrange(8): a[i] = i
self.failUnless(wa1.data.shape == (4,))
self.failUnless(wa1.shape == (4,))
for i in xrange(4):
self.failUnless(a[i*2] == wa1.data[i])
def test_getslice_3d_float(self):
......@@ -1737,10 +1737,18 @@ class _testCase_slicing(unittest.TestCase):
a = numpy.asarray(range(4*5*6))
a.resize((4,5,6))
wa1 = wrap(a)[1:3]
wa1.data.shape
self.failUnless(wa1.shape == (2,5,6))
self.failUnless(numpy.all(a[1:3] == wa1.data))
a[1] *= -1.0
self.failUnless(numpy.all(a[1:3] == wa1.data))
def test_getslice_3d_one(self):
"""Test getslice on 3d array"""
a = numpy.asarray(range(4*5*6))
a.resize((4,5,6))
wa = wrap(a)
wa_123 = wa[1,2,3]
self.failUnless(wa_123.shape == (), wa_123.shape)
add = scalar_switch(add_elemwise, add_scalar, add_scalar)
add_inplace = scalar_switch(add_elemwise_inplace, add_scalar_inplace)
......
......@@ -53,7 +53,22 @@ class InconsistencyError(GofError):
"""
pass
def require_set(cls):
"""Return the set of objects named in a __env_require__ field in a base class"""
r = set()
if hasattr(cls, '__class__'):
cls = cls.__class__
bases = utils.all_bases(cls, lambda cls: hasattr(cls, '__env_require__'))
for base in bases:
req = base.__env_require__
if isinstance(req, (list, tuple)):
r.update(req)
else:
r.add(req)
return r
class Env(graph.Graph):
"""
......@@ -179,7 +194,7 @@ class Env(graph.Graph):
return True
def satisfy(self, x):
for feature_class in x.require():
for feature_class in require_set(x):
self.add_feature(feature_class)
def add_feature(self, feature_class, do_import = True):
......
from copy import copy
from op import Op
from result import is_result, ResultBase
......@@ -57,27 +58,31 @@ def compute(*nodes):
"""Recursively evaluate each node (in a quick & dirty way)."""
compute_from(nodes, set())
class ForbidConstantOverwrite(features.Listener, features.Constraint):
def __init__(self, env):
self.env = env
self.bad = set()
def root_inputs(self, input):
owner = input.owner
def root_inputs(input):
"""Return the leaves of a search through consecutive view_map()s"""
owner = input.owner
if owner:
view_map = owner.view_map()
if input in view_map:
answer = []
for input2 in view_map[input]:
answer += owner.root_inputs(input2)
answer.extend(root_inputs(input2))
return answer
else:
return [input]
else:
return [input]
class ForbidConstantOverwrite(features.Listener, features.Constraint):
def __init__(self, env):
self.env = env
self.bad = set()
def on_import(self, op):
for output, inputs in op.destroy_map().items():
for input in inputs:
for root_input in self.root_inputs(input):
for root_input in root_inputs(input):
if getattr(root_input, 'constant', False):
self.bad.add(op)
return
......@@ -185,9 +190,7 @@ class DestroyHandler(features.Listener, features.Constraint, features.Orderings)
self.__detect_cycles_helper__(user, [])
def get_maps(self, op):
vmap = getattr(op, 'view_map',{})
dmap = getattr(op, 'destoy_map', {})
return vmap, dmap
return op.view_map(), op.destroy_map()
def on_import(self, op):
view_map, destroy_map = self.get_maps(op)
......@@ -347,6 +350,8 @@ class DestroyHandler(features.Listener, features.Constraint, features.Orderings)
class NewPythonOp(Op):
__env_require__ = DestroyHandler
def view_map(self):
return {}
......@@ -358,8 +363,6 @@ class PythonOp(NewPythonOp):
__metaclass__ = ClsInit
__require__ = DestroyHandler
nout = 1
@staticmethod
......@@ -389,45 +392,28 @@ class PythonOp(NewPythonOp):
NewPythonOp.__init__(self, inputs, self.gen_outputs())
def __validate__(self):
return all([ is_result(i) for i in self.inputs])
return all([is_result(i) for i in self.inputs])
def gen_outputs(self):
raise AbstractFunctionError()
@staticmethod
def root_inputs(input):
owner = input.owner
if owner:
view_map = owner.view_map()
if input in view_map:
answer = []
for input2 in view_map[input]:
answer += owner.root_inputs(input2)
return answer
else:
return [input]
else:
return [input]
def input_is_up_to_date(self, input):
answer = True
for input in self.root_inputs(input):
answer &= input.up_to_date
return answer
def input_is_constant(self, input):
answer = False
for input in self.root_inputs(input):
answer |= input.constant
return answer
def check_input(self, input):
def input_is_up_to_date(input):
answer = True
for input in root_inputs(input):
answer &= input.up_to_date
return answer
if input.data is None:
raise ValueError("Uncomputed input: %s in %s" % (input, self))
if not self.input_is_up_to_date(input):
if not input_is_up_to_date(input):
raise ValueError("Input is out of date: %s in %s" % (input, self))
def perform(self):
def input_is_constant(input):
answer = False
for input in root_inputs(input):
answer |= input.constant
return answer
exc = set()
for output, inputs in self.destroy_map().items():
exc.update(inputs)
......@@ -518,7 +504,7 @@ class PythonOp(NewPythonOp):
if output not in except_list:
output.alloc()
__require__ = ForbidConstantOverwrite
__env_require__ = ForbidConstantOverwrite
def __copy__(self):
"""
......@@ -580,7 +566,7 @@ class PythonOpt(opt.Optimizer):
class DummyOp(Op):
class DummyOp(NewPythonOp):
def __init__(self, input):
Op.__init__(self, [input], [ResultBase()])
......
......@@ -28,7 +28,6 @@ class Op(object):
"""
__slots__ = ['_inputs', '_outputs']
__require__ = []
#create inputs and outputs as read-only attributes
inputs = property(lambda self: self._inputs, doc = "The list of this Op's input Results.")
......@@ -229,11 +228,10 @@ class Op(object):
"""
r = set()
bases = all_bases(cls, lambda cls: hasattr(cls, '__require__'))
bases.append(cls)
bases = all_bases(cls, lambda cls: hasattr(cls, '__env_require__'))
for base in bases:
req = base.__require__
req = base.__env_require__
if isinstance(req, (list, tuple)):
r.update(req)
else:
......
......@@ -9,8 +9,6 @@ import ext
class Optimizer:
__require__ = ()
def apply(self, env):
pass
......@@ -18,27 +16,6 @@ class Optimizer:
env.satisfy(self)
self.apply(env)
@classmethod
def require(cls):
"""
Returns a list of EnvFeature subclasses that must be used by
any Env manipulating this kind of op. For instance, a
Destroyer requires features.DestroyHandler to guarantee that
various destructive operations don't interfere.
"""
r = set()
bases = utils.all_bases(cls, lambda cls: hasattr(cls, '__require__'))
bases.append(cls)
for base in bases:
req = base.__require__
if isinstance(req, (list, tuple)):
r.update(req)
else:
r.add(req)
return r
def __call__(self, env):
self.optimize(env)
......@@ -88,7 +65,7 @@ class LocalOptimizer(Optimizer):
class OpSpecificOptimizer(LocalOptimizer):
__require__ = features.InstanceFinder
__env_require__ = features.InstanceFinder
opclass = Op
......@@ -100,7 +77,7 @@ class OpSpecificOptimizer(LocalOptimizer):
class OpSubOptimizer(Optimizer):
__require__ = features.InstanceFinder
__env_require__ = features.InstanceFinder
def __init__(self, op1, op2):
if not op1.has_default_output:
......@@ -127,7 +104,7 @@ class OpSubOptimizer(Optimizer):
class OpRemover(Optimizer):
__require__ = features.InstanceFinder
__env_require__ = features.InstanceFinder
def __init__(self, opclass):
self.opclass = opclass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论