提交 6bc14189 authored 作者: nouiz's avatar nouiz

Merge pull request #147 from jaberg/transpose_dot_opt

Transpose dot opt
......@@ -298,16 +298,24 @@ class Env(utils.object2):
if node == 'output':
r = self.outputs[i]
if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Variable.", r, new_r)
raise TypeError("The type of the replacement must be the"
" same as the type of the original Variable.",
r, new_r)
self.outputs[i] = new_r
else:
if node.env is not self:
raise Exception("Cannot operate on %s because it does not belong to this Env" % node)
raise Exception("Cannot operate on %s because it does not"
" belong to this Env" % node)
r = node.inputs[i]
if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Variable.", r, new_r)
raise TypeError("The type of the replacement must be the"
" same as the type of the original Variable.",
r, new_r)
node.inputs[i] = new_r
if r is new_r:
return
self.__import_r__([new_r])
self.__add_clients__(new_r, [(node, i)])
prune = self.__remove_clients__(r, [(node, i)], False)
......
......@@ -651,6 +651,7 @@ class PatternSub(LocalOptimizer):
def skip_identities(self, expr):
if self.skip_identities_fn:
return self.skip_identities_fn(expr)
def op_key(self):
return self.op
......
......@@ -137,6 +137,7 @@ class DimShuffle(Op):
d = dict(self.__dict__)
del d['_hashval']
return d
def __setstate__(self, d):
self.__dict__.update(d)
self._rehash()
......@@ -218,13 +219,11 @@ class DimShuffle(Op):
rval.insert(augm, 1)
return [rval]
def R_op(self, inputs, eval_points):
if None in eval_points:
return [None]
return self.make_node(*eval_points).outputs
def c_code(self, node, name, inp, out, sub):
input, = inp
res, = out
......
......@@ -87,6 +87,9 @@ def broadcast_like(value, template, env, dtype=None):
filled by broadcasting value through it. `value` will be casted as necessary.
"""
value = T.as_tensor_variable(value)
if value.type == template.type:
return value
shape_of = env.shape_feature.shape_of
if template not in shape_of:
raise NotImplementedError('broadcast_like currently requires the template Variable to be in the env already')
......@@ -331,26 +334,31 @@ def local_dimshuffle_lift(node):
else:
return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs
## dot(x,y).T -> dot(y.T, x.T)
# These optimizations "lift" (propagate towards the inputs) DimShuffle
# through dot product. It allows to put the graph in a more standard shape,
# and to later merge consecutive DimShuffles.
inplace_matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=True)
matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=False)
# The transformation should be apply whether or not the transpose is inplace.
# The newly-introduced transpositions are not inplace, this will be taken care
# of in a later optimization phase.
# First optimization: inplace
local_transposed_dot_inplace = gof.PatternSub(
(inplace_matrix_transpose, (T.dot, 'x', 'y')),
(T.dot, (matrix_transpose, 'y'), (matrix_transpose, 'x')))
# Second optimization: not inplace
local_transposed_dot = gof.PatternSub(
(matrix_transpose, (T.dot, 'x', 'y')),
(T.dot, (matrix_transpose, 'y'), (matrix_transpose, 'x')))
# Register in the canonization phase only
register_canonicalize(local_transposed_dot_inplace, name='local_transposed_dot_inplace')
register_canonicalize(local_transposed_dot, name='local_transposed_dot')
@register_canonicalize
@gof.local_optimizer([])
def local_lift_transpose_through_dot(node):
"""
dot(x,y).T -> dot(y.T, x.T)
These optimizations "lift" (propagate towards the inputs) DimShuffle
through dot product. It allows to put the graph in a more standard shape,
and to later merge consecutive DimShuffles.
The transformation should be apply whether or not the transpose is
inplace. The newly-introduced transpositions are not inplace, this will
be taken care of in a later optimization phase.
"""
if not (isinstance(node.op, T.DimShuffle)
and node.op.new_order == (1, 0)):
return False
if not (node.inputs[0].owner and node.inputs[0].owner.op == T.dot):
return False
x, y = node.inputs[0].owner.inputs
if x.ndim == y.ndim == 2:
return [T.dot(y.T, x.T)]
@gof.local_optimizer([])
def dimshuffle_as_view(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论