提交 e6914913 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use jax.numpy.vectorize for Elemwise Composite Ops

上级 4fa10665
......@@ -131,7 +131,7 @@ def jax_funcify(op, **kwargs):
@jax_funcify.register(MakeSlice)
def jax_funcify_MakeSlice(op):
def jax_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)
......@@ -139,7 +139,7 @@ def jax_funcify_MakeSlice(op):
@jax_funcify.register(ScalarOp)
def jax_funcify_ScalarOp(op):
def jax_funcify_ScalarOp(op, **kwargs):
func_name = op.nfunc_spec[0]
if "." in func_name:
......@@ -168,7 +168,7 @@ def jax_funcify_ScalarOp(op):
@jax_funcify.register(Clip)
def jax_funcify_Clip(op):
def jax_funcify_Clip(op, **kwargs):
def clip(x, min, max):
return jnp.where(x < min, min, jnp.where(x > max, max, x))
......@@ -176,7 +176,7 @@ def jax_funcify_Clip(op):
@jax_funcify.register(Identity)
def jax_funcify_Identity(op):
def jax_funcify_Identity(op, **kwargs):
def identity(x):
return x
......@@ -184,7 +184,7 @@ def jax_funcify_Identity(op):
@jax_funcify.register(Softmax)
def jax_funcify_Softmax(op):
def jax_funcify_Softmax(op, **kwargs):
def softmax(x):
return jax.nn.softmax(x)
......@@ -192,7 +192,7 @@ def jax_funcify_Softmax(op):
@jax_funcify.register(LogSoftmax)
def jax_funcify_LogSoftmax(op):
def jax_funcify_LogSoftmax(op, **kwargs):
def log_softmax(x):
return jax.nn.log_softmax(x)
......@@ -200,7 +200,7 @@ def jax_funcify_LogSoftmax(op):
@jax_funcify.register(ScalarSoftplus)
def jax_funcify_ScalarSoftplus(op):
def jax_funcify_ScalarSoftplus(op, **kwargs):
def scalarsoftplus(x):
return jnp.where(x < -30.0, 0.0, jnp.where(x > 30.0, x, jnp.log1p(jnp.exp(x))))
......@@ -208,7 +208,7 @@ def jax_funcify_ScalarSoftplus(op):
@jax_funcify.register(Second)
def jax_funcify_Second(op):
def jax_funcify_Second(op, **kwargs):
def second(x, y):
return jnp.broadcast_to(y, x.shape)
......@@ -216,7 +216,7 @@ def jax_funcify_Second(op):
@jax_funcify.register(AllocDiag)
def jax_funcify_AllocDiag(op):
def jax_funcify_AllocDiag(op, **kwargs):
offset = op.offset
def allocdiag(v, offset=offset):
......@@ -226,7 +226,7 @@ def jax_funcify_AllocDiag(op):
@jax_funcify.register(AllocEmpty)
def jax_funcify_AllocEmpty(op):
def jax_funcify_AllocEmpty(op, **kwargs):
def allocempty(*shape):
return jnp.empty(shape, dtype=op.dtype)
......@@ -234,7 +234,7 @@ def jax_funcify_AllocEmpty(op):
@jax_funcify.register(Alloc)
def jax_funcify_Alloc(op):
def jax_funcify_Alloc(op, **kwargs):
def alloc(x, *shape):
res = jnp.broadcast_to(x, shape)
return res
......@@ -243,7 +243,7 @@ def jax_funcify_Alloc(op):
@jax_funcify.register(Dot)
def jax_funcify_Dot(op):
def jax_funcify_Dot(op, **kwargs):
def dot(x, y):
return jnp.dot(x, y)
......@@ -251,7 +251,7 @@ def jax_funcify_Dot(op):
@jax_funcify.register(ARange)
def jax_funcify_ARange(op):
def jax_funcify_ARange(op, **kwargs):
# XXX: This currently requires concrete arguments.
def arange(start, stop, step):
return jnp.arange(start, stop, step, dtype=op.dtype)
......@@ -274,7 +274,7 @@ def jnp_safe_copy(x):
@jax_funcify.register(DeepCopyOp)
def jax_funcify_DeepCopyOp(op):
def jax_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x):
return jnp_safe_copy(x)
......@@ -282,7 +282,7 @@ def jax_funcify_DeepCopyOp(op):
@jax_funcify.register(Shape)
def jax_funcify_Shape(op):
def jax_funcify_Shape(op, **kwargs):
def shape(x):
return jnp.shape(x)
......@@ -290,7 +290,7 @@ def jax_funcify_Shape(op):
@jax_funcify.register(Shape_i)
def jax_funcify_Shape_i(op):
def jax_funcify_Shape_i(op, **kwargs):
i = op.i
def shape_i(x):
......@@ -300,7 +300,7 @@ def jax_funcify_Shape_i(op):
@jax_funcify.register(SpecifyShape)
def jax_funcify_SpecifyShape(op):
def jax_funcify_SpecifyShape(op, **kwargs):
def specifyshape(x, shape):
assert x.ndim == len(shape)
assert jnp.all(x.shape == tuple(shape)), (
......@@ -315,7 +315,7 @@ def jax_funcify_SpecifyShape(op):
@jax_funcify.register(Rebroadcast)
def jax_funcify_Rebroadcast(op):
def jax_funcify_Rebroadcast(op, **kwargs):
op_axis = op.axis
def rebroadcast(x):
......@@ -331,7 +331,7 @@ def jax_funcify_Rebroadcast(op):
@jax_funcify.register(ViewOp)
def jax_funcify_ViewOp(op):
def jax_funcify_ViewOp(op, **kwargs):
def viewop(x):
return x
......@@ -339,7 +339,7 @@ def jax_funcify_ViewOp(op):
@jax_funcify.register(Cast)
def jax_funcify_Cast(op):
def jax_funcify_Cast(op, **kwargs):
def cast(x):
return jnp.array(x).astype(op.o_type.dtype)
......@@ -347,7 +347,7 @@ def jax_funcify_Cast(op):
@jax_funcify.register(TensorFromScalar)
def jax_funcify_TensorFromScalar(op):
def jax_funcify_TensorFromScalar(op, **kwargs):
def tensor_from_scalar(x):
return jnp.array(x)
......@@ -355,7 +355,7 @@ def jax_funcify_TensorFromScalar(op):
@jax_funcify.register(ScalarFromTensor)
def jax_funcify_ScalarFromTensor(op):
def jax_funcify_ScalarFromTensor(op, **kwargs):
def scalar_from_tensor(x):
return jnp.array(x).flatten()[0]
......@@ -363,30 +363,25 @@ def jax_funcify_ScalarFromTensor(op):
@jax_funcify.register(Elemwise)
def jax_funcify_Elemwise(op):
def jax_funcify_Elemwise(op, **kwargs):
scalar_op = op.scalar_op
return jax_funcify(scalar_op)
return jax_funcify(scalar_op, **kwargs)
@jax_funcify.register(Composite)
def jax_funcify_Composite(op):
# This approach basically gets rid of the fused `Elemwise` by turning each
# `Op` in the `Composite` back into individually broadcasted NumPy-like
# operations.
# TODO: A better approach would involve something like `jax.vmap` or some
# other operation that can perform the broadcasting that `Elemwise` does.
def jax_funcify_Composite(op, vectorize=True, **kwargs):
jax_impl = jax_funcify(op.fgraph)
def composite(*args):
return jax_impl(*args)[0]
return composite
return jnp.vectorize(composite)
@jax_funcify.register(Scan)
def jax_funcify_Scan(op):
def jax_funcify_Scan(op, **kwargs):
inner_fg = FunctionGraph(op.inputs, op.outputs)
jax_aet_inner_func = jax_funcify(inner_fg)
jax_aet_inner_func = jax_funcify(inner_fg, **kwargs)
def scan(*outer_inputs):
scan_args = ScanArgs(
......@@ -536,7 +531,7 @@ def jax_funcify_Scan(op):
@jax_funcify.register(IfElse)
def jax_funcify_IfElse(op):
def jax_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
def ifelse(cond, *args, n_outs=n_outs):
......@@ -549,7 +544,7 @@ def jax_funcify_IfElse(op):
@jax_funcify.register(Subtensor)
def jax_funcify_Subtensor(op):
def jax_funcify_Subtensor(op, **kwargs):
idx_list = getattr(op, "idx_list", None)
......@@ -568,7 +563,7 @@ def jax_funcify_Subtensor(op):
_ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops]
def jax_funcify_IncSubtensor(op):
def jax_funcify_IncSubtensor(op, **kwargs):
idx_list = getattr(op, "idx_list", None)
......@@ -591,7 +586,7 @@ _ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_o
@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op):
def jax_funcify_AdvancedIncSubtensor(op, **kwargs):
if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update
......@@ -606,7 +601,12 @@ def jax_funcify_AdvancedIncSubtensor(op):
@jax_funcify.register(FunctionGraph)
def jax_funcify_FunctionGraph(
fgraph, order=None, input_storage=None, output_storage=None, storage_map=None
fgraph,
order=None,
input_storage=None,
output_storage=None,
storage_map=None,
**kwargs,
):
if order is None:
......@@ -642,7 +642,7 @@ def jax_funcify_FunctionGraph(
body_assigns = []
for node in order:
jax_func = jax_funcify(node.op)
jax_func = jax_funcify(node.op, node=node, **kwargs)
# Create a local alias with a unique name
local_jax_func_name = unique_name(jax_func)
......@@ -696,7 +696,7 @@ def {fgraph_name}({", ".join(fgraph_input_names)}):
@jax_funcify.register(CAReduce)
def jax_funcify_CAReduce(op):
def jax_funcify_CAReduce(op, **kwargs):
axis = op.axis
op_nfunc_spec = getattr(op, "nfunc_spec", None)
scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None)
......@@ -739,7 +739,7 @@ def jax_funcify_CAReduce(op):
@jax_funcify.register(MakeVector)
def jax_funcify_MakeVector(op):
def jax_funcify_MakeVector(op, **kwargs):
def makevector(*x):
return jnp.array(x, dtype=op.dtype)
......@@ -747,7 +747,7 @@ def jax_funcify_MakeVector(op):
@jax_funcify.register(Reshape)
def jax_funcify_Reshape(op):
def jax_funcify_Reshape(op, **kwargs):
def reshape(x, shape):
return jnp.reshape(x, shape)
......@@ -755,7 +755,7 @@ def jax_funcify_Reshape(op):
@jax_funcify.register(DimShuffle)
def jax_funcify_DimShuffle(op):
def jax_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x):
res = jnp.transpose(x, op.shuffle + op.drop)
......@@ -776,7 +776,7 @@ def jax_funcify_DimShuffle(op):
@jax_funcify.register(Join)
def jax_funcify_Join(op):
def jax_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [jnp.asarray(tensor) for tensor in tensors]
......@@ -802,7 +802,7 @@ def jax_funcify_Join(op):
@jax_funcify.register(MaxAndArgmax)
def jax_funcify_MaxAndArgmax(op):
def jax_funcify_MaxAndArgmax(op, **kwargs):
axis = op.axis
def maxandargmax(x, axis=axis):
......@@ -840,7 +840,7 @@ def jax_funcify_MaxAndArgmax(op):
@jax_funcify.register(ExtractDiag)
def jax_funcify_ExtractDiag(op):
def jax_funcify_ExtractDiag(op, **kwargs):
offset = op.offset
axis1 = op.axis1
axis2 = op.axis2
......@@ -852,7 +852,7 @@ def jax_funcify_ExtractDiag(op):
@jax_funcify.register(Cholesky)
def jax_funcify_Cholesky(op):
def jax_funcify_Cholesky(op, **kwargs):
lower = op.lower
def cholesky(a, lower=lower):
......@@ -862,7 +862,7 @@ def jax_funcify_Cholesky(op):
@jax_funcify.register(Solve)
def jax_funcify_Solve(op):
def jax_funcify_Solve(op, **kwargs):
if op.A_structure == "lower_triangular":
lower = True
......@@ -876,7 +876,7 @@ def jax_funcify_Solve(op):
@jax_funcify.register(Det)
def jax_funcify_Det(op):
def jax_funcify_Det(op, **kwargs):
def det(x):
return jnp.linalg.det(x)
......@@ -884,7 +884,7 @@ def jax_funcify_Det(op):
@jax_funcify.register(Eig)
def jax_funcify_Eig(op):
def jax_funcify_Eig(op, **kwargs):
def eig(x):
return jnp.linalg.eig(x)
......@@ -892,7 +892,7 @@ def jax_funcify_Eig(op):
@jax_funcify.register(Eigh)
def jax_funcify_Eigh(op):
def jax_funcify_Eigh(op, **kwargs):
uplo = op.UPLO
def eigh(x, uplo=uplo):
......@@ -902,7 +902,7 @@ def jax_funcify_Eigh(op):
@jax_funcify.register(MatrixInverse)
def jax_funcify_MatrixInverse(op):
def jax_funcify_MatrixInverse(op, **kwargs):
def matrix_inverse(x):
return jnp.linalg.inv(x)
......@@ -910,7 +910,7 @@ def jax_funcify_MatrixInverse(op):
@jax_funcify.register(QRFull)
def jax_funcify_QRFull(op):
def jax_funcify_QRFull(op, **kwargs):
mode = op.mode
def qr_full(x, mode=mode):
......@@ -920,7 +920,7 @@ def jax_funcify_QRFull(op):
@jax_funcify.register(QRIncomplete)
def jax_funcify_QRIncomplete(op):
def jax_funcify_QRIncomplete(op, **kwargs):
mode = op.mode
def qr_incomplete(x, mode=mode):
......@@ -930,7 +930,7 @@ def jax_funcify_QRIncomplete(op):
@jax_funcify.register(SVD)
def jax_funcify_SVD(op):
def jax_funcify_SVD(op, **kwargs):
full_matrices = op.full_matrices
compute_uv = op.compute_uv
......@@ -941,7 +941,7 @@ def jax_funcify_SVD(op):
@jax_funcify.register(CumOp)
def jax_funcify_CumOp(op):
def jax_funcify_CumOp(op, **kwargs):
axis = op.axis
mode = op.mode
......@@ -955,7 +955,7 @@ def jax_funcify_CumOp(op):
@jax_funcify.register(DiffOp)
def jax_funcify_DiffOp(op):
def jax_funcify_DiffOp(op, **kwargs):
n = op.n
axis = op.axis
......@@ -966,7 +966,7 @@ def jax_funcify_DiffOp(op):
@jax_funcify.register(RepeatOp)
def jax_funcify_RepeatOp(op):
def jax_funcify_RepeatOp(op, **kwargs):
axis = op.axis
def repeatop(x, repeats, axis=axis):
......@@ -976,7 +976,7 @@ def jax_funcify_RepeatOp(op):
@jax_funcify.register(Bartlett)
def jax_funcify_Bartlett(op):
def jax_funcify_Bartlett(op, **kwargs):
def bartlett(x):
return jnp.bartlett(x)
......@@ -984,7 +984,7 @@ def jax_funcify_Bartlett(op):
@jax_funcify.register(FillDiagonal)
def jax_funcify_FillDiagonal(op):
def jax_funcify_FillDiagonal(op, **kwargs):
# def filldiagonal(a, val):
# if a.ndim == 2:
......@@ -1002,7 +1002,7 @@ def jax_funcify_FillDiagonal(op):
@jax_funcify.register(FillDiagonalOffset)
def jax_funcify_FillDiagonalOffset(op):
def jax_funcify_FillDiagonalOffset(op, **kwargs):
# def filldiagonaloffset(a, val, offset):
# height, width = a.shape
......@@ -1026,7 +1026,7 @@ def jax_funcify_FillDiagonalOffset(op):
@jax_funcify.register(Unique)
def jax_funcify_Unique(op):
def jax_funcify_Unique(op, **kwargs):
axis = op.axis
if axis is not None:
......@@ -1055,7 +1055,7 @@ def jax_funcify_Unique(op):
@jax_funcify.register(UnravelIndex)
def jax_funcify_UnravelIndex(op):
def jax_funcify_UnravelIndex(op, **kwargs):
order = op.order
warn("JAX ignores the `order` parameter in `unravel_index`.")
......@@ -1067,7 +1067,7 @@ def jax_funcify_UnravelIndex(op):
@jax_funcify.register(RavelMultiIndex)
def jax_funcify_RavelMultiIndex(op):
def jax_funcify_RavelMultiIndex(op, **kwargs):
mode = op.mode
order = op.order
......@@ -1079,7 +1079,7 @@ def jax_funcify_RavelMultiIndex(op):
@jax_funcify.register(Eye)
def jax_funcify_Eye(op):
def jax_funcify_Eye(op, **kwargs):
dtype = op.dtype
def eye(N, M, k):
......@@ -1089,7 +1089,7 @@ def jax_funcify_Eye(op):
@jax_funcify.register(BatchedDot)
def jax_funcify_BatchedDot(op):
def jax_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b):
if a.shape[0] != b.shape[0]:
raise TypeError("Shapes must match in the 0-th dimension")
......@@ -1101,7 +1101,7 @@ def jax_funcify_BatchedDot(op):
@jax_funcify.register(RandomVariable)
def jax_funcify_RandomVariable(op):
def jax_funcify_RandomVariable(op, **kwargs):
name = op.name
if not hasattr(jax.random, name):
......
......@@ -298,22 +298,32 @@ def test_jax_basic():
)
def test_jax_Composite():
@pytest.mark.parametrize(
"x, y, x_val, y_val",
[
(scalar("x"), scalar("y"), np.array(10), np.array(20)),
(scalar("x"), vector("y"), np.array(10), np.arange(10, 20)),
(
matrix("x"),
vector("y"),
np.arange(10 * 20).reshape((20, 10)),
np.arange(10, 20),
),
],
)
def test_jax_Composite(x, y, x_val, y_val):
x_s = aes.float64("x")
y_s = aes.float64("y")
comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2]))
x = vector("x")
y = vector("y")
comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2 + aes.exp(x_s - y_s)]))
out = comp_op(x, y)
out_fg = FunctionGraph([x, y], [out])
test_input_vals = [
np.arange(10).astype(config.floatX),
np.arange(10, 20).astype(config.floatX),
x_val.astype(config.floatX),
y_val.astype(config.floatX),
]
_ = compare_jax_and_py(out_fg, test_input_vals)
......@@ -354,7 +364,7 @@ def test_jax_FunctionGraph_once():
outputs[i][0] = inp[0]
@jax_funcify.register(TestOp)
def jax_funcify_TestOp(op):
def jax_funcify_TestOp(op, **kwargs):
def func(*args, op=op):
op.called += 1
return list(args)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论