提交 8894a3a7 authored 作者: Tanjay94's avatar Tanjay94

Fixed Eigvalsh to support none array as b input.

上级 63b827c6
...@@ -310,7 +310,17 @@ class Eigvalsh(Op): ...@@ -310,7 +310,17 @@ class Eigvalsh(Op):
def make_node(self, a, b): def make_node(self, a, b):
assert imported_scipy, ( assert imported_scipy, (
"Scipy not available. Scipy is needed for the Eigvalsh op") "Scipy not available. Scipy is needed for the Eigvalsh op")
a, b = map(as_tensor_variable, (a, b))
if b == theano.tensor.NoneConst:
a = as_tensor_variable(a)
assert a.ndim == 2
out_dtype = theano.scalar.upcast(a.dtype)
w = theano.tensor.vector(dtype=out_dtype)
return Apply(self, [a], [w])
else:
a = as_tensor_variable(a)
b = as_tensor_variable(b)
assert a.ndim == 2 assert a.ndim == 2
assert b.ndim == 2 assert b.ndim == 2
...@@ -318,8 +328,11 @@ class Eigvalsh(Op): ...@@ -318,8 +328,11 @@ class Eigvalsh(Op):
w = theano.tensor.vector(dtype=out_dtype) w = theano.tensor.vector(dtype=out_dtype)
return Apply(self, [a, b], [w]) return Apply(self, [a, b], [w])
def perform(self, node, (a, b), (w,)): def perform(self, node, inputs, (w,)):
w[0] = scipy.linalg.eigvalsh(a=a, b=b, lower=self.lower) if len(inputs) == 2:
w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=inputs[1], lower=self.lower)
else:
w[0] = scipy.linalg.eigvalsh(a=inputs[0], b=None, lower=self.lower)
def grad(self, inputs, g_outputs): def grad(self, inputs, g_outputs):
a, b = inputs a, b = inputs
...@@ -366,7 +379,9 @@ class EigvalshGrad(Op): ...@@ -366,7 +379,9 @@ class EigvalshGrad(Op):
def make_node(self, a, b, gw): def make_node(self, a, b, gw):
assert imported_scipy, ( assert imported_scipy, (
"Scipy not available. Scipy is needed for the GEigvalsh op") "Scipy not available. Scipy is needed for the GEigvalsh op")
a, b, gw = map(as_tensor_variable, (a, b, gw)) a = as_tensor_variable(a)
b = as_tensor_variable(b)
gw = as_tensor_variable(gw)
assert a.ndim == 2 assert a.ndim == 2
assert b.ndim == 2 assert b.ndim == 2
assert gw.ndim == 1 assert gw.ndim == 1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论