提交 59ce6fa9 authored 作者: James Bergstra's avatar James Bergstra

replacing transposed dot patternsub with local opt

The old form (PatternSub) had the problem that it wasn't calling broadcast_like on the return value. It would have been wrong to hack a broadcast_like into gof.opt, so I rewrote these two optimizations as a function. PatternSub is a really cool idea, in the future it might be worth writing a TensorPatternSub that deals with tensor-specific issues, like casting, dimshuffles, and rebroadcasts.
上级 1f72c472
...@@ -334,26 +334,31 @@ def local_dimshuffle_lift(node): ...@@ -334,26 +334,31 @@ def local_dimshuffle_lift(node):
else: else:
return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs
## dot(x,y).T -> dot(y.T, x.T)
# These optimizations "lift" (propagate towards the inputs) DimShuffle @register_canonicalize
# through dot product. It allows to put the graph in a more standard shape, @gof.local_optimizer([])
# and to later merge consecutive DimShuffles. def local_lift_transpose_through_dot(node):
inplace_matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=True) """
matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=False) dot(x,y).T -> dot(y.T, x.T)
# The transformation should be apply whether or not the transpose is inplace.
# The newly-introduced transpositions are not inplace, this will be taken care These optimizations "lift" (propagate towards the inputs) DimShuffle
# of in a later optimization phase. through dot product. It allows to put the graph in a more standard shape,
# First optimization: inplace and to later merge consecutive DimShuffles.
local_transposed_dot_inplace = gof.PatternSub(
(inplace_matrix_transpose, (T.dot, 'x', 'y')), The transformation should be apply whether or not the transpose is
(T.dot, (matrix_transpose, 'y'), (matrix_transpose, 'x'))) inplace. The newly-introduced transpositions are not inplace, this will
# Second optimization: not inplace be taken care of in a later optimization phase.
local_transposed_dot = gof.PatternSub( """
(matrix_transpose, (T.dot, 'x', 'y')), if not (isinstance(node.op, T.DimShuffle)
(T.dot, (matrix_transpose, 'y'), (matrix_transpose, 'x'))) and node.op.new_order == (1, 0)):
# Register in the canonization phase only return False
register_canonicalize(local_transposed_dot_inplace, name='local_transposed_dot_inplace') if not (node.inputs[0].owner and node.inputs[0].owner.op == T.dot):
register_canonicalize(local_transposed_dot, name='local_transposed_dot') return False
x, y = node.inputs[0].owner.inputs
if x.ndim == y.ndim == 2:
return [broadcast_like(T.dot(y.T, x.T), node.outputs[0], node.env)]
@gof.local_optimizer([]) @gof.local_optimizer([])
def dimshuffle_as_view(node): def dimshuffle_as_view(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论