提交 e6914913 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use jax.numpy.vectorize for Elemwise Composite Ops

上级 4fa10665
......@@ -298,22 +298,32 @@ def test_jax_basic():
)
def test_jax_Composite():
@pytest.mark.parametrize(
"x, y, x_val, y_val",
[
(scalar("x"), scalar("y"), np.array(10), np.array(20)),
(scalar("x"), vector("y"), np.array(10), np.arange(10, 20)),
(
matrix("x"),
vector("y"),
np.arange(10 * 20).reshape((20, 10)),
np.arange(10, 20),
),
],
)
def test_jax_Composite(x, y, x_val, y_val):
x_s = aes.float64("x")
y_s = aes.float64("y")
comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2]))
x = vector("x")
y = vector("y")
comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2 + aes.exp(x_s - y_s)]))
out = comp_op(x, y)
out_fg = FunctionGraph([x, y], [out])
test_input_vals = [
np.arange(10).astype(config.floatX),
np.arange(10, 20).astype(config.floatX),
x_val.astype(config.floatX),
y_val.astype(config.floatX),
]
_ = compare_jax_and_py(out_fg, test_input_vals)
......@@ -354,7 +364,7 @@ def test_jax_FunctionGraph_once():
outputs[i][0] = inp[0]
@jax_funcify.register(TestOp)
def jax_funcify_TestOp(op):
def jax_funcify_TestOp(op, **kwargs):
def func(*args, op=op):
op.called += 1
return list(args)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论