提交 1e481c32 authored 作者: James Bergstra's avatar James Bergstra

gemm - changed handling of aliased input and output.

Made the type of exception raised in that case an InconsistencyError (was ValueError) and made the GemmOptimizer catch that exception when trying to insert gemms into a graph.
上级 db81fe35
...@@ -389,9 +389,9 @@ class Gemm(GemmRelated): ...@@ -389,9 +389,9 @@ class Gemm(GemmRelated):
z, a, x, y, b = inputs z, a, x, y, b = inputs
zr, xr, yr = [set(view_roots(i)) for i in z,x,y] zr, xr, yr = [set(view_roots(i)) for i in z,x,y]
if zr.intersection(xr): if zr.intersection(xr):
raise ValueError(Gemm.E_z_uniq, (z, x)) raise InconsistencyError(Gemm.E_z_uniq, (z, x))
if zr.intersection(yr): if zr.intersection(yr):
raise ValueError(Gemm.E_z_uniq, (z, y)) raise InconsistencyError(Gemm.E_z_uniq, (z, y))
bz, ba, bx, by, bb = [r.type.broadcastable for r in inputs] bz, ba, bx, by, bb = [r.type.broadcastable for r in inputs]
if bz != (False,False): raise ValueError(Gemm.E_rank, bz) if bz != (False,False): raise ValueError(Gemm.E_rank, bz)
if bx != (False,False): raise ValueError(Gemm.E_rank, bx) if bx != (False,False): raise ValueError(Gemm.E_rank, bx)
...@@ -784,7 +784,10 @@ class GemmOptimizer(Optimizer): ...@@ -784,7 +784,10 @@ class GemmOptimizer(Optimizer):
nodelist.reverse() nodelist.reverse()
for node in nodelist: for node in nodelist:
#new_outputs = _gemm_from_node(node) #new_outputs = _gemm_from_node(node)
new_outputs = _gemm_from_node2(node) try:
new_outputs = _gemm_from_node2(node)
except InconsistencyError, e:
continue
if new_outputs: if new_outputs:
assert len(new_outputs) == len(node.outputs) assert len(new_outputs) == len(node.outputs)
try: try:
......
...@@ -139,7 +139,7 @@ class t_gemm(TestCase): ...@@ -139,7 +139,7 @@ class t_gemm(TestCase):
Z = as_tensor_variable(self.rand(2,2)) Z = as_tensor_variable(self.rand(2,2))
try: try:
gemm_inplace(Z, 1.0, Z, Z, 1.0) gemm_inplace(Z, 1.0, Z, Z, 1.0)
except ValueError, e: except InconsistencyError, e:
if e[0] == Gemm.E_z_uniq: if e[0] == Gemm.E_z_uniq:
return return
self.fail() self.fail()
...@@ -149,7 +149,7 @@ class t_gemm(TestCase): ...@@ -149,7 +149,7 @@ class t_gemm(TestCase):
A = as_tensor_variable(self.rand(2,2)) A = as_tensor_variable(self.rand(2,2))
try: try:
gemm_inplace(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0) gemm_inplace(Z, 1.0, A, inplace.transpose_inplace(Z), 1.0)
except ValueError, e: except InconsistencyError, e:
if e[0] == Gemm.E_z_uniq: if e[0] == Gemm.E_z_uniq:
return return
self.fail() self.fail()
...@@ -159,7 +159,7 @@ class t_gemm(TestCase): ...@@ -159,7 +159,7 @@ class t_gemm(TestCase):
A = as_tensor_variable(self.rand(2,2)) A = as_tensor_variable(self.rand(2,2))
try: try:
gemm_inplace(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0) gemm_inplace(Z, 1.0, inplace.transpose_inplace(Z), A, 1.0)
except ValueError, e: except InconsistencyError, e:
if e[0] == Gemm.E_z_uniq: if e[0] == Gemm.E_z_uniq:
return return
self.fail() self.fail()
...@@ -169,7 +169,7 @@ class t_gemm(TestCase): ...@@ -169,7 +169,7 @@ class t_gemm(TestCase):
A = as_tensor_variable(self.rand(2,2)) A = as_tensor_variable(self.rand(2,2))
try: try:
gemm_inplace(Z, 1.0, Z, A, 1.0) gemm_inplace(Z, 1.0, Z, A, 1.0)
except ValueError, e: except InconsistencyError, e:
if e[0] == Gemm.E_z_uniq: if e[0] == Gemm.E_z_uniq:
return return
self.fail() self.fail()
...@@ -614,3 +614,19 @@ def test_dot22scalar(): ...@@ -614,3 +614,19 @@ def test_dot22scalar():
assert _dot22scalar in [x.op for x in topo] assert _dot22scalar in [x.op for x in topo]
assert len(topo)==2 assert len(topo)==2
f(av,bv,cv) f(av,bv,cv)
def test_dot_w_self():
# This can trigger problems in the optimization because what would normally be a gemm must
# not be because the output is aliased to one of the inputs.
A = shared(value = numpy.ones((2,2)))
B = T.matrix()
p = T.dot(A,A)*B
grad = T.grad(T.mean(p),[A])
f = theano.function([B], p, updates = { A : A - grad[0]} )
# tests correctness in debugmode
f(numpy.asarray([[0,1], [2,3]]))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论