提交 6bc14189 authored 作者: nouiz's avatar nouiz

Merge pull request #147 from jaberg/transpose_dot_opt

Transpose dot opt
...@@ -298,16 +298,24 @@ class Env(utils.object2): ...@@ -298,16 +298,24 @@ class Env(utils.object2):
if node == 'output': if node == 'output':
r = self.outputs[i] r = self.outputs[i]
if not r.type == new_r.type: if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Variable.", r, new_r) raise TypeError("The type of the replacement must be the"
" same as the type of the original Variable.",
r, new_r)
self.outputs[i] = new_r self.outputs[i] = new_r
else: else:
if node.env is not self: if node.env is not self:
raise Exception("Cannot operate on %s because it does not belong to this Env" % node) raise Exception("Cannot operate on %s because it does not"
" belong to this Env" % node)
r = node.inputs[i] r = node.inputs[i]
if not r.type == new_r.type: if not r.type == new_r.type:
raise TypeError("The type of the replacement must be the same as the type of the original Variable.", r, new_r) raise TypeError("The type of the replacement must be the"
" same as the type of the original Variable.",
r, new_r)
node.inputs[i] = new_r node.inputs[i] = new_r
if r is new_r:
return
self.__import_r__([new_r]) self.__import_r__([new_r])
self.__add_clients__(new_r, [(node, i)]) self.__add_clients__(new_r, [(node, i)])
prune = self.__remove_clients__(r, [(node, i)], False) prune = self.__remove_clients__(r, [(node, i)], False)
......
...@@ -651,6 +651,7 @@ class PatternSub(LocalOptimizer): ...@@ -651,6 +651,7 @@ class PatternSub(LocalOptimizer):
def skip_identities(self, expr): def skip_identities(self, expr):
if self.skip_identities_fn: if self.skip_identities_fn:
return self.skip_identities_fn(expr) return self.skip_identities_fn(expr)
def op_key(self): def op_key(self):
return self.op return self.op
......
...@@ -137,6 +137,7 @@ class DimShuffle(Op): ...@@ -137,6 +137,7 @@ class DimShuffle(Op):
d = dict(self.__dict__) d = dict(self.__dict__)
del d['_hashval'] del d['_hashval']
return d return d
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
self._rehash() self._rehash()
...@@ -218,13 +219,11 @@ class DimShuffle(Op): ...@@ -218,13 +219,11 @@ class DimShuffle(Op):
rval.insert(augm, 1) rval.insert(augm, 1)
return [rval] return [rval]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if None in eval_points: if None in eval_points:
return [None] return [None]
return self.make_node(*eval_points).outputs return self.make_node(*eval_points).outputs
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
input, = inp input, = inp
res, = out res, = out
......
...@@ -87,6 +87,9 @@ def broadcast_like(value, template, env, dtype=None): ...@@ -87,6 +87,9 @@ def broadcast_like(value, template, env, dtype=None):
filled by broadcasting value through it. `value` will be casted as necessary. filled by broadcasting value through it. `value` will be casted as necessary.
""" """
value = T.as_tensor_variable(value)
if value.type == template.type:
return value
shape_of = env.shape_feature.shape_of shape_of = env.shape_feature.shape_of
if template not in shape_of: if template not in shape_of:
raise NotImplementedError('broadcast_like currently requires the template Variable to be in the env already') raise NotImplementedError('broadcast_like currently requires the template Variable to be in the env already')
...@@ -331,26 +334,31 @@ def local_dimshuffle_lift(node): ...@@ -331,26 +334,31 @@ def local_dimshuffle_lift(node):
else: else:
return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs return DimShuffle(iinput.type.broadcastable, new_order, inplace).make_node(iinput).outputs
## dot(x,y).T -> dot(y.T, x.T)
# These optimizations "lift" (propagate towards the inputs) DimShuffle @register_canonicalize
# through dot product. It allows to put the graph in a more standard shape, @gof.local_optimizer([])
# and to later merge consecutive DimShuffles. def local_lift_transpose_through_dot(node):
inplace_matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=True) """
matrix_transpose = T.DimShuffle([False,False], [1,0], inplace=False) dot(x,y).T -> dot(y.T, x.T)
# The transformation should be apply whether or not the transpose is inplace.
# The newly-introduced transpositions are not inplace, this will be taken care These optimizations "lift" (propagate towards the inputs) DimShuffle
# of in a later optimization phase. through dot product. It allows to put the graph in a more standard shape,
# First optimization: inplace and to later merge consecutive DimShuffles.
local_transposed_dot_inplace = gof.PatternSub(
(inplace_matrix_transpose, (T.dot, 'x', 'y')), The transformation should be apply whether or not the transpose is
(T.dot, (matrix_transpose, 'y'), (matrix_transpose, 'x'))) inplace. The newly-introduced transpositions are not inplace, this will
# Second optimization: not inplace be taken care of in a later optimization phase.
local_transposed_dot = gof.PatternSub( """
(matrix_transpose, (T.dot, 'x', 'y')), if not (isinstance(node.op, T.DimShuffle)
(T.dot, (matrix_transpose, 'y'), (matrix_transpose, 'x'))) and node.op.new_order == (1, 0)):
# Register in the canonization phase only return False
register_canonicalize(local_transposed_dot_inplace, name='local_transposed_dot_inplace') if not (node.inputs[0].owner and node.inputs[0].owner.op == T.dot):
register_canonicalize(local_transposed_dot, name='local_transposed_dot') return False
x, y = node.inputs[0].owner.inputs
if x.ndim == y.ndim == 2:
return [T.dot(y.T, x.T)]
@gof.local_optimizer([]) @gof.local_optimizer([])
def dimshuffle_as_view(node): def dimshuffle_as_view(node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论