提交 e858937f authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Other 'shape_lift' optimizer, for advanced indexing by [arange(len(y)), y]

上级 06887a98
...@@ -932,6 +932,22 @@ def _check_rows_is_arange_len_labels(rows, labels): ...@@ -932,6 +932,22 @@ def _check_rows_is_arange_len_labels(rows, labels):
if shape_var.owner and shape_var.owner.op == tensor._shape: if shape_var.owner and shape_var.owner.op == tensor._shape:
return shape_var.owner.inputs[0] is labels return shape_var.owner.inputs[0] is labels
@gof.local_optimizer([tensor._shape])
def local_shape_lift_advanced_indexing_arange(node):
'''shape(a[arange(len(y)), y]) -> shape(y) (conditions apply)'''
if node.op == tensor._shape:
if node.inputs[0].owner and \
isinstance(node.inputs[0].owner.op, tensor.AdvancedSubtensor):
try:
a, rows, labels = node.inputs[0].owner.inputs
except:
return
if _check_rows_is_arange_len_labels(rows, labels):
if labels.ndim == 1 and a.ndim == 2:
return tensor._shape(labels),
opt.register_specialize(local_shape_lift_advanced_indexing_arange, 'shape_lift')
@opt.register_specialize @opt.register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([])
def local_advanced_indexing_crossentropy_onehot(node): def local_advanced_indexing_crossentropy_onehot(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论