提交 aff13140 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add an optimization to convert set_subtensor() of an add into inc_subtensor().

This did come up in an implementation of adagrad at some point and will be much faster on GPU.
上级 1bafa2d4
...@@ -29,7 +29,8 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice, ...@@ -29,7 +29,8 @@ from theano.tensor.subtensor import (get_idx_list, get_canonical_form_slice,
Subtensor, IncSubtensor, make_constant, Subtensor, IncSubtensor, make_constant,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedSubtensor1) AdvancedSubtensor1,
advanced_inc_subtensor1)
from theano import scalar from theano import scalar
from theano.tensor import basic as T from theano.tensor import basic as T
from theano import compile # to register the optimizer built by this file from theano import compile # to register the optimizer built by this file
...@@ -1345,10 +1346,10 @@ def local_subtensor_make_vector(node): ...@@ -1345,10 +1346,10 @@ def local_subtensor_make_vector(node):
replace all subtensor(make_vector) like: replace all subtensor(make_vector) like:
[a,b,c][0] -> a [a,b,c][0] -> a
[a,b,c][0:2] -> [a,b] [a,b,c][0:2] -> [a,b]
replace all AdvancedSubtensor1(make_vector) like: replace all AdvancedSubtensor1(make_vector) like:
[a,b,c][[0,2]] -> [a,c] [a,b,c][[0,2]] -> [a,c]
we can do this for constant indexes we can do this for constant indexes
""" """
x = node.inputs[0] x = node.inputs[0]
...@@ -1629,13 +1630,13 @@ def local_elemwise_alloc(node): ...@@ -1629,13 +1630,13 @@ def local_elemwise_alloc(node):
if len(node.outputs) > 1: if len(node.outputs) > 1:
# Ensure all outputs have the same broadcast pattern # Ensure all outputs have the same broadcast pattern
# This is a supposition that I'm not sure is always true. # This is a supposition that I'm not sure is always true.
assert all([o.type.broadcastable == assert all([o.type.broadcastable ==
node.outputs[0].type.broadcastable for o in node.outputs[0].type.broadcastable for o in
node.outputs[1:]]) node.outputs[1:]])
# The broadcast pattern of the ouptut must match the broadcast pattern of # The broadcast pattern of the ouptut must match the broadcast pattern of
# at least one of the inputs. # at least one of the inputs.
if not any([i.type.broadcastable == if not any([i.type.broadcastable ==
node.outputs[0].type.broadcastable for i in node.inputs]): node.outputs[0].type.broadcastable for i in node.inputs]):
return False return False
...@@ -1874,6 +1875,34 @@ def local_useless_inc_subtensor(node): ...@@ -1874,6 +1875,34 @@ def local_useless_inc_subtensor(node):
return [Subtensor(node.op.idx_list)(*node.inputs[1:])] return [Subtensor(node.op.idx_list)(*node.inputs[1:])]
@register_canonicalize
@gof.local_optimizer([AdvancedIncSubtensor1])
def local_set_to_inc_subtensor(node):
if (isinstance(node.op, AdvancedIncSubtensor1) and
node.op.set_instead_of_inc == True and
node.inputs[1].owner and
isinstance(node.inputs[1].owner.op, Elemwise) and
isinstance(node.inputs[1].owner.op.scalar_op, scalar.Add)):
addn = node.inputs[1].owner
subn = None
other = None
if (addn.inputs[0].owner and
isinstance(addn.inputs[0].owner.op, AdvancedSubtensor1)):
subn = addn.inputs[0].owner
other = addn.inputs[1]
elif (addn.inputs[1].owner and
isinstance(addn.inputs[1].owner.op, AdvancedSubtensor1)):
subn = addn.inputs[1].owner
other = addn.inputs[0]
else:
return
if (subn.inputs[1] != node.inputs[2] or
subn.inputs[0] != node.inputs[0]):
return
return [advanced_inc_subtensor1(node.inputs[0], other, node.inputs[2])]
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([Subtensor]) @gof.local_optimizer([Subtensor])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论