提交 85f5b1ae authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

...@@ -618,8 +618,8 @@ class CLinker(link.Linker): ...@@ -618,8 +618,8 @@ class CLinker(link.Linker):
input_storage, input_storage,
output_storage) output_storage)
return thunk, \ return thunk, \
[link.Filter(input.type, storage) for input, storage in zip(self.env.inputs, input_storage)], \ [link.Filter(input, storage) for input, storage in zip(self.env.inputs, input_storage)], \
[link.Filter(output.type, storage, True) for output, storage in zip(self.env.outputs, output_storage)], \ [link.Filter(output, storage, True) for output, storage in zip(self.env.outputs, output_storage)], \
error_storage error_storage
def make_thunk(self, input_storage = None, output_storage = None): def make_thunk(self, input_storage = None, output_storage = None):
...@@ -821,8 +821,8 @@ class OpWiseCLinker(link.LocalLinker): ...@@ -821,8 +821,8 @@ class OpWiseCLinker(link.LocalLinker):
f = self.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler) f = self.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler)
return f, [link.Filter(input.type, storage) for input, storage in zip(env.inputs, input_storage)], \ return f, [link.Filter(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[link.Filter(output.type, storage, True) for output, storage in zip(env.outputs, output_storage)], \ [link.Filter(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
thunks, order thunks, order
......
...@@ -24,6 +24,7 @@ class Apply(utils.object2): ...@@ -24,6 +24,7 @@ class Apply(utils.object2):
""" """
self.op = op self.op = op
self.inputs = [] self.inputs = []
self.tag = utils.scratchpad()
## filter inputs to make sure each element is a Result ## filter inputs to make sure each element is a Result
for input in inputs: for input in inputs:
...@@ -67,7 +68,14 @@ class Apply(utils.object2): ...@@ -67,7 +68,14 @@ class Apply(utils.object2):
def __asapply__(self): def __asapply__(self):
return self return self
def clone(self): def clone(self):
return self.__class__(self.op, self.inputs, [output.clone() for output in self.outputs]) # cp = copy(self)
# cp.outputs = [output.clone() for output in self.outputs]
# for output in cp.outputs:
# output.owner = cp
# return cp
cp = self.__class__(self.op, self.inputs, [output.clone() for output in self.outputs])
cp.tag = copy(self.tag)
return cp
def clone_with_new_inputs(self, inputs, check_type = True): def clone_with_new_inputs(self, inputs, check_type = True):
""" """
Returns an Apply node with the same op but different inputs. Unless Returns an Apply node with the same op but different inputs. Unless
...@@ -96,6 +104,7 @@ class Result(utils.object2): ...@@ -96,6 +104,7 @@ class Result(utils.object2):
""" """
#__slots__ = ['type', 'owner', 'index', 'name'] #__slots__ = ['type', 'owner', 'index', 'name']
def __init__(self, type, owner = None, index = None, name = None): def __init__(self, type, owner = None, index = None, name = None):
self.tag = utils.scratchpad()
self.type = type self.type = type
if owner is not None and not isinstance(owner, Apply): if owner is not None and not isinstance(owner, Apply):
raise TypeError("owner must be an Apply instance", owner) raise TypeError("owner must be an Apply instance", owner)
...@@ -120,7 +129,10 @@ class Result(utils.object2): ...@@ -120,7 +129,10 @@ class Result(utils.object2):
def __repr__(self): def __repr__(self):
return str(self) return str(self)
def clone(self): def clone(self):
return self.__class__(self.type, None, None, self.name) #return copy(self)
cp = self.__class__(self.type, None, None, self.name)
cp.tag = copy(self.tag)
return cp
class Value(Result): class Value(Result):
""" """
......
...@@ -35,7 +35,7 @@ def raise_with_op(op, exc_info = None): ...@@ -35,7 +35,7 @@ def raise_with_op(op, exc_info = None):
exc_info = sys.exc_info() exc_info = sys.exc_info()
exc_type, exc_value, exc_trace = exc_info exc_type, exc_value, exc_trace = exc_info
try: try:
trace = op.trace trace = op.tag.trace
except AttributeError: except AttributeError:
trace = () trace = ()
exc_value.__thunk_trace__ = trace exc_value.__thunk_trace__ = trace
...@@ -107,20 +107,24 @@ class Linker: ...@@ -107,20 +107,24 @@ class Linker:
class Filter(object): class Filter(object):
def __init__(self, type, storage, readonly = False, strict = False): def __init__(self, r, storage, readonly = False, strict = False, trace = ()):
self.type = type self.r = r
self.type = r.type
self.storage = storage self.storage = storage
self.readonly = readonly self.readonly = readonly
self.strict = strict self.strict = strict
def __get(self): def __get(self):
return self.storage[0] return self.storage[0]
def __set(self, value): def __set(self, value):
try:
if self.readonly: if self.readonly:
raise Exception("Cannot set readonly storage.") raise Exception("Cannot set readonly storage.")
if self.strict: if self.strict:
self.storage[0] = self.type.filter(value, strict = True) self.storage[0] = self.type.filter(value, strict = True)
else: else:
self.storage[0] = self.type.filter(value) self.storage[0] = self.type.filter(value)
except:
raise_with_op(self.r)
data = property(__get, __set) data = property(__get, __set)
def __str__(self): def __str__(self):
return "<" + str(self.storage[0]) + ">" return "<" + str(self.storage[0]) + ">"
...@@ -245,8 +249,8 @@ class PerformLinker(LocalLinker): ...@@ -245,8 +249,8 @@ class PerformLinker(LocalLinker):
f = self.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler) f = self.streamline(env, thunks, order, no_recycling = no_recycling, profiler = profiler)
return f, [Filter(input.type, storage) for input, storage in zip(env.inputs, input_storage)], \ return f, [Filter(input, storage) for input, storage in zip(env.inputs, input_storage)], \
[Filter(output.type, storage, True) for output, storage in zip(env.outputs, output_storage)], \ [Filter(output, storage, True) for output, storage in zip(env.outputs, output_storage)], \
thunks, order thunks, order
# return f, env.inputs, env.outputs # return f, env.inputs, env.outputs
......
...@@ -4,6 +4,7 @@ compatible with gof's graph manipulation routines. ...@@ -4,6 +4,7 @@ compatible with gof's graph manipulation routines.
""" """
import utils import utils
import traceback
class Op(utils.object2): class Op(utils.object2):
...@@ -31,6 +32,7 @@ class Op(utils.object2): ...@@ -31,6 +32,7 @@ class Op(utils.object2):
self.make_node(*inputs).outputs (if more than one output) self.make_node(*inputs).outputs (if more than one output)
""" """
node = self.make_node(*inputs) node = self.make_node(*inputs)
node.tag.trace = traceback.extract_stack()[:-1]
if self.default_output is not None: if self.default_output is not None:
return node.outputs[self.default_output] return node.outputs[self.default_output]
else: else:
......
...@@ -3,6 +3,7 @@ import copy ...@@ -3,6 +3,7 @@ import copy
import utils import utils
from utils import AbstractFunctionError, object2 from utils import AbstractFunctionError, object2
from graph import Result from graph import Result
import traceback
######## ########
...@@ -23,10 +24,13 @@ class Type(object2): ...@@ -23,10 +24,13 @@ class Type(object2):
raise AbstractFunctionError() raise AbstractFunctionError()
def make_result(self, name = None): def make_result(self, name = None):
return Result(self, name = name) r = Result(self, name = name)
return r
def __call__(self, name = None): def __call__(self, name = None):
return self.make_result(name) r = self.make_result(name)
r.tag.trace = traceback.extract_stack()[:-1]
return r
def c_is_simple(self): def c_is_simple(self):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论