提交 84f0c415 authored 作者: James Bergstra's avatar James Bergstra

adding new destroyhandler mechanism: destroyhandler_tolerate_aliased

上级 30bafbcf
...@@ -470,25 +470,44 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper): ...@@ -470,25 +470,44 @@ class DestroyHandlerHelper2(toolbox.Bookkeeper):
# add_inplace(x, x.T). In some special cases though, the in-place op will # add_inplace(x, x.T). In some special cases though, the in-place op will
# actually be able to work properly with multiple destroyed inputs (e.g, # actually be able to work properly with multiple destroyed inputs (e.g,
# add_inplace(x, x). An Op that can still work in this case should declare # add_inplace(x, x). An Op that can still work in this case should declare
# so via the 'tolerate_same' attribute # so via the 'destroyhandler_tolerate_same' attribute or
# 'destroyhandler_tolerate_aliased' attribute.
# #
# tolerate_same should be a list of pairs of the form # destroyhandler_tolerate_same should be a list of pairs of the form
# [(idx0, idx1), (idx0, idx2), ...] # [(idx0, idx1), (idx0, idx2), ...]
# The first element of each pair is the index of a destroyed # The first element of each pair is the input index of a destroyed
# variable. # variable.
# The second element of each pair is the index of a different input where # The second element of each pair is the index of a different input where
# we will permit exactly the same variable to appear. # we will permit exactly the same variable to appear.
# For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed # For example, add_inplace.tolerate_same might be [(0,1)] if the destroyed
# input is also allowed to appear as the second argument. # input is also allowed to appear as the second argument.
#
# destroyhandler_tolerate_alias is the same sort of list of
# pairs.
# op.destroyhandler_tolerate_alias = [(idx0, idx1)] tells the
# destroyhandler to IGNORE an aliasing between a destroyed
# input idx0 and another input idx1.
# This is generally a bad idea, but it is safe in some
# cases, such as
# - the op reads from the aliased idx1 before modifying idx0
# - the idx0 and idx1 are guaranteed not to overlap (e.g.
# they are pointed at different rows of a matrix).
#
#CHECK FOR INPUT ALIASING #CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import # OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'tolerate_same', []) tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
tolerated = set(idx1 for idx0, idx1 in tolerate_same tolerated = set(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx) if idx0 == destroyed_idx)
tolerated.add(destroyed_idx) tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
ignored = set(idx1 for idx0, idx1 in tolerate_aliased
if idx0 == destroyed_idx)
#print 'tolerated', tolerated #print 'tolerated', tolerated
#print 'ignored', ignored
for i, input in enumerate(app.inputs): for i, input in enumerate(app.inputs):
if i in ignored:
continue
if input in root_impact \ if input in root_impact \
and (i not in tolerated or input is not destroyed_variable): and (i not in tolerated or input is not destroyed_variable):
raise InconsistencyError("Input aliasing: %s (%i, %i)" raise InconsistencyError("Input aliasing: %s (%i, %i)"
......
...@@ -40,13 +40,14 @@ def MyValue(data): ...@@ -40,13 +40,14 @@ def MyValue(data):
class MyOp(Op): class MyOp(Op):
def __init__(self, nin, name, vmap = {}, dmap = {}, nout = 1, tolerate_same = []): def __init__(self, nin, name, vmap = {}, dmap = {}, nout = 1,
destroyhandler_tolerate_same = []):
self.nin = nin self.nin = nin
self.nout = nout self.nout = nout
self.name = name self.name = name
self.destroy_map = dmap self.destroy_map = dmap
self.view_map = vmap self.view_map = vmap
self.tolerate_same = tolerate_same self.destroyhandler_tolerate_same = destroyhandler_tolerate_same
def make_node(self, *inputs): def make_node(self, *inputs):
assert len(inputs) == self.nin assert len(inputs) == self.nin
...@@ -65,7 +66,7 @@ sigmoid = MyOp(1, 'Sigmoid') ...@@ -65,7 +66,7 @@ sigmoid = MyOp(1, 'Sigmoid')
transpose_view = MyOp(1, 'TransposeView', vmap = {0: [0]}) transpose_view = MyOp(1, 'TransposeView', vmap = {0: [0]})
add = MyOp(2, 'Add') add = MyOp(2, 'Add')
add_in_place = MyOp(2, 'AddInPlace', dmap = {0: [0]}) add_in_place = MyOp(2, 'AddInPlace', dmap = {0: [0]})
add_in_place_2 = MyOp(2, 'AddInPlace', dmap = {0: [0]}, tolerate_same = [(0, 1)]) add_in_place_2 = MyOp(2, 'AddInPlace', dmap = {0: [0]}, destroyhandler_tolerate_same = [(0, 1)])
dot = MyOp(2, 'Dot') dot = MyOp(2, 'Dot')
......
...@@ -3113,7 +3113,8 @@ class SubtensorPrinter: ...@@ -3113,7 +3113,8 @@ class SubtensorPrinter:
pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), SubtensorPrinter()) pprint.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, Subtensor), SubtensorPrinter())
def set_subtensor(x, y, inplace=False): def set_subtensor(x, y, inplace=False,
tolerate_inplace_aliasing=False):
"""Return x with the given subtensor overwritten by y. """Return x with the given subtensor overwritten by y.
Example: To replicate the numpy expression "r[10:] = 5", type Example: To replicate the numpy expression "r[10:] = 5", type
...@@ -3122,14 +3123,21 @@ def set_subtensor(x, y, inplace=False): ...@@ -3122,14 +3123,21 @@ def set_subtensor(x, y, inplace=False):
:param x: symbolic variable for the lvalue of = operation :param x: symbolic variable for the lvalue of = operation
:param y: symbolic variable for the rvalue of = operation :param y: symbolic variable for the rvalue of = operation
:param tolerate_inplace_aliasing: see inc_subtensor for documentation.
""" """
return inc_subtensor(x, y, inplace, set_instead_of_inc=True) return inc_subtensor(x, y, inplace, set_instead_of_inc=True)
def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False): def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
tolerate_inplace_aliasing=False):
"""Return x with the given subtensor incremented by y. """Return x with the given subtensor incremented by y.
:param x: the symbolic result of a Subtensor operation. :param x: the symbolic result of a Subtensor operation.
:param y: the amount by which to increment ths subtensor in question :param y: the amount by which to increment ths subtensor in question
:param tolerate_inplace_aliasing: allow x and y to be views of a single
underlying array even while working inplace. For correct results,
x and y must not be overlapping views; if they overlap, the result
of this Op will generally be incorrect. This value has no effect if
inplace=False.
Example: To replicate the numpy expression "r[10:] += 5", type Example: To replicate the numpy expression "r[10:] += 5", type
...@@ -3138,7 +3146,12 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False): ...@@ -3138,7 +3146,12 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False):
# retrieve idx_list from x.owner # retrieve idx_list from x.owner
if not isinstance(x.owner.op, Subtensor): if not isinstance(x.owner.op, Subtensor):
raise TypeError('x must be result of a subtensor operation') raise TypeError('x must be result of a subtensor operation')
the_op = IncSubtensor(x.owner.op.idx_list, inplace, set_instead_of_inc) if tolerate_inplace_aliasing:
destroyhandler_tolerate_aliased = [[0,1]]
else:
destroyhandler_tolerate_aliased = []
the_op = IncSubtensor(x.owner.op.idx_list, inplace, set_instead_of_inc,
destroyhandler_tolerate_aliased=destroyhandler_tolerate_aliased)
real_x = x.owner.inputs[0] real_x = x.owner.inputs[0]
real_idxargs = x.owner.inputs[1:] real_idxargs = x.owner.inputs[1:]
return the_op(real_x, y, *real_idxargs) return the_op(real_x, y, *real_idxargs)
...@@ -3157,11 +3170,13 @@ class IncSubtensor(Op): ...@@ -3157,11 +3170,13 @@ class IncSubtensor(Op):
of incrementing it by that value. of incrementing it by that value.
""" """
def __init__(self, idx_list, inplace=False, set_instead_of_inc=False): def __init__(self, idx_list, inplace=False, set_instead_of_inc=False,
destroyhandler_tolerate_aliased=[]):
self.idx_list = map(Subtensor.convert, idx_list) self.idx_list = map(Subtensor.convert, idx_list)
self.inplace = inplace self.inplace = inplace
if inplace: if inplace:
self.destroy_map = {0: [0]} self.destroy_map = {0: [0]}
self.destroyhandler_tolerate_aliased = list(destroyhandler_tolerate_aliased)
self.set_instead_of_inc = set_instead_of_inc self.set_instead_of_inc = set_instead_of_inc
def __eq__(self, other): def __eq__(self, other):
...@@ -3211,8 +3226,10 @@ class IncSubtensor(Op): ...@@ -3211,8 +3226,10 @@ class IncSubtensor(Op):
idx_list = list(self.idx_list) idx_list = list(self.idx_list)
if len(idx_list) > x.type.ndim: if len(idx_list) > x.type.ndim:
exception = ValueError(Subtensor.e_invalid%(len(idx_list), exception = ValueError(
x.type.ndim)) Subtensor.e_invalid%(
len(idx_list),
x.type.ndim))
exception.subtensor_invalid = True exception.subtensor_invalid = True
raise exception raise exception
......
...@@ -1518,7 +1518,9 @@ def local_inplace_setsubtensor(node): ...@@ -1518,7 +1518,9 @@ def local_inplace_setsubtensor(node):
""" """
if isinstance(node.op, T.IncSubtensor) and not node.op.inplace: if isinstance(node.op, T.IncSubtensor) and not node.op.inplace:
new_op = node.op.__class__(node.op.idx_list, inplace=True, \ new_op = node.op.__class__(node.op.idx_list, inplace=True, \
set_instead_of_inc=node.op.set_instead_of_inc) set_instead_of_inc=node.op.set_instead_of_inc,
destroyhandler_tolerate_aliased=\
node.op.destroyhandler_tolerate_aliased)
new_node = new_op(*node.inputs) new_node = new_op(*node.inputs)
return [new_node] return [new_node]
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论