提交 cb084f7f authored 作者: Razvan Pascanu's avatar Razvan Pascanu

new scan_permutation op

This op takes care of the special case when the entries of scan are not in the righr order. Such a thing can happen once the save mem optimization is implemented, and the memory buffer for an output is much smaller then the number of steps. Then scan will just roll around that memory buffer, and entry 0 might end up somewhere in the middle
上级 971da4ee
......@@ -340,3 +340,53 @@ def allocate_memory(T, y_info, y):
ins_shapes.append(in_shape)
shape = infer_shape([y], inputs, ins_shapes)[0]
return tensor.zeros([T, ] + shape, dtype=y.dtype)
class ScanPermutation(gof.Op):
def __init__(self, inplace=False):
self.inplace = inplace
if inplace:
self.destroy_map = {0: [0]}
def __eq__(self, other):
return type(self) == type(other) and self.inplace == other.inplace
def __hash__(self):
return hash(type(self)) ^ hash(self.inplace)
def __str__(self):
if self.inplace:
return "scan_permutation{inplace}"
else:
return "scan_permutation"
def make_node(self, membuffer, index):
# index has to be a scalar
assert index.ndim = 0
# we neeed at least one dimension
assert membuffer.ndim > 0
return gof.Apply(self, [membuffer, index], [membuffer.type()])
def perform(self, node, inputs, outputs):
membuffer = inputs[0]
index = inputs[0]
if index <= membuffer.shape[0] or index % membuffer.shape[0] == 0:
if self.inplace:
outputs[0] = membuffer
else:
outputs[0][:] = membuffer
else:
pos = index % membuffer.shape[0]
if outputs[0] is membuffer:
membuffer = membuffer.copy()
outputs[0][:membuffer.shape[0] - p] = membuffer[p:]
outputs[0][membuffer.shape[0] - p:] = membuffer[:p]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None]
return self.make_node(eval_points[0], inputs[1]).outputs
def grad(self, inputs, grads):
pos = inputs[0].shape[0] - (inputs[1] % inputs[0].shape[0])
return self.make_node(grads[0], pos).outputs
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论