提交 e6914913 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use jax.numpy.vectorize for Elemwise Composite Ops

上级 4fa10665
...@@ -131,7 +131,7 @@ def jax_funcify(op, **kwargs): ...@@ -131,7 +131,7 @@ def jax_funcify(op, **kwargs):
@jax_funcify.register(MakeSlice) @jax_funcify.register(MakeSlice)
def jax_funcify_MakeSlice(op): def jax_funcify_MakeSlice(op, **kwargs):
def makeslice(*x): def makeslice(*x):
return slice(*x) return slice(*x)
...@@ -139,7 +139,7 @@ def jax_funcify_MakeSlice(op): ...@@ -139,7 +139,7 @@ def jax_funcify_MakeSlice(op):
@jax_funcify.register(ScalarOp) @jax_funcify.register(ScalarOp)
def jax_funcify_ScalarOp(op): def jax_funcify_ScalarOp(op, **kwargs):
func_name = op.nfunc_spec[0] func_name = op.nfunc_spec[0]
if "." in func_name: if "." in func_name:
...@@ -168,7 +168,7 @@ def jax_funcify_ScalarOp(op): ...@@ -168,7 +168,7 @@ def jax_funcify_ScalarOp(op):
@jax_funcify.register(Clip) @jax_funcify.register(Clip)
def jax_funcify_Clip(op): def jax_funcify_Clip(op, **kwargs):
def clip(x, min, max): def clip(x, min, max):
return jnp.where(x < min, min, jnp.where(x > max, max, x)) return jnp.where(x < min, min, jnp.where(x > max, max, x))
...@@ -176,7 +176,7 @@ def jax_funcify_Clip(op): ...@@ -176,7 +176,7 @@ def jax_funcify_Clip(op):
@jax_funcify.register(Identity) @jax_funcify.register(Identity)
def jax_funcify_Identity(op): def jax_funcify_Identity(op, **kwargs):
def identity(x): def identity(x):
return x return x
...@@ -184,7 +184,7 @@ def jax_funcify_Identity(op): ...@@ -184,7 +184,7 @@ def jax_funcify_Identity(op):
@jax_funcify.register(Softmax) @jax_funcify.register(Softmax)
def jax_funcify_Softmax(op): def jax_funcify_Softmax(op, **kwargs):
def softmax(x): def softmax(x):
return jax.nn.softmax(x) return jax.nn.softmax(x)
...@@ -192,7 +192,7 @@ def jax_funcify_Softmax(op): ...@@ -192,7 +192,7 @@ def jax_funcify_Softmax(op):
@jax_funcify.register(LogSoftmax) @jax_funcify.register(LogSoftmax)
def jax_funcify_LogSoftmax(op): def jax_funcify_LogSoftmax(op, **kwargs):
def log_softmax(x): def log_softmax(x):
return jax.nn.log_softmax(x) return jax.nn.log_softmax(x)
...@@ -200,7 +200,7 @@ def jax_funcify_LogSoftmax(op): ...@@ -200,7 +200,7 @@ def jax_funcify_LogSoftmax(op):
@jax_funcify.register(ScalarSoftplus) @jax_funcify.register(ScalarSoftplus)
def jax_funcify_ScalarSoftplus(op): def jax_funcify_ScalarSoftplus(op, **kwargs):
def scalarsoftplus(x): def scalarsoftplus(x):
return jnp.where(x < -30.0, 0.0, jnp.where(x > 30.0, x, jnp.log1p(jnp.exp(x)))) return jnp.where(x < -30.0, 0.0, jnp.where(x > 30.0, x, jnp.log1p(jnp.exp(x))))
...@@ -208,7 +208,7 @@ def jax_funcify_ScalarSoftplus(op): ...@@ -208,7 +208,7 @@ def jax_funcify_ScalarSoftplus(op):
@jax_funcify.register(Second) @jax_funcify.register(Second)
def jax_funcify_Second(op): def jax_funcify_Second(op, **kwargs):
def second(x, y): def second(x, y):
return jnp.broadcast_to(y, x.shape) return jnp.broadcast_to(y, x.shape)
...@@ -216,7 +216,7 @@ def jax_funcify_Second(op): ...@@ -216,7 +216,7 @@ def jax_funcify_Second(op):
@jax_funcify.register(AllocDiag) @jax_funcify.register(AllocDiag)
def jax_funcify_AllocDiag(op): def jax_funcify_AllocDiag(op, **kwargs):
offset = op.offset offset = op.offset
def allocdiag(v, offset=offset): def allocdiag(v, offset=offset):
...@@ -226,7 +226,7 @@ def jax_funcify_AllocDiag(op): ...@@ -226,7 +226,7 @@ def jax_funcify_AllocDiag(op):
@jax_funcify.register(AllocEmpty) @jax_funcify.register(AllocEmpty)
def jax_funcify_AllocEmpty(op): def jax_funcify_AllocEmpty(op, **kwargs):
def allocempty(*shape): def allocempty(*shape):
return jnp.empty(shape, dtype=op.dtype) return jnp.empty(shape, dtype=op.dtype)
...@@ -234,7 +234,7 @@ def jax_funcify_AllocEmpty(op): ...@@ -234,7 +234,7 @@ def jax_funcify_AllocEmpty(op):
@jax_funcify.register(Alloc) @jax_funcify.register(Alloc)
def jax_funcify_Alloc(op): def jax_funcify_Alloc(op, **kwargs):
def alloc(x, *shape): def alloc(x, *shape):
res = jnp.broadcast_to(x, shape) res = jnp.broadcast_to(x, shape)
return res return res
...@@ -243,7 +243,7 @@ def jax_funcify_Alloc(op): ...@@ -243,7 +243,7 @@ def jax_funcify_Alloc(op):
@jax_funcify.register(Dot) @jax_funcify.register(Dot)
def jax_funcify_Dot(op): def jax_funcify_Dot(op, **kwargs):
def dot(x, y): def dot(x, y):
return jnp.dot(x, y) return jnp.dot(x, y)
...@@ -251,7 +251,7 @@ def jax_funcify_Dot(op): ...@@ -251,7 +251,7 @@ def jax_funcify_Dot(op):
@jax_funcify.register(ARange) @jax_funcify.register(ARange)
def jax_funcify_ARange(op): def jax_funcify_ARange(op, **kwargs):
# XXX: This currently requires concrete arguments. # XXX: This currently requires concrete arguments.
def arange(start, stop, step): def arange(start, stop, step):
return jnp.arange(start, stop, step, dtype=op.dtype) return jnp.arange(start, stop, step, dtype=op.dtype)
...@@ -274,7 +274,7 @@ def jnp_safe_copy(x): ...@@ -274,7 +274,7 @@ def jnp_safe_copy(x):
@jax_funcify.register(DeepCopyOp) @jax_funcify.register(DeepCopyOp)
def jax_funcify_DeepCopyOp(op): def jax_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x): def deepcopyop(x):
return jnp_safe_copy(x) return jnp_safe_copy(x)
...@@ -282,7 +282,7 @@ def jax_funcify_DeepCopyOp(op): ...@@ -282,7 +282,7 @@ def jax_funcify_DeepCopyOp(op):
@jax_funcify.register(Shape) @jax_funcify.register(Shape)
def jax_funcify_Shape(op): def jax_funcify_Shape(op, **kwargs):
def shape(x): def shape(x):
return jnp.shape(x) return jnp.shape(x)
...@@ -290,7 +290,7 @@ def jax_funcify_Shape(op): ...@@ -290,7 +290,7 @@ def jax_funcify_Shape(op):
@jax_funcify.register(Shape_i) @jax_funcify.register(Shape_i)
def jax_funcify_Shape_i(op): def jax_funcify_Shape_i(op, **kwargs):
i = op.i i = op.i
def shape_i(x): def shape_i(x):
...@@ -300,7 +300,7 @@ def jax_funcify_Shape_i(op): ...@@ -300,7 +300,7 @@ def jax_funcify_Shape_i(op):
@jax_funcify.register(SpecifyShape) @jax_funcify.register(SpecifyShape)
def jax_funcify_SpecifyShape(op): def jax_funcify_SpecifyShape(op, **kwargs):
def specifyshape(x, shape): def specifyshape(x, shape):
assert x.ndim == len(shape) assert x.ndim == len(shape)
assert jnp.all(x.shape == tuple(shape)), ( assert jnp.all(x.shape == tuple(shape)), (
...@@ -315,7 +315,7 @@ def jax_funcify_SpecifyShape(op): ...@@ -315,7 +315,7 @@ def jax_funcify_SpecifyShape(op):
@jax_funcify.register(Rebroadcast) @jax_funcify.register(Rebroadcast)
def jax_funcify_Rebroadcast(op): def jax_funcify_Rebroadcast(op, **kwargs):
op_axis = op.axis op_axis = op.axis
def rebroadcast(x): def rebroadcast(x):
...@@ -331,7 +331,7 @@ def jax_funcify_Rebroadcast(op): ...@@ -331,7 +331,7 @@ def jax_funcify_Rebroadcast(op):
@jax_funcify.register(ViewOp) @jax_funcify.register(ViewOp)
def jax_funcify_ViewOp(op): def jax_funcify_ViewOp(op, **kwargs):
def viewop(x): def viewop(x):
return x return x
...@@ -339,7 +339,7 @@ def jax_funcify_ViewOp(op): ...@@ -339,7 +339,7 @@ def jax_funcify_ViewOp(op):
@jax_funcify.register(Cast) @jax_funcify.register(Cast)
def jax_funcify_Cast(op): def jax_funcify_Cast(op, **kwargs):
def cast(x): def cast(x):
return jnp.array(x).astype(op.o_type.dtype) return jnp.array(x).astype(op.o_type.dtype)
...@@ -347,7 +347,7 @@ def jax_funcify_Cast(op): ...@@ -347,7 +347,7 @@ def jax_funcify_Cast(op):
@jax_funcify.register(TensorFromScalar) @jax_funcify.register(TensorFromScalar)
def jax_funcify_TensorFromScalar(op): def jax_funcify_TensorFromScalar(op, **kwargs):
def tensor_from_scalar(x): def tensor_from_scalar(x):
return jnp.array(x) return jnp.array(x)
...@@ -355,7 +355,7 @@ def jax_funcify_TensorFromScalar(op): ...@@ -355,7 +355,7 @@ def jax_funcify_TensorFromScalar(op):
@jax_funcify.register(ScalarFromTensor) @jax_funcify.register(ScalarFromTensor)
def jax_funcify_ScalarFromTensor(op): def jax_funcify_ScalarFromTensor(op, **kwargs):
def scalar_from_tensor(x): def scalar_from_tensor(x):
return jnp.array(x).flatten()[0] return jnp.array(x).flatten()[0]
...@@ -363,30 +363,25 @@ def jax_funcify_ScalarFromTensor(op): ...@@ -363,30 +363,25 @@ def jax_funcify_ScalarFromTensor(op):
@jax_funcify.register(Elemwise) @jax_funcify.register(Elemwise)
def jax_funcify_Elemwise(op): def jax_funcify_Elemwise(op, **kwargs):
scalar_op = op.scalar_op scalar_op = op.scalar_op
return jax_funcify(scalar_op) return jax_funcify(scalar_op, **kwargs)
@jax_funcify.register(Composite) @jax_funcify.register(Composite)
def jax_funcify_Composite(op): def jax_funcify_Composite(op, vectorize=True, **kwargs):
# This approach basically gets rid of the fused `Elemwise` by turning each
# `Op` in the `Composite` back into individually broadcasted NumPy-like
# operations.
# TODO: A better approach would involve something like `jax.vmap` or some
# other operation that can perform the broadcasting that `Elemwise` does.
jax_impl = jax_funcify(op.fgraph) jax_impl = jax_funcify(op.fgraph)
def composite(*args): def composite(*args):
return jax_impl(*args)[0] return jax_impl(*args)[0]
return composite return jnp.vectorize(composite)
@jax_funcify.register(Scan) @jax_funcify.register(Scan)
def jax_funcify_Scan(op): def jax_funcify_Scan(op, **kwargs):
inner_fg = FunctionGraph(op.inputs, op.outputs) inner_fg = FunctionGraph(op.inputs, op.outputs)
jax_aet_inner_func = jax_funcify(inner_fg) jax_aet_inner_func = jax_funcify(inner_fg, **kwargs)
def scan(*outer_inputs): def scan(*outer_inputs):
scan_args = ScanArgs( scan_args = ScanArgs(
...@@ -536,7 +531,7 @@ def jax_funcify_Scan(op): ...@@ -536,7 +531,7 @@ def jax_funcify_Scan(op):
@jax_funcify.register(IfElse) @jax_funcify.register(IfElse)
def jax_funcify_IfElse(op): def jax_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs n_outs = op.n_outs
def ifelse(cond, *args, n_outs=n_outs): def ifelse(cond, *args, n_outs=n_outs):
...@@ -549,7 +544,7 @@ def jax_funcify_IfElse(op): ...@@ -549,7 +544,7 @@ def jax_funcify_IfElse(op):
@jax_funcify.register(Subtensor) @jax_funcify.register(Subtensor)
def jax_funcify_Subtensor(op): def jax_funcify_Subtensor(op, **kwargs):
idx_list = getattr(op, "idx_list", None) idx_list = getattr(op, "idx_list", None)
...@@ -568,7 +563,7 @@ def jax_funcify_Subtensor(op): ...@@ -568,7 +563,7 @@ def jax_funcify_Subtensor(op):
_ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops] _ = [jax_funcify.register(op, jax_funcify_Subtensor) for op in subtensor_ops]
def jax_funcify_IncSubtensor(op): def jax_funcify_IncSubtensor(op, **kwargs):
idx_list = getattr(op, "idx_list", None) idx_list = getattr(op, "idx_list", None)
...@@ -591,7 +586,7 @@ _ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_o ...@@ -591,7 +586,7 @@ _ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_o
@jax_funcify.register(AdvancedIncSubtensor) @jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op): def jax_funcify_AdvancedIncSubtensor(op, **kwargs):
if getattr(op, "set_instead_of_inc", False): if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update jax_fn = jax.ops.index_update
...@@ -606,7 +601,12 @@ def jax_funcify_AdvancedIncSubtensor(op): ...@@ -606,7 +601,12 @@ def jax_funcify_AdvancedIncSubtensor(op):
@jax_funcify.register(FunctionGraph) @jax_funcify.register(FunctionGraph)
def jax_funcify_FunctionGraph( def jax_funcify_FunctionGraph(
fgraph, order=None, input_storage=None, output_storage=None, storage_map=None fgraph,
order=None,
input_storage=None,
output_storage=None,
storage_map=None,
**kwargs,
): ):
if order is None: if order is None:
...@@ -642,7 +642,7 @@ def jax_funcify_FunctionGraph( ...@@ -642,7 +642,7 @@ def jax_funcify_FunctionGraph(
body_assigns = [] body_assigns = []
for node in order: for node in order:
jax_func = jax_funcify(node.op) jax_func = jax_funcify(node.op, node=node, **kwargs)
# Create a local alias with a unique name # Create a local alias with a unique name
local_jax_func_name = unique_name(jax_func) local_jax_func_name = unique_name(jax_func)
...@@ -696,7 +696,7 @@ def {fgraph_name}({", ".join(fgraph_input_names)}): ...@@ -696,7 +696,7 @@ def {fgraph_name}({", ".join(fgraph_input_names)}):
@jax_funcify.register(CAReduce) @jax_funcify.register(CAReduce)
def jax_funcify_CAReduce(op): def jax_funcify_CAReduce(op, **kwargs):
axis = op.axis axis = op.axis
op_nfunc_spec = getattr(op, "nfunc_spec", None) op_nfunc_spec = getattr(op, "nfunc_spec", None)
scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None) scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None)
...@@ -739,7 +739,7 @@ def jax_funcify_CAReduce(op): ...@@ -739,7 +739,7 @@ def jax_funcify_CAReduce(op):
@jax_funcify.register(MakeVector) @jax_funcify.register(MakeVector)
def jax_funcify_MakeVector(op): def jax_funcify_MakeVector(op, **kwargs):
def makevector(*x): def makevector(*x):
return jnp.array(x, dtype=op.dtype) return jnp.array(x, dtype=op.dtype)
...@@ -747,7 +747,7 @@ def jax_funcify_MakeVector(op): ...@@ -747,7 +747,7 @@ def jax_funcify_MakeVector(op):
@jax_funcify.register(Reshape) @jax_funcify.register(Reshape)
def jax_funcify_Reshape(op): def jax_funcify_Reshape(op, **kwargs):
def reshape(x, shape): def reshape(x, shape):
return jnp.reshape(x, shape) return jnp.reshape(x, shape)
...@@ -755,7 +755,7 @@ def jax_funcify_Reshape(op): ...@@ -755,7 +755,7 @@ def jax_funcify_Reshape(op):
@jax_funcify.register(DimShuffle) @jax_funcify.register(DimShuffle)
def jax_funcify_DimShuffle(op): def jax_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x): def dimshuffle(x):
res = jnp.transpose(x, op.shuffle + op.drop) res = jnp.transpose(x, op.shuffle + op.drop)
...@@ -776,7 +776,7 @@ def jax_funcify_DimShuffle(op): ...@@ -776,7 +776,7 @@ def jax_funcify_DimShuffle(op):
@jax_funcify.register(Join) @jax_funcify.register(Join)
def jax_funcify_Join(op): def jax_funcify_Join(op, **kwargs):
def join(axis, *tensors): def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim # tensors could also be tuples, and in this case they don't have a ndim
tensors = [jnp.asarray(tensor) for tensor in tensors] tensors = [jnp.asarray(tensor) for tensor in tensors]
...@@ -802,7 +802,7 @@ def jax_funcify_Join(op): ...@@ -802,7 +802,7 @@ def jax_funcify_Join(op):
@jax_funcify.register(MaxAndArgmax) @jax_funcify.register(MaxAndArgmax)
def jax_funcify_MaxAndArgmax(op): def jax_funcify_MaxAndArgmax(op, **kwargs):
axis = op.axis axis = op.axis
def maxandargmax(x, axis=axis): def maxandargmax(x, axis=axis):
...@@ -840,7 +840,7 @@ def jax_funcify_MaxAndArgmax(op): ...@@ -840,7 +840,7 @@ def jax_funcify_MaxAndArgmax(op):
@jax_funcify.register(ExtractDiag) @jax_funcify.register(ExtractDiag)
def jax_funcify_ExtractDiag(op): def jax_funcify_ExtractDiag(op, **kwargs):
offset = op.offset offset = op.offset
axis1 = op.axis1 axis1 = op.axis1
axis2 = op.axis2 axis2 = op.axis2
...@@ -852,7 +852,7 @@ def jax_funcify_ExtractDiag(op): ...@@ -852,7 +852,7 @@ def jax_funcify_ExtractDiag(op):
@jax_funcify.register(Cholesky) @jax_funcify.register(Cholesky)
def jax_funcify_Cholesky(op): def jax_funcify_Cholesky(op, **kwargs):
lower = op.lower lower = op.lower
def cholesky(a, lower=lower): def cholesky(a, lower=lower):
...@@ -862,7 +862,7 @@ def jax_funcify_Cholesky(op): ...@@ -862,7 +862,7 @@ def jax_funcify_Cholesky(op):
@jax_funcify.register(Solve) @jax_funcify.register(Solve)
def jax_funcify_Solve(op): def jax_funcify_Solve(op, **kwargs):
if op.A_structure == "lower_triangular": if op.A_structure == "lower_triangular":
lower = True lower = True
...@@ -876,7 +876,7 @@ def jax_funcify_Solve(op): ...@@ -876,7 +876,7 @@ def jax_funcify_Solve(op):
@jax_funcify.register(Det) @jax_funcify.register(Det)
def jax_funcify_Det(op): def jax_funcify_Det(op, **kwargs):
def det(x): def det(x):
return jnp.linalg.det(x) return jnp.linalg.det(x)
...@@ -884,7 +884,7 @@ def jax_funcify_Det(op): ...@@ -884,7 +884,7 @@ def jax_funcify_Det(op):
@jax_funcify.register(Eig) @jax_funcify.register(Eig)
def jax_funcify_Eig(op): def jax_funcify_Eig(op, **kwargs):
def eig(x): def eig(x):
return jnp.linalg.eig(x) return jnp.linalg.eig(x)
...@@ -892,7 +892,7 @@ def jax_funcify_Eig(op): ...@@ -892,7 +892,7 @@ def jax_funcify_Eig(op):
@jax_funcify.register(Eigh) @jax_funcify.register(Eigh)
def jax_funcify_Eigh(op): def jax_funcify_Eigh(op, **kwargs):
uplo = op.UPLO uplo = op.UPLO
def eigh(x, uplo=uplo): def eigh(x, uplo=uplo):
...@@ -902,7 +902,7 @@ def jax_funcify_Eigh(op): ...@@ -902,7 +902,7 @@ def jax_funcify_Eigh(op):
@jax_funcify.register(MatrixInverse) @jax_funcify.register(MatrixInverse)
def jax_funcify_MatrixInverse(op): def jax_funcify_MatrixInverse(op, **kwargs):
def matrix_inverse(x): def matrix_inverse(x):
return jnp.linalg.inv(x) return jnp.linalg.inv(x)
...@@ -910,7 +910,7 @@ def jax_funcify_MatrixInverse(op): ...@@ -910,7 +910,7 @@ def jax_funcify_MatrixInverse(op):
@jax_funcify.register(QRFull) @jax_funcify.register(QRFull)
def jax_funcify_QRFull(op): def jax_funcify_QRFull(op, **kwargs):
mode = op.mode mode = op.mode
def qr_full(x, mode=mode): def qr_full(x, mode=mode):
...@@ -920,7 +920,7 @@ def jax_funcify_QRFull(op): ...@@ -920,7 +920,7 @@ def jax_funcify_QRFull(op):
@jax_funcify.register(QRIncomplete) @jax_funcify.register(QRIncomplete)
def jax_funcify_QRIncomplete(op): def jax_funcify_QRIncomplete(op, **kwargs):
mode = op.mode mode = op.mode
def qr_incomplete(x, mode=mode): def qr_incomplete(x, mode=mode):
...@@ -930,7 +930,7 @@ def jax_funcify_QRIncomplete(op): ...@@ -930,7 +930,7 @@ def jax_funcify_QRIncomplete(op):
@jax_funcify.register(SVD) @jax_funcify.register(SVD)
def jax_funcify_SVD(op): def jax_funcify_SVD(op, **kwargs):
full_matrices = op.full_matrices full_matrices = op.full_matrices
compute_uv = op.compute_uv compute_uv = op.compute_uv
...@@ -941,7 +941,7 @@ def jax_funcify_SVD(op): ...@@ -941,7 +941,7 @@ def jax_funcify_SVD(op):
@jax_funcify.register(CumOp) @jax_funcify.register(CumOp)
def jax_funcify_CumOp(op): def jax_funcify_CumOp(op, **kwargs):
axis = op.axis axis = op.axis
mode = op.mode mode = op.mode
...@@ -955,7 +955,7 @@ def jax_funcify_CumOp(op): ...@@ -955,7 +955,7 @@ def jax_funcify_CumOp(op):
@jax_funcify.register(DiffOp) @jax_funcify.register(DiffOp)
def jax_funcify_DiffOp(op): def jax_funcify_DiffOp(op, **kwargs):
n = op.n n = op.n
axis = op.axis axis = op.axis
...@@ -966,7 +966,7 @@ def jax_funcify_DiffOp(op): ...@@ -966,7 +966,7 @@ def jax_funcify_DiffOp(op):
@jax_funcify.register(RepeatOp) @jax_funcify.register(RepeatOp)
def jax_funcify_RepeatOp(op): def jax_funcify_RepeatOp(op, **kwargs):
axis = op.axis axis = op.axis
def repeatop(x, repeats, axis=axis): def repeatop(x, repeats, axis=axis):
...@@ -976,7 +976,7 @@ def jax_funcify_RepeatOp(op): ...@@ -976,7 +976,7 @@ def jax_funcify_RepeatOp(op):
@jax_funcify.register(Bartlett) @jax_funcify.register(Bartlett)
def jax_funcify_Bartlett(op): def jax_funcify_Bartlett(op, **kwargs):
def bartlett(x): def bartlett(x):
return jnp.bartlett(x) return jnp.bartlett(x)
...@@ -984,7 +984,7 @@ def jax_funcify_Bartlett(op): ...@@ -984,7 +984,7 @@ def jax_funcify_Bartlett(op):
@jax_funcify.register(FillDiagonal) @jax_funcify.register(FillDiagonal)
def jax_funcify_FillDiagonal(op): def jax_funcify_FillDiagonal(op, **kwargs):
# def filldiagonal(a, val): # def filldiagonal(a, val):
# if a.ndim == 2: # if a.ndim == 2:
...@@ -1002,7 +1002,7 @@ def jax_funcify_FillDiagonal(op): ...@@ -1002,7 +1002,7 @@ def jax_funcify_FillDiagonal(op):
@jax_funcify.register(FillDiagonalOffset) @jax_funcify.register(FillDiagonalOffset)
def jax_funcify_FillDiagonalOffset(op): def jax_funcify_FillDiagonalOffset(op, **kwargs):
# def filldiagonaloffset(a, val, offset): # def filldiagonaloffset(a, val, offset):
# height, width = a.shape # height, width = a.shape
...@@ -1026,7 +1026,7 @@ def jax_funcify_FillDiagonalOffset(op): ...@@ -1026,7 +1026,7 @@ def jax_funcify_FillDiagonalOffset(op):
@jax_funcify.register(Unique) @jax_funcify.register(Unique)
def jax_funcify_Unique(op): def jax_funcify_Unique(op, **kwargs):
axis = op.axis axis = op.axis
if axis is not None: if axis is not None:
...@@ -1055,7 +1055,7 @@ def jax_funcify_Unique(op): ...@@ -1055,7 +1055,7 @@ def jax_funcify_Unique(op):
@jax_funcify.register(UnravelIndex) @jax_funcify.register(UnravelIndex)
def jax_funcify_UnravelIndex(op): def jax_funcify_UnravelIndex(op, **kwargs):
order = op.order order = op.order
warn("JAX ignores the `order` parameter in `unravel_index`.") warn("JAX ignores the `order` parameter in `unravel_index`.")
...@@ -1067,7 +1067,7 @@ def jax_funcify_UnravelIndex(op): ...@@ -1067,7 +1067,7 @@ def jax_funcify_UnravelIndex(op):
@jax_funcify.register(RavelMultiIndex) @jax_funcify.register(RavelMultiIndex)
def jax_funcify_RavelMultiIndex(op): def jax_funcify_RavelMultiIndex(op, **kwargs):
mode = op.mode mode = op.mode
order = op.order order = op.order
...@@ -1079,7 +1079,7 @@ def jax_funcify_RavelMultiIndex(op): ...@@ -1079,7 +1079,7 @@ def jax_funcify_RavelMultiIndex(op):
@jax_funcify.register(Eye) @jax_funcify.register(Eye)
def jax_funcify_Eye(op): def jax_funcify_Eye(op, **kwargs):
dtype = op.dtype dtype = op.dtype
def eye(N, M, k): def eye(N, M, k):
...@@ -1089,7 +1089,7 @@ def jax_funcify_Eye(op): ...@@ -1089,7 +1089,7 @@ def jax_funcify_Eye(op):
@jax_funcify.register(BatchedDot) @jax_funcify.register(BatchedDot)
def jax_funcify_BatchedDot(op): def jax_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b): def batched_dot(a, b):
if a.shape[0] != b.shape[0]: if a.shape[0] != b.shape[0]:
raise TypeError("Shapes must match in the 0-th dimension") raise TypeError("Shapes must match in the 0-th dimension")
...@@ -1101,7 +1101,7 @@ def jax_funcify_BatchedDot(op): ...@@ -1101,7 +1101,7 @@ def jax_funcify_BatchedDot(op):
@jax_funcify.register(RandomVariable) @jax_funcify.register(RandomVariable)
def jax_funcify_RandomVariable(op): def jax_funcify_RandomVariable(op, **kwargs):
name = op.name name = op.name
if not hasattr(jax.random, name): if not hasattr(jax.random, name):
......
...@@ -298,22 +298,32 @@ def test_jax_basic(): ...@@ -298,22 +298,32 @@ def test_jax_basic():
) )
def test_jax_Composite(): @pytest.mark.parametrize(
"x, y, x_val, y_val",
[
(scalar("x"), scalar("y"), np.array(10), np.array(20)),
(scalar("x"), vector("y"), np.array(10), np.arange(10, 20)),
(
matrix("x"),
vector("y"),
np.arange(10 * 20).reshape((20, 10)),
np.arange(10, 20),
),
],
)
def test_jax_Composite(x, y, x_val, y_val):
x_s = aes.float64("x") x_s = aes.float64("x")
y_s = aes.float64("y") y_s = aes.float64("y")
comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2])) comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2 + aes.exp(x_s - y_s)]))
x = vector("x")
y = vector("y")
out = comp_op(x, y) out = comp_op(x, y)
out_fg = FunctionGraph([x, y], [out]) out_fg = FunctionGraph([x, y], [out])
test_input_vals = [ test_input_vals = [
np.arange(10).astype(config.floatX), x_val.astype(config.floatX),
np.arange(10, 20).astype(config.floatX), y_val.astype(config.floatX),
] ]
_ = compare_jax_and_py(out_fg, test_input_vals) _ = compare_jax_and_py(out_fg, test_input_vals)
...@@ -354,7 +364,7 @@ def test_jax_FunctionGraph_once(): ...@@ -354,7 +364,7 @@ def test_jax_FunctionGraph_once():
outputs[i][0] = inp[0] outputs[i][0] = inp[0]
@jax_funcify.register(TestOp) @jax_funcify.register(TestOp)
def jax_funcify_TestOp(op): def jax_funcify_TestOp(op, **kwargs):
def func(*args, op=op): def func(*args, op=op):
op.called += 1 op.called += 1
return list(args) return list(args)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论