提交 59ce6fa9 authored 作者: James Bergstra's avatar James Bergstra

replacing transposed dot patternsub with local opt

The old form (PatternSub) had the problem that it wasn't calling broadcast_like on the return value. It would have been wrong to hack a broadcast_like into gof.opt, so I rewrote these two optimizations as a function. PatternSub is a really cool idea, in the future it might be worth writing a TensorPatternSub that deals with tensor-specific issues, like casting, dimshuffles, and rebroadcasts.
上级 1f72c472
......@@ -334,26 +334,31 @@ def local_dimshuffle_lift(node):
else:
return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs
## dot(x,y).T -> dot(y.T, x.T)
# These optimizations "lift" (propagate towards the inputs) DimShuffle
# through dot product. It allows to put the graph in a more standard shape,
# and to later merge consecutive DimShuffles.
inplace_matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=True)
matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=False)
# The transformation should be apply whether or not the transpose is inplace.
# The newly-introduced transpositions are not inplace, this will be taken care
# of in a later optimization phase.
# First optimization: inplace
local_transposed_dot_inplace = gof.PatternSub(
(inplace_matrix_transpose, (T.dot, 'x', 'y')),
(T.dot, (matrix_transpose, 'y'), (matrix_transpose, 'x')))
# Second optimization: not inplace
local_transposed_dot = gof.PatternSub(
(matrix_transpose, (T.dot, 'x', 'y')),
(T.dot, (matrix_transpose, 'y'), (matrix_transpose, 'x')))
# Register in the canonization phase only
register_canonicalize(local_transposed_dot_inplace, name='local_transposed_dot_inplace')
register_canonicalize(local_transposed_dot, name='local_transposed_dot')
@register_canonicalize
@gof.local_optimizer([])
def local_lift_transpose_through_dot(node):
"""
dot(x,y).T -> dot(y.T, x.T)
These optimizations "lift" (propagate towards the inputs) DimShuffle
through dot product. It allows to put the graph in a more standard shape,
and to later merge consecutive DimShuffles.
The transformation should be apply whether or not the transpose is
inplace. The newly-introduced transpositions are not inplace, this will
be taken care of in a later optimization phase.
"""
if not (isinstance(node.op, T.DimShuffle)
and node.op.new_order == (1, 0)):
return False
if not (node.inputs[0].owner and node.inputs[0].owner.op == T.dot):
return False
x, y = node.inputs[0].owner.inputs
if x.ndim == y.ndim == 2:
return [broadcast_like(T.dot(y.T, x.T), node.outputs[0], node.env)]
@gof.local_optimizer([])
def dimshuffle_as_view(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论