提交 08c2b2e4 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Optimization for subtensor of alloc.

上级 af68ea42
...@@ -1498,6 +1498,37 @@ def local_subtensor_merge(node): ...@@ -1498,6 +1498,37 @@ def local_subtensor_merge(node):
return [ out ] return [ out ]
@register_canonicalize
@register_specialize
@gof.local_optimizer([])
def local_subtensor_of_alloc(node):
"""alloc[x:y] -> alloc"""
if not isinstance(node.op, T.Subtensor):
return False
u = node.inputs[0]
if u.owner is None:
return False
if not isinstance(u.owner.op, T.Alloc):
return False
slices = T.get_idx_list(node.inputs, node.op.idx_list)
dims = u.owner.inputs[1:]
assert len(slices) <= len(dims)
nw_dims = []
for sl, dim in zip(slices, dims):
csl,_ = T.get_canonical_form_slice(sl, dim)
if type(csl) is not slice:
nw_dims+=[T.constant(1)]
else:
nw_dims+= [(csl.stop - csl.start)//csl.step]
nw_dims += dims[len(slices):]
rval = T.alloc(u.owner.inputs[0], *nw_dims)
if type(rval) not in (list, tuple):
rval = [rval]
return rval
@register_canonicalize @register_canonicalize
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def local_IncSubtensor_serialize(node): def local_IncSubtensor_serialize(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论