提交 4a98f1e1 authored 作者: amrithasuresh's avatar amrithasuresh

Updated numpy as np

上级 dc83dda9
...@@ -3,7 +3,7 @@ import logging ...@@ -3,7 +3,7 @@ import logging
import warnings import warnings
from six.moves import xrange from six.moves import xrange
import numpy import numpy as np
try: try:
import scipy.linalg import scipy.linalg
...@@ -145,7 +145,7 @@ class CholeskyGrad(Op): ...@@ -145,7 +145,7 @@ class CholeskyGrad(Op):
dx = outputs[0] dx = outputs[0]
N = x.shape[0] N = x.shape[0]
if self.lower: if self.lower:
F = numpy.tril(dz) F = np.tril(dz)
for k in xrange(N - 1, -1, -1): for k in xrange(N - 1, -1, -1):
for j in xrange(k + 1, N): for j in xrange(k + 1, N):
for i in xrange(j, N): for i in xrange(j, N):
...@@ -156,7 +156,7 @@ class CholeskyGrad(Op): ...@@ -156,7 +156,7 @@ class CholeskyGrad(Op):
F[k, k] -= L[j, k] * F[j, k] F[k, k] -= L[j, k] * F[j, k]
F[k, k] /= (2 * L[k, k]) F[k, k] /= (2 * L[k, k])
else: else:
F = numpy.triu(dz) F = np.triu(dz)
for k in xrange(N - 1, -1, -1): for k in xrange(N - 1, -1, -1):
for j in xrange(k + 1, N): for j in xrange(k + 1, N):
for i in xrange(j, N): for i in xrange(j, N):
...@@ -206,8 +206,8 @@ class Solve(Op): ...@@ -206,8 +206,8 @@ class Solve(Op):
# infer dtype by solving the most simple # infer dtype by solving the most simple
# case with (1, 1) matrices # case with (1, 1) matrices
o_dtype = scipy.linalg.solve( o_dtype = scipy.linalg.solve(
numpy.eye(1).astype(A.dtype), np.eye(1).astype(A.dtype),
numpy.eye(1).astype(b.dtype)).dtype np.eye(1).astype(b.dtype)).dtype
x = tensor.tensor( x = tensor.tensor(
broadcastable=b.broadcastable, broadcastable=b.broadcastable,
dtype=o_dtype) dtype=o_dtype)
...@@ -370,11 +370,11 @@ class EigvalshGrad(Op): ...@@ -370,11 +370,11 @@ class EigvalshGrad(Op):
assert lower in [True, False] assert lower in [True, False]
self.lower = lower self.lower = lower
if lower: if lower:
self.tri0 = numpy.tril self.tri0 = np.tril
self.tri1 = lambda a: numpy.triu(a, 1) self.tri1 = lambda a: np.triu(a, 1)
else: else:
self.tri0 = numpy.triu self.tri0 = np.triu
self.tri1 = lambda a: numpy.tril(a, -1) self.tri1 = lambda a: np.tril(a, -1)
def make_node(self, a, b, gw): def make_node(self, a, b, gw):
assert imported_scipy, ( assert imported_scipy, (
...@@ -394,14 +394,14 @@ class EigvalshGrad(Op): ...@@ -394,14 +394,14 @@ class EigvalshGrad(Op):
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
(a, b, gw) = inputs (a, b, gw) = inputs
w, v = scipy.linalg.eigh(a, b, lower=self.lower) w, v = scipy.linalg.eigh(a, b, lower=self.lower)
gA = v.dot(numpy.diag(gw).dot(v.T)) gA = v.dot(np.diag(gw).dot(v.T))
gB = - v.dot(numpy.diag(gw * w).dot(v.T)) gB = - v.dot(np.diag(gw * w).dot(v.T))
# See EighGrad comments for an explanation of these lines # See EighGrad comments for an explanation of these lines
out1 = self.tri0(gA) + self.tri1(gA).T out1 = self.tri0(gA) + self.tri1(gA).T
out2 = self.tri0(gB) + self.tri1(gB).T out2 = self.tri0(gB) + self.tri1(gB).T
outputs[0][0] = numpy.asarray(out1, dtype=node.outputs[0].dtype) outputs[0][0] = np.asarray(out1, dtype=node.outputs[0].dtype)
outputs[1][0] = numpy.asarray(out2, dtype=node.outputs[1].dtype) outputs[1][0] = np.asarray(out2, dtype=node.outputs[1].dtype)
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
return [shapes[0], shapes[1]] return [shapes[0], shapes[1]]
...@@ -510,13 +510,13 @@ class ExpmGrad(Op): ...@@ -510,13 +510,13 @@ class ExpmGrad(Op):
w, V = scipy.linalg.eig(A, right=True) w, V = scipy.linalg.eig(A, right=True)
U = scipy.linalg.inv(V).T U = scipy.linalg.inv(V).T
exp_w = numpy.exp(w) exp_w = np.exp(w)
X = numpy.subtract.outer(exp_w, exp_w) / numpy.subtract.outer(w, w) X = np.subtract.outer(exp_w, exp_w) / np.subtract.outer(w, w)
numpy.fill_diagonal(X, exp_w) np.fill_diagonal(X, exp_w)
Y = U.dot(V.T.dot(gA).dot(U) * X).dot(V.T) Y = U.dot(V.T.dot(gA).dot(U) * X).dot(V.T)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore", numpy.ComplexWarning) warnings.simplefilter("ignore", np.ComplexWarning)
out[0] = Y.astype(A.dtype) out[0] = Y.astype(A.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论