提交 05c589dd authored 作者: Frederic's avatar Frederic

Unfuse AllocEmpty with CGemv when useful

上级 1d491001
...@@ -799,7 +799,14 @@ def use_c_gemv(node): ...@@ -799,7 +799,14 @@ def use_c_gemv(node):
@local_optimizer([CGemv(inplace=False)]) @local_optimizer([CGemv(inplace=False)])
def make_c_gemv_destructive(node): def make_c_gemv_destructive(node):
if isinstance(node.op, CGemv) and not node.op.inplace: if isinstance(node.op, CGemv) and not node.op.inplace:
return [cgemv_inplace(*node.inputs)] inputs = list(node.inputs)
dest = inputs[0]
if (dest.owner and
isinstance(dest.owner.op, T.AllocEmpty) and
len(dest.clients) > 1):
inputs[0] = T.AllocEmpty(dest.dtype)(*dest.owner.inputs)
return [cgemv_inplace(*inputs)]
# ##### ####### ####### # ##### ####### #######
......
...@@ -256,6 +256,21 @@ class TestCGemv(TestCase, TestOptimizationMixin): ...@@ -256,6 +256,21 @@ class TestCGemv(TestCase, TestOptimizationMixin):
self.assertRaises(ValueError, f, A_val, ones_3, ones_6) self.assertRaises(ValueError, f, A_val, ones_3, ones_6)
self.assertRaises(ValueError, f, A_val, ones_4, ones_6) self.assertRaises(ValueError, f, A_val, ones_4, ones_6)
def test_multiple_inplace(self):
x = tensor.dmatrix('x')
y = tensor.dvector('y')
z = tensor.dvector('z')
f = theano.function([x, y, z],
[tensor.dot(y, x), tensor.dot(z,x)])
vx = numpy.random.rand(3, 3)
vy = numpy.random.rand(3)
vz = numpy.random.rand(3)
out = f(vx, vy, vz)
assert numpy.allclose(out[0], numpy.dot(vy, vx))
assert numpy.allclose(out[1], numpy.dot(vz, vx))
assert len([n for n in f.maker.fgraph.apply_nodes
if isinstance(n.op, tensor.AllocEmpty)]) == 2
class TestCGemvFloat32(TestCase, BaseGemv, TestOptimizationMixin): class TestCGemvFloat32(TestCase, BaseGemv, TestOptimizationMixin):
mode = mode_blas_opt mode = mode_blas_opt
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论