提交 e97557c8 authored 作者: nouiz's avatar nouiz

Merge pull request #1283 from mrocklin/matinv-solve

inv_as_solve optimization checks if op is Dot
......@@ -5,7 +5,8 @@ import numpy
from theano.gof import Op, Apply
from theano.tensor import as_tensor_variable, dot, DimShuffle
from theano.tensor import as_tensor_variable, dot, DimShuffle, Dot
from theano.tensor.blas import Dot22
from theano import tensor
import theano.tensor
from theano.tensor.opt import (register_stabilize,
......@@ -227,7 +228,7 @@ def is_positive(v):
def inv_as_solve(node):
if not imported_scipy:
return False
if node.op == dot:
if isinstance(node.op, (Dot, Dot22)):
l, r = node.inputs
if l.owner and l.owner.op == matrix_inverse:
return [solve(l.owner.inputs[0], r)]
......
......@@ -28,6 +28,7 @@ from theano.sandbox.linalg.ops import (cholesky,
spectral_radius_bound,
imported_scipy,
Eig,
inv_as_solve,
)
from theano.sandbox.linalg import eig, eigh
from nose.plugins.skip import SkipTest
......@@ -533,3 +534,10 @@ class test_Eigh(test_Eig):
class test_Eigh_float32(test_Eigh):
dtype = 'float32'
def test_matrix_inverse_solve():
A = theano.tensor.dmatrix('A')
b = theano.tensor.dmatrix('b')
node = matrix_inverse(A).dot(b).owner
[out] = inv_as_solve.transform(node)
assert isinstance(out.owner.op, Solve)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论