Unverified 提交 0c203e99 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #169 from junpenglao/jax_scan

Implement a JAX conversion for the Scan Op
差异被折叠。
......@@ -177,6 +177,7 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
def jax_func(*inputs):
func_args = [fn(*inputs) for fn in input_funcs]
# func_args = jax.tree_map(lambda fn: fn(*inputs), input_funcs)
return return_func(*func_args)
jax_funcs.append(update_wrapper(jax_func, return_func))
......@@ -420,7 +421,7 @@ def jax_funcify_Scan(op):
def scan(*outer_inputs):
scan_args = ScanArgs(
outer_inputs, [None] * op.n_outs, op.inputs, op.outputs, op.info
list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info
)
# `outer_inputs` is a list with the following composite form:
......@@ -435,9 +436,9 @@ def jax_funcify_Scan(op):
n_steps = scan_args.n_steps
seqs = scan_args.outer_in_seqs
n_non_seqs = len(scan_args.outer_in_non_seqs)
# TODO: mit_mots
mit_mot_in_slices = []
# TODO: sit_sots
mit_sot_in_slices = []
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
neg_taps = [abs(t) for t in tap if t < 0]
......@@ -447,7 +448,15 @@ def jax_funcify_Scan(op):
init_slice = seq[: max_neg + max_pos]
mit_sot_in_slices.append(init_slice)
init_carry = [mit_sot_in_slices, scan_args.outer_in_non_seqs]
sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]
init_carry = (
mit_mot_in_slices,
mit_sot_in_slices,
sit_sot_in_slices,
scan_args.outer_in_shared,
scan_args.outer_in_non_seqs,
)
def jax_args_to_inner_scan(op, carry, x):
# `carry` contains all inner-output taps, non_seqs, and shared
......@@ -470,15 +479,22 @@ def jax_funcify_Scan(op):
# + inner_in_sit_sot
# + inner_in_shared
# + inner_in_non_seqs
inner_scan_inputs = [
inner_in_seqs,
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_non_seqs,
]
inner_in_mit_sot_flatten = []
for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
inner_in_mit_sot_flatten.extend(array[index])
inner_scan_inputs = sum(
[
inner_in_seqs,
inner_in_mit_mot,
inner_in_mit_sot_flatten,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
],
[],
)
raise NotImplementedError()
return inner_scan_inputs
def inner_scan_outs_to_jax_outs(
......@@ -486,47 +502,66 @@ def jax_funcify_Scan(op):
old_carry,
inner_scan_outs,
):
# `inner_scan_outs` is a list with the following
# composite form:
# outer_out_mit_mot
# + outer_out_mit_sot
# + outer_out_sit_sot
# + outer_out_nit_sot
# + outer_out_shared
# + cond
(
outer_out_mit_mot,
outer_out_mit_sot,
outer_out_sit_sot,
outer_out_nit_sot,
outer_out_shared,
cond,
) = inner_scan_outs
outer_out_non_seqs = old_carry[:-n_non_seqs]
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
) = old_carry
def update_mit_sot(mit_sot, new_val):
return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0)
inner_out_mit_sot = [
update_mit_sot(mit_sot, new_val)
for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
]
# This should contain all inner-output taps, non_seqs, and shared
# terms
carry = [
outer_out_mit_mot,
outer_out_mit_sot,
outer_out_sit_sot,
outer_out_shared,
outer_out_non_seqs,
]
# This should contain all inner-outputs that produce
# outer-outputs
y = []
if not inner_in_sit_sot:
inner_out_sit_sot = []
else:
inner_out_sit_sot = inner_scan_outs
new_carry = (
inner_in_mit_mot,
inner_out_mit_sot,
inner_out_sit_sot,
inner_in_shared,
inner_in_non_seqs,
)
raise NotImplementedError()
return (carry, y)
return new_carry
def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_scan(op, carry, x)
inner_scan_outs = jax_tt_inner_func(*inner_args)
new_carry, y = inner_scan_outs_to_jax_outs(op, inner_scan_outs)
return new_carry, y
inner_scan_outs = [fn(*inner_args) for fn in jax_tt_inner_func]
new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
return new_carry, inner_scan_outs
_, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
# We need to prepend the initial values so that the JAX output will
# match the raw `Scan` `Op` output and, thus, work with a downstream
# `Subtensor` `Op` introduced by the `scan` helper function.
def append_scan_out(scan_in_part, scan_out_part):
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)
if scan_args.outer_in_mit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
]
elif scan_args.outer_in_sit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
]
return jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
if len(scan_out_final) == 1:
scan_out_final = scan_out_final[0]
return scan_out_final
return scan
......
......@@ -1075,8 +1075,9 @@ class scan_args:
if k in info:
self.other_info[k] = info[k]
inner_inputs = property(
lambda self: (
@property
def inner_inputs(self):
return (
self.inner_in_seqs
+ sum(self.inner_in_mit_mot, [])
+ sum(self.inner_in_mit_sot, [])
......@@ -1084,10 +1085,10 @@ class scan_args:
+ self.inner_in_shared
+ self.inner_in_non_seqs
)
)
outer_inputs = property(
lambda self: (
@property
def outer_inputs(self):
return (
[self.n_steps]
+ self.outer_in_seqs
+ self.outer_in_mit_mot
......@@ -1097,10 +1098,10 @@ class scan_args:
+ self.outer_in_nit_sot
+ self.outer_in_non_seqs
)
)
inner_outputs = property(
lambda self: (
@property
def inner_outputs(self):
return (
sum(self.inner_out_mit_mot, [])
+ self.inner_out_mit_sot
+ self.inner_out_sit_sot
......@@ -1108,20 +1109,20 @@ class scan_args:
+ self.inner_out_shared
+ self.cond
)
)
outer_outputs = property(
lambda self: (
@property
def outer_outputs(self):
return (
self.outer_out_mit_mot
+ self.outer_out_mit_sot
+ self.outer_out_sit_sot
+ self.outer_out_nit_sot
+ self.outer_out_shared
)
)
info = property(
lambda self: OrderedDict(
@property
def info(self):
return OrderedDict(
n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot),
......@@ -1137,7 +1138,6 @@ class scan_args:
mit_mot_out_slices=self.mit_mot_out_slices,
**self.other_info,
)
)
def __copy__(self):
res = object.__new__(type(self))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论