提交 594a136d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix optimization local_subtensor_of_alloc, and failing test.

上级 10abe09a
...@@ -1507,18 +1507,33 @@ def local_subtensor_of_alloc(node): ...@@ -1507,18 +1507,33 @@ def local_subtensor_of_alloc(node):
if not isinstance(u.owner.op, T.Alloc): if not isinstance(u.owner.op, T.Alloc):
return False return False
slices = T.get_idx_list(node.inputs, node.op.idx_list) slices = T.get_idx_list(node.inputs, node.op.idx_list)
val = u.owner.inputs[0]
dims = u.owner.inputs[1:] dims = u.owner.inputs[1:]
assert len(slices) <= len(dims) assert len(slices) <= len(dims)
# Number of dimensions added to val
n_added_dims = u.ndim - val.ndim
# Dimensions of the returned alloc
nw_dims = [] nw_dims = []
for sl, dim in zip(slices, dims): # Slices to take from val
val_slices = []
for i, (sl, dim) in enumerate(zip(slices, dims)):
# If val was not copied over that dim,
# we need to take the appropriate subtensor on it.
if i >= n_added_dims:
val_slices.append(sl)
csl,_ = T.get_canonical_form_slice(sl, dim) csl,_ = T.get_canonical_form_slice(sl, dim)
if type(csl) is not slice: if type(csl) is not slice:
nw_dims+=[T.constant(1)] # That dimension is removed.
pass
else: else:
nw_dims+= [(csl.stop - csl.start)//csl.step] nw_dims+= [(csl.stop - csl.start)//csl.step]
nw_val = val[tuple(val_slices)]
nw_dims += dims[len(slices):] nw_dims += dims[len(slices):]
rval = T.alloc(u.owner.inputs[0], *nw_dims) rval = T.alloc(nw_val, *nw_dims)
if type(rval) not in (list, tuple): if type(rval) not in (list, tuple):
rval = [rval] rval = [rval]
return rval return rval
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论