提交 96ce70fa authored 作者: James Bergstra's avatar James Bergstra

added test for lift_transpose_through_dot

上级 87876db3
......@@ -3088,6 +3088,19 @@ def test_local_div_to_inv():
assert numpy.allclose(out_val, 0.5)
class Test_lift_transpose_through_dot(unittest.TestCase):
def optimize(self, g):
out2in(opt.local_useless_elemwise).optimize(g)
out2in(opt.local_lift_transpose_through_dot).optimize(g)
out2in(opt.local_useless_elemwise).optimize(g)
return g
def test_matrix_matrix(self):
a, b = matrices('ab')
g = self.optimize(Env([a, b], [tensor.dot(a, b).T]))
sg = '[dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{1,0}(a))]'
assert str(g) == sg
if __name__ == '__main__':
# unittest.main()
test_fusion().tes_memory_leak()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论