提交 fb182b59 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Move and better comment the transpose-dot lifting optimization.

上级 93b9b4ff
...@@ -283,6 +283,27 @@ def local_dimshuffle_lift(node): ...@@ -283,6 +283,27 @@ def local_dimshuffle_lift(node):
else: else:
return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs
## dot(x,y).T -> dot(y.T, x.T)
# These optimizations "lift" (propagate towards the inputs) DimShuffle
# through dot product. It allows to put the graph in a more standard shape,
# and to later merge consecutive DimShuffles.
inplace_matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=True)
matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=False)
# The transformation should be apply whether or not the transpose is inplace.
# The newly-introduced transpositions are not inplace, this will be taken care
# of in a later optimization phase.
# First optimization: inplace
local_transposed_dot_inplace = gof.PatternSub(
(inplace_matrix_transpose, (T.dot, 'x', 'y')),
(T.dot, (matrix_transpose, 'y'), (matrix_transpose, 'x')))
# Second optimization: not inplace
local_transposed_dot = gof.PatternSub(
(matrix_transpose, (T.dot, 'x', 'y')),
(T.dot, (matrix_transpose, 'y'), (matrix_transpose, 'x')))
# Register in the canonization phase only
register_canonicalize(local_transposed_dot_inplace, name='local_transposed_dot_inplace')
register_canonicalize(local_transposed_dot, name='local_transposed_dot')
@gof.local_optimizer([]) @gof.local_optimizer([])
def dimshuffle_as_view(node): def dimshuffle_as_view(node):
op = node.op op = node.op
...@@ -2824,23 +2845,6 @@ register_canonicalize(constant_folding, 'fast_compile') ...@@ -2824,23 +2845,6 @@ register_canonicalize(constant_folding, 'fast_compile')
register_stabilize(constant_folding) # because register_stabilize(constant_folding) # because
register_specialize(constant_folding) register_specialize(constant_folding)
## dot(x,y).T -> dot(y.T, x.T)
inplace_matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=True)
matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=False)
# The transformation should be apply whether or not the transpose is inplace.
# The newly-introduced transpositions are not inplace, this will be taken care
# of in a later optimization phase.
# First optimization: inplace
local_transposed_dot_inplace = gof.PatternSub(
(inplace_matrix_transpose, (T.dot, 'x', 'y')),
(T.dot, (matrix_transpose, 'y'), (matrix_transpose, 'x')))
register_canonicalize(local_transposed_dot_inplace, name='local_transposed_dot_inplace')
# Second optimization: not inplace
local_transposed_dot = gof.PatternSub(
(matrix_transpose, (T.dot, 'x', 'y')),
(T.dot, (matrix_transpose, 'y'), (matrix_transpose, 'x')))
register_canonicalize(local_transposed_dot, name='local_transposed_dot')
def _is_1(expr): def _is_1(expr):
"""rtype bool. True iff expr is a constant close to 1 """rtype bool. True iff expr is a constant close to 1
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论