Unverified 提交 90ef59eb authored 作者: Thomas Wiecki's avatar Thomas Wiecki 提交者: GitHub

Add jaxification for linear algebra operations (#59)

上级 4c72bf9e
差异被折叠。
......@@ -74,15 +74,12 @@ class JAXLinker(PerformLinker):
thunk_outputs = [storage_map[n] for n in node.outputs]
# JIT-compile the functions
if len(node.outputs) > 1:
assert len(jax_funcs) == len(node.ouptputs)
jax_impl_jits = [
jax.jit(jax_func, static_argnums) for jax_func in jax_funcs
]
else:
assert not isinstance(jax_funcs, Sequence)
jax_impl_jits = [jax.jit(jax_funcs, static_argnums)]
if not isinstance(jax_funcs, Sequence):
jax_funcs = [jax_funcs]
jax_impl_jits = [
jax.jit(jax_func, static_argnums) for jax_func in jax_funcs
]
def thunk(
node=node, jax_impl_jits=jax_impl_jits, thunk_outputs=thunk_outputs
......@@ -92,6 +89,14 @@ class JAXLinker(PerformLinker):
for jax_impl_jit in jax_impl_jits
]
if len(jax_impl_jits) < len(node.outputs):
# In this case, the JAX function will output a single
# output that contains the other outputs.
# This happens for multi-output `Op`s that directly
# correspond to multi-output JAX functions (e.g. `SVD` and
# `jax.numpy.linalg.svd`).
outputs = outputs[0]
for o_node, o_storage, o_val in zip(
node.outputs, thunk_outputs, outputs
):
......
......@@ -2,6 +2,7 @@ import theano
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from warnings import warn
from functools import update_wrapper, reduce
......@@ -49,12 +50,36 @@ from theano.tensor.opt import MakeVector
from theano.tensor.nnet.sigm import ScalarSoftplus
from theano.tensor.nlinalg import (
Det,
Eig,
Eigh,
MatrixInverse,
QRFull,
QRIncomplete,
SVD,
ExtractDiag,
AllocDiag,
)
from theano.tensor.slinalg import (
Cholesky,
Solve,
)
if theano.config.floatX == "float64":
jax.config.update("jax_enable_x64", True)
else:
jax.config.update("jax_enable_x64", False)
# XXX: Enabling this will break some shape-based functionality, and severely
# limit the types of graphs that can be converted.
# See https://github.com/google/jax/blob/4d556837cc9003492f674c012689efc3d68fdf5f/design_notes/omnistaging.md
jax.config.disable_omnistaging()
jax.config.update("jax_enable_x64", True)
# Older versions < 0.2.0 do not have this flag so we don't need to set it.
try:
jax.config.disable_omnistaging()
except AttributeError:
pass
subtensor_ops = (Subtensor, AdvancedSubtensor1, BaseAdvancedSubtensor)
incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1, BaseAdvancedIncSubtensor)
......@@ -629,3 +654,112 @@ def jax_funcify_Join(op):
return jnp.concatenate(tensors, axis=axis)
return join
@jax_funcify.register(ExtractDiag)
def jax_funcify_ExtractDiag(op):
offset = op.offset
axis1 = op.axis1
axis2 = op.axis2
def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2):
return jnp.diagonal(x, offset=offset, axis1=axis1, axis2=axis2)
return extract_diag
@jax_funcify.register(Cholesky)
def jax_funcify_Cholesky(op):
lower = op.lower
def cholesky(a, lower=lower):
return jsp.linalg.cholesky(a, lower=lower).astype(a.dtype)
return cholesky
@jax_funcify.register(AllocDiag)
def jax_funcify_AllocDiag(op):
def alloc_diag(x):
return jnp.diag(x)
return alloc_diag
@jax_funcify.register(Solve)
def jax_funcify_Solve(op):
if op.A_structure == "lower_triangular":
lower = True
else:
lower = False
def solve(a, b, lower=lower):
return jsp.linalg.solve(a, b, lower=lower)
return solve
@jax_funcify.register(Det)
def jax_funcify_Det(op):
def det(x):
return jnp.linalg.det(x)
return det
@jax_funcify.register(Eig)
def jax_funcify_Eig(op):
def eig(x):
return jnp.linalg.eig(x)
return eig
@jax_funcify.register(Eigh)
def jax_funcify_Eigh(op):
uplo = op.UPLO
def eigh(x, uplo=uplo):
return jnp.linalg.eigh(x, UPLO=uplo)
return eigh
@jax_funcify.register(MatrixInverse)
def jax_funcify_MatrixInverse(op):
def matrix_inverse(x):
return jnp.linalg.inv(x)
return matrix_inverse
@jax_funcify.register(QRFull)
def jax_funcify_QRFull(op):
mode = op.mode
def qr_full(x, mode=mode):
return jnp.linalg.qr(x, mode=mode)
return qr_full
@jax_funcify.register(QRIncomplete)
def jax_funcify_QRIncomplete(op):
mode = op.mode
def qr_incomplete(x, mode=mode):
return jnp.linalg.qr(x, mode=mode)
return qr_incomplete
@jax_funcify.register(SVD)
def jax_funcify_SVD(op):
full_matrices = op.full_matrices
compute_uv = op.compute_uv
def svd(x, full_matrices=full_matrices, compute_uv=compute_uv):
return jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
return svd
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论