提交 648631eb authored 作者: Matthew Rocklin's avatar Matthew Rocklin

inv_as_solve optimization checks if op is Dot

Previously it checked if node.op == dot where dot was a dot function I assume that this was a mistake I've changed this to check if isinstance(node.op, (Dot, Dot22)) I've also added a test
上级 517540f5
...@@ -5,7 +5,8 @@ import numpy ...@@ -5,7 +5,8 @@ import numpy
from theano.gof import Op, Apply from theano.gof import Op, Apply
from theano.tensor import as_tensor_variable, dot, DimShuffle from theano.tensor import as_tensor_variable, dot, DimShuffle, Dot
from theano.tensor.blas import Dot22
from theano import tensor from theano import tensor
import theano.tensor import theano.tensor
from theano.tensor.opt import (register_stabilize, from theano.tensor.opt import (register_stabilize,
...@@ -227,7 +228,7 @@ def is_positive(v): ...@@ -227,7 +228,7 @@ def is_positive(v):
def inv_as_solve(node): def inv_as_solve(node):
if not imported_scipy: if not imported_scipy:
return False return False
if node.op == dot: if isinstance(node.op, (Dot, Dot22)):
l, r = node.inputs l, r = node.inputs
if l.owner and l.owner.op == matrix_inverse: if l.owner and l.owner.op == matrix_inverse:
return [solve(l.owner.inputs[0], r)] return [solve(l.owner.inputs[0], r)]
......
...@@ -28,6 +28,7 @@ from theano.sandbox.linalg.ops import (cholesky, ...@@ -28,6 +28,7 @@ from theano.sandbox.linalg.ops import (cholesky,
spectral_radius_bound, spectral_radius_bound,
imported_scipy, imported_scipy,
Eig, Eig,
inv_as_solve,
) )
from theano.sandbox.linalg import eig, eigh from theano.sandbox.linalg import eig, eigh
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
...@@ -533,3 +534,10 @@ class test_Eigh(test_Eig): ...@@ -533,3 +534,10 @@ class test_Eigh(test_Eig):
class test_Eigh_float32(test_Eigh): class test_Eigh_float32(test_Eigh):
dtype = 'float32' dtype = 'float32'
def test_matrix_inverse_solve():
A = theano.tensor.dmatrix('A')
b = theano.tensor.dmatrix('b')
node = matrix_inverse(A).dot(b).owner
[out] = inv_as_solve.transform(node)
assert isinstance(out.owner.op, Solve)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论