提交 2da94091 authored 作者: Frederic Bastien's avatar Frederic Bastien

make the optimizer local_inplace_setsubtensor work for GpuIncSubtensor too.

上级 b8a75301
...@@ -942,8 +942,11 @@ def local_IncSubtensor_serialize(node): ...@@ -942,8 +942,11 @@ def local_IncSubtensor_serialize(node):
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def local_inplace_setsubtensor(node): def local_inplace_setsubtensor(node):
"""
Also work for GpuIncSubtensor
"""
if isinstance(node.op, T.IncSubtensor) and not node.op.inplace: if isinstance(node.op, T.IncSubtensor) and not node.op.inplace:
new_op = T.IncSubtensor(node.op.idx_list, inplace=True, \ new_op = node.op.__class__(node.op.idx_list, inplace=True, \
set_instead_of_inc=node.op.set_instead_of_inc) set_instead_of_inc=node.op.set_instead_of_inc)
new_node = new_op(*node.inputs) new_node = new_op(*node.inputs)
return [new_node] return [new_node]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论