提交 c787e96e authored 作者: Frederic's avatar Frederic

Make ExtractDiag work with not square matrix.

上级 b87e163f
...@@ -417,8 +417,12 @@ class ExtractDiag(Op): ...@@ -417,8 +417,12 @@ class ExtractDiag(Op):
z, = outs z, = outs
#for some reason numpy.diag(x) is really slow #for some reason numpy.diag(x) is really slow
N,M = x.shape N,M = x.shape
assert N==M
if x.shape[0] < x.shape [1]:
rval = x[:,0]
else:
rval = x[0] rval = x[0]
rval.strides = (x.strides[0]+x.strides[1],) rval.strides = (x.strides[0]+x.strides[1],)
if self.view: if self.view:
z[0] = rval z[0] = rval
...@@ -433,7 +437,8 @@ class ExtractDiag(Op): ...@@ -433,7 +437,8 @@ class ExtractDiag(Op):
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
x_s, = shapes x_s, = shapes
return [(x_s[0],)] shp = tensor.min(node.inputs[0].shape)
return [(shp,)]
extract_diag = ExtractDiag() extract_diag = ExtractDiag()
#TODO: optimization to insert ExtractDiag with view=True #TODO: optimization to insert ExtractDiag with view=True
......
...@@ -103,20 +103,14 @@ def test_extract_diag(): ...@@ -103,20 +103,14 @@ def test_extract_diag():
g = extract_diag(x) g = extract_diag(x)
f = theano.function([x], g) f = theano.function([x], g)
m = rng.rand(3,3).astype(config.floatX) for shp in [(2, 3), (3, 2), (3, 3)]:
m = rng.rand(*shp).astype(config.floatX)
v = numpy.diag(m) v = numpy.diag(m)
r = f(m) r = f(m)
# The right diagonal is extracted # The right diagonal is extracted
assert (r == v).all() assert (r == v).all()
m = rng.rand(2, 3).astype(config.floatX) # Test we accept only matrix
ok = False
try:
r = f(m)
except Exception:
ok = True
assert ok
xx = theano.tensor.vector() xx = theano.tensor.vector()
ok = False ok = False
try: try:
...@@ -125,11 +119,14 @@ def test_extract_diag(): ...@@ -125,11 +119,14 @@ def test_extract_diag():
ok = True ok = True
assert ok assert ok
# Test infer_shape
f = theano.function([x], g.shape) f = theano.function([x], g.shape)
topo = f.maker.env.toposort() topo = f.maker.env.toposort()
if config.mode != 'FAST_COMPILE':
assert sum([node.op.__class__ == ExtractDiag for node in topo]) == 0 assert sum([node.op.__class__ == ExtractDiag for node in topo]) == 0
m = rng.rand(3,3).astype(config.floatX) for shp in [(2, 3), (3, 2), (3, 3)]:
assert f(m) == 3 m = rng.rand(*shp).astype(config.floatX)
assert f(m) == min(shp)
# not testing the view=True case since it is not used anywhere. # not testing the view=True case since it is not used anywhere.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论