提交 de695bf1 authored 作者: James Bergstra's avatar James Bergstra

test lift_transpose_through_dot with broadcastable

上级 250e0242
...@@ -357,7 +357,7 @@ def local_lift_transpose_through_dot(node): ...@@ -357,7 +357,7 @@ def local_lift_transpose_through_dot(node):
x, y = node.inputs[0].owner.inputs x, y = node.inputs[0].owner.inputs
if x.ndim == y.ndim == 2: if x.ndim == y.ndim == 2:
return [broadcast_like(T.dot(y.T, x.T), node.outputs[0], node.env)] return [T.dot(y.T, x.T)]
@gof.local_optimizer([]) @gof.local_optimizer([])
......
...@@ -53,6 +53,27 @@ mode_opt = theano.compile.mode.get_mode(mode_opt) ...@@ -53,6 +53,27 @@ mode_opt = theano.compile.mode.get_mode(mode_opt)
ds = lambda x, y: DimShuffle(x.type.broadcastable, y)(x) ds = lambda x, y: DimShuffle(x.type.broadcastable, y)(x)
dimshuffle_lift = out2in(local_dimshuffle_lift) dimshuffle_lift = out2in(local_dimshuffle_lift)
_optimizer_stabilize = gof.Query(include=['fast_run'])
_optimizer_stabilize.position_cutoff = 1.51
_optimizer_stabilize = compile.optdb.query(_optimizer_stabilize)
_optimizer_specialize = gof.Query(include=['fast_run'])
_optimizer_specialize.position_cutoff = 2.01
_optimizer_specialize = compile.optdb.query(_optimizer_specialize)
_optimizer_fast_run = gof.Query(include=['fast_run'])
_optimizer_fast_run = compile.optdb.query(_optimizer_fast_run)
def optimize(g, level='fast_run'):
if 'fast_run' is level:
_optimizer_fast_run.optimize(g)
elif 'specialize' is level:
_optimizer_specialize.optimize(g)
elif 'stabilize' is level:
_optimizer_stabilize.optimize(g)
else:
raise ValueError(level)
return g
def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)): def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
x = TensorType(broadcastable = xbc, dtype = 'float64')('x') x = TensorType(broadcastable = xbc, dtype = 'float64')('x')
...@@ -3089,7 +3110,7 @@ def test_local_div_to_inv(): ...@@ -3089,7 +3110,7 @@ def test_local_div_to_inv():
class Test_lift_transpose_through_dot(unittest.TestCase): class Test_lift_transpose_through_dot(unittest.TestCase):
def optimize(self, g): def simple_optimize(self, g):
out2in(opt.local_useless_elemwise).optimize(g) out2in(opt.local_useless_elemwise).optimize(g)
out2in(opt.local_lift_transpose_through_dot).optimize(g) out2in(opt.local_lift_transpose_through_dot).optimize(g)
out2in(opt.local_useless_elemwise).optimize(g) out2in(opt.local_useless_elemwise).optimize(g)
...@@ -3097,10 +3118,31 @@ class Test_lift_transpose_through_dot(unittest.TestCase): ...@@ -3097,10 +3118,31 @@ class Test_lift_transpose_through_dot(unittest.TestCase):
def test_matrix_matrix(self): def test_matrix_matrix(self):
a, b = matrices('ab') a, b = matrices('ab')
g = self.optimize(Env([a, b], [tensor.dot(a, b).T])) g = self.simple_optimize(Env([a, b], [tensor.dot(a, b).T]))
sg = '[dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{1,0}(a))]' sg = '[dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{1,0}(a))]'
assert str(g) == sg assert str(g) == sg
def test_row_matrix(self):
a = vector('a')
b = matrix('b')
g = optimize(Env(
[a, b],
[tensor.dot(a.dimshuffle('x', 0), b).T]),
level='stabilize')
sg = '[dot(DimShuffle{1,0}(b), DimShuffle{0,x}(a))]'
assert str(g) == sg
def test_matrix_col(self):
a = vector('a')
b = matrix('b')
g = optimize(Env(
[a, b],
[tensor.dot(b, a.dimshuffle(0, 'x')).T]),
level='stabilize')
sg = '[dot(DimShuffle{x,0}(a), DimShuffle{1,0}(b))]'
assert str(g) == sg
if __name__ == '__main__': if __name__ == '__main__':
# unittest.main() # unittest.main()
test_fusion().tes_memory_leak() test_fusion().tes_memory_leak()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论