提交 87d414e4 authored 作者: David Warde-Farley's avatar David Warde-Farley

Merge pull request #148 from nouiz/fix_extract_diag

Fix extract diag
......@@ -463,6 +463,7 @@ solve = Solve() # general solve
#TODO: Optimizations to replace multiplication by matrix inverse with solve() Op (still unwritten)
class ExtractDiag(Op):
""" Return the diagonal of matrix """
def __init__(self, view=False):
self.view = view
if self.view:
......@@ -485,8 +486,12 @@ class ExtractDiag(Op):
z, = outs
#for some reason numpy.diag(x) is really slow
N,M = x.shape
assert N==M
rval = x[0]
if x.shape[0] < x.shape [1]:
rval = x[:,0]
else:
rval = x[0]
rval.strides = (x.strides[0]+x.strides[1],)
if self.view:
z[0] = rval
......@@ -501,7 +506,8 @@ class ExtractDiag(Op):
def infer_shape(self, node, shapes):
x_s, = shapes
return [(x_s[0],)]
shp = tensor.min(node.inputs[0].shape)
return [(shp,)]
extract_diag = ExtractDiag()
#TODO: optimization to insert ExtractDiag with view=True
......
......@@ -182,20 +182,14 @@ def test_extract_diag():
g = extract_diag(x)
f = theano.function([x], g)
m = rng.rand(3,3).astype(config.floatX)
v = numpy.diag(m)
r = f(m)
# The right diagonal is extracted
assert (r == v).all()
m = rng.rand(2, 3).astype(config.floatX)
ok = False
try:
for shp in [(2, 3), (3, 2), (3, 3)]:
m = rng.rand(*shp).astype(config.floatX)
v = numpy.diag(m)
r = f(m)
except Exception:
ok = True
assert ok
# The right diagonal is extracted
assert (r == v).all()
# Test we accept only matrix
xx = theano.tensor.vector()
ok = False
try:
......@@ -204,11 +198,14 @@ def test_extract_diag():
ok = True
assert ok
# Test infer_shape
f = theano.function([x], g.shape)
topo = f.maker.env.toposort()
assert sum([node.op.__class__ == ExtractDiag for node in topo]) == 0
m = rng.rand(3,3).astype(config.floatX)
assert f(m) == 3
if config.mode != 'FAST_COMPILE':
assert sum([node.op.__class__ == ExtractDiag for node in topo]) == 0
for shp in [(2, 3), (3, 2), (3, 3)]:
m = rng.rand(*shp).astype(config.floatX)
assert f(m) == min(shp)
# not testing the view=True case since it is not used anywhere.
......@@ -219,9 +216,10 @@ def test_trace():
g = trace(x)
f = theano.function([x], g)
m = rng.rand(4, 4).astype(config.floatX)
v = numpy.trace(m)
assert v == f(m)
for shp in [(2, 3), (3, 2), (3, 3)]:
m = rng.rand(*shp).astype(config.floatX)
v = numpy.trace(m)
assert v == f(m)
xx = theano.tensor.vector()
ok = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论