提交 e09e2b1e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Return the correct output for multi-output inputs in a JAXified graph

上级 99f08ee3
...@@ -281,6 +281,14 @@ def test_jax_basic_multiout(): ...@@ -281,6 +281,14 @@ def test_jax_basic_multiout():
out_fg = theano.gof.FunctionGraph([x], outs) out_fg = theano.gof.FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn) compare_jax_and_py(out_fg, [X.astype(tt.config.floatX)], assert_fn=assert_fn)
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
x = tt.dvector()
mx, amx = theano.tensor.MaxAndArgmax([0])(x)
out = mx * amx
out_fg = theano.gof.FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.r_[1, 2]])
@pytest.mark.skip(reason="Not fully implemented, yet.") @pytest.mark.skip(reason="Not fully implemented, yet.")
def test_jax_scan(): def test_jax_scan():
......
...@@ -93,10 +93,13 @@ incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1) ...@@ -93,10 +93,13 @@ incsubtensor_ops = (IncSubtensor, AdvancedIncSubtensor1)
def compose_jax_funcs(out_node, fgraph_inputs, memo=None): def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
"""Compose JAX implementations of node operations. """Compose JAX implementations of node operations.
This function walks the graph given by the `Apply` node, `out_node`, and
creates JAX JIT-able functions for its input and output variables.
Parameters Parameters
---------- ----------
out_node: Node out_node: theano.gof.graph.Apply
The output node. The node for which we want to construct a JAX JIT-able function.
fgraph_inputs: List[Variable] fgraph_inputs: List[Variable]
The inputs--in a `FunctionGraph` sense--to `out_node`. The inputs--in a `FunctionGraph` sense--to `out_node`.
memo: Mapping (Optional) memo: Mapping (Optional)
...@@ -116,9 +119,13 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None): ...@@ -116,9 +119,13 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
jax_return_func = jax_funcify(out_node.op) jax_return_func = jax_funcify(out_node.op)
# We create a list of JAX-able functions that produce the values of each
# input variable for `out_node`.
input_funcs = [] input_funcs = []
for i in out_node.inputs: for i in out_node.inputs:
if i in fgraph_inputs: if i in fgraph_inputs:
# This input is a top-level input (i.e. an input to the
# `FunctionGraph` in which this `out_node` resides)
idx = fgraph_inputs.index(i) idx = fgraph_inputs.index(i)
i_dtype = getattr(i, "dtype", None) i_dtype = getattr(i, "dtype", None)
...@@ -129,6 +136,7 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None): ...@@ -129,6 +136,7 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
input_f = jax_inputs_func input_f = jax_inputs_func
elif i.owner is None: elif i.owner is None:
# This input is something like a `theano.gof.graph.Constant`
i_dtype = getattr(i, "dtype", None) i_dtype = getattr(i, "dtype", None)
i_data = i.data i_data = i.data
...@@ -141,8 +149,24 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None): ...@@ -141,8 +149,24 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
input_f = jax_data_func input_f = jax_data_func
else: else:
# This input is the output of another node, so we need to
# generate a JAX-able function for its subgraph
input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo) input_f = compose_jax_funcs(i.owner, fgraph_inputs, memo)
if i.owner.nout > 1:
# This input is one of multiple outputs from the `i.owner`
# node, and we need to determine exactly which one it is and
# create a JAX-able function that returns only it.
out_idx = i.owner.outputs.index(i)
(out_fn,) = input_f
def jax_multiout_func(*inputs, out_idx=out_idx, out_fn=out_fn):
return out_fn(*inputs)[out_idx]
input_f = jax_multiout_func
assert callable(input_f)
input_funcs.append(input_f) input_funcs.append(input_f)
if not isinstance(jax_return_func, Sequence): if not isinstance(jax_return_func, Sequence):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论