提交 4bb986f9 authored 作者: wonghang's avatar wonghang

remove nan mode, use dimshuffle for outer

上级 9425ec79
...@@ -419,8 +419,14 @@ class GpuCublasTriangularSolve(Op): ...@@ -419,8 +419,14 @@ class GpuCublasTriangularSolve(Op):
trans_solve_op = GpuCublasTriangularSolve(not self.lower) trans_solve_op = GpuCublasTriangularSolve(not self.lower)
b_bar = trans_solve_op(A.T, c_bar) b_bar = trans_solve_op(A.T, c_bar)
A_bar = -tensor.outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
# FIXME: tensor.outer does not appear to use GPU
def gpu_outer(x,y):
return tensor.dot(x.dimshuffle(0,'x'),y.dimshuffle('x',0))
A_bar = -gpu_outer(b_bar, c) if c.ndim == 1 else -b_bar.dot(c.T)
# FIXME: tensor.tril / tensor.triu has no GPU implementation
if self.lower: if self.lower:
A_bar = tensor.tril(A_bar) A_bar = tensor.tril(A_bar)
else: else:
...@@ -567,15 +573,20 @@ class GpuCholesky(Op): ...@@ -567,15 +573,20 @@ class GpuCholesky(Op):
dz = gradients[0] dz = gradients[0]
chol_x = outputs[0] chol_x = outputs[0]
ok = ~tensor.any(tensor.isnan(chol_x)) # this is for nan mode
chol_x = tensor.switch(ok, chol_x, 1) #
dz = tensor.switch(ok, dz, 1) # ok = ~tensor.any(tensor.isnan(chol_x))
# chol_x = tensor.switch(ok, chol_x, 1)
# dz = tensor.switch(ok, dz, 1)
# deal with upper triangular by converting to lower triangular # deal with upper triangular by converting to lower triangular
if not self.lower: if not self.lower:
chol_x = chol_x.T chol_x = chol_x.T
dz = dz.T dz = dz.T
# FIXME: tensor.tril / tensor.triu / tensor.diagonal / tensor.diag
# has no GPU implementation
def tril_and_halve_diagonal(mtx): def tril_and_halve_diagonal(mtx):
"""Extracts lower triangle of square matrix and halves diagonal.""" """Extracts lower triangle of square matrix and halves diagonal."""
return tensor.tril(mtx) - tensor.diag(tensor.diagonal(mtx) / 2.) return tensor.tril(mtx) - tensor.diag(tensor.diagonal(mtx) / 2.)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论