Modified IncSubtensor related optimizations to take into account the set_instead_of_inc parameter.

上级 1d58f01f
...@@ -346,7 +346,7 @@ def local_IncSubtensor_serialize(node): ...@@ -346,7 +346,7 @@ def local_IncSubtensor_serialize(node):
# #
# add(x, incsubtensor(b, c), incsubtensor(b, d)) # add(x, incsubtensor(b, c), incsubtensor(b, d))
# -> incsubtensor(incsubtensor(add(x,b), c), d) # -> incsubtensor(incsubtensor(add(x,b,b), c), d)
""" """
def movable(i): def movable(i):
...@@ -354,7 +354,8 @@ def local_IncSubtensor_serialize(node): ...@@ -354,7 +354,8 @@ def local_IncSubtensor_serialize(node):
return i.owner \ return i.owner \
and isinstance(i.owner.op, T.IncSubtensor) \ and isinstance(i.owner.op, T.IncSubtensor) \
and i.type == o_type \ and i.type == o_type \
and len(i.clients) == 1 and len(i.clients) == 1 \
and not i.owner.op.set_instead_of_inc
if node.op == T.add: if node.op == T.add:
o_type = node.outputs[0].type o_type = node.outputs[0].type
...@@ -383,7 +384,8 @@ def local_IncSubtensor_serialize(node): ...@@ -383,7 +384,8 @@ def local_IncSubtensor_serialize(node):
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def local_inplace_setsubtensor(node): def local_inplace_setsubtensor(node):
if isinstance(node.op, T.IncSubtensor) and not node.op.inplace: if isinstance(node.op, T.IncSubtensor) and not node.op.inplace:
new_op = T.IncSubtensor(node.op.idx_list, inplace=True) new_op = T.IncSubtensor(node.op.idx_list, inplace=True, \
set_instead_of_inc=node.op.set_instead_of_inc)
new_node = new_op(*node.inputs) new_node = new_op(*node.inputs)
return [new_node] return [new_node]
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论