Modified IncSubtensor related optimizations to take into account the set_instead_of_inc parameter.

上级 1d58f01f
......@@ -346,7 +346,7 @@ def local_IncSubtensor_serialize(node):
#
# add(x, incsubtensor(b, c), incsubtensor(b, d))
# -> incsubtensor(incsubtensor(add(x,b), c), d)
# -> incsubtensor(incsubtensor(add(x,b,b), c), d)
"""
def movable(i):
......@@ -354,7 +354,8 @@ def local_IncSubtensor_serialize(node):
return i.owner \
and isinstance(i.owner.op, T.IncSubtensor) \
and i.type == o_type \
and len(i.clients) == 1
and len(i.clients) == 1 \
and not i.owner.op.set_instead_of_inc
if node.op == T.add:
o_type = node.outputs[0].type
......@@ -383,7 +384,8 @@ def local_IncSubtensor_serialize(node):
@gof.local_optimizer([None])
def local_inplace_setsubtensor(node):
if isinstance(node.op, T.IncSubtensor) and not node.op.inplace:
new_op = T.IncSubtensor(node.op.idx_list, inplace=True)
new_op = T.IncSubtensor(node.op.idx_list, inplace=True, \
set_instead_of_inc=node.op.set_instead_of_inc)
new_node = new_op(*node.inputs)
return [new_node]
return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论