Unverified 提交 0c203e99 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: GitHub

Merge pull request #169 from junpenglao/jax_scan

Implement a JAX conversion for the Scan Op
差异被折叠。
...@@ -177,6 +177,7 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None): ...@@ -177,6 +177,7 @@ def compose_jax_funcs(out_node, fgraph_inputs, memo=None):
def jax_func(*inputs): def jax_func(*inputs):
func_args = [fn(*inputs) for fn in input_funcs] func_args = [fn(*inputs) for fn in input_funcs]
# func_args = jax.tree_map(lambda fn: fn(*inputs), input_funcs)
return return_func(*func_args) return return_func(*func_args)
jax_funcs.append(update_wrapper(jax_func, return_func)) jax_funcs.append(update_wrapper(jax_func, return_func))
...@@ -420,7 +421,7 @@ def jax_funcify_Scan(op): ...@@ -420,7 +421,7 @@ def jax_funcify_Scan(op):
def scan(*outer_inputs): def scan(*outer_inputs):
scan_args = ScanArgs( scan_args = ScanArgs(
outer_inputs, [None] * op.n_outs, op.inputs, op.outputs, op.info list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info
) )
# `outer_inputs` is a list with the following composite form: # `outer_inputs` is a list with the following composite form:
...@@ -435,9 +436,9 @@ def jax_funcify_Scan(op): ...@@ -435,9 +436,9 @@ def jax_funcify_Scan(op):
n_steps = scan_args.n_steps n_steps = scan_args.n_steps
seqs = scan_args.outer_in_seqs seqs = scan_args.outer_in_seqs
n_non_seqs = len(scan_args.outer_in_non_seqs) # TODO: mit_mots
mit_mot_in_slices = []
# TODO: sit_sots
mit_sot_in_slices = [] mit_sot_in_slices = []
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot): for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
neg_taps = [abs(t) for t in tap if t < 0] neg_taps = [abs(t) for t in tap if t < 0]
...@@ -447,7 +448,15 @@ def jax_funcify_Scan(op): ...@@ -447,7 +448,15 @@ def jax_funcify_Scan(op):
init_slice = seq[: max_neg + max_pos] init_slice = seq[: max_neg + max_pos]
mit_sot_in_slices.append(init_slice) mit_sot_in_slices.append(init_slice)
init_carry = [mit_sot_in_slices, scan_args.outer_in_non_seqs] sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]
init_carry = (
mit_mot_in_slices,
mit_sot_in_slices,
sit_sot_in_slices,
scan_args.outer_in_shared,
scan_args.outer_in_non_seqs,
)
def jax_args_to_inner_scan(op, carry, x): def jax_args_to_inner_scan(op, carry, x):
# `carry` contains all inner-output taps, non_seqs, and shared # `carry` contains all inner-output taps, non_seqs, and shared
...@@ -470,15 +479,22 @@ def jax_funcify_Scan(op): ...@@ -470,15 +479,22 @@ def jax_funcify_Scan(op):
# + inner_in_sit_sot # + inner_in_sit_sot
# + inner_in_shared # + inner_in_shared
# + inner_in_non_seqs # + inner_in_non_seqs
inner_scan_inputs = [ inner_in_mit_sot_flatten = []
inner_in_seqs, for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
inner_in_mit_mot, inner_in_mit_sot_flatten.extend(array[index])
inner_in_mit_sot,
inner_in_sit_sot, inner_scan_inputs = sum(
inner_in_non_seqs, [
] inner_in_seqs,
inner_in_mit_mot,
inner_in_mit_sot_flatten,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
],
[],
)
raise NotImplementedError()
return inner_scan_inputs return inner_scan_inputs
def inner_scan_outs_to_jax_outs( def inner_scan_outs_to_jax_outs(
...@@ -486,47 +502,66 @@ def jax_funcify_Scan(op): ...@@ -486,47 +502,66 @@ def jax_funcify_Scan(op):
old_carry, old_carry,
inner_scan_outs, inner_scan_outs,
): ):
# `inner_scan_outs` is a list with the following
# composite form:
# outer_out_mit_mot
# + outer_out_mit_sot
# + outer_out_sit_sot
# + outer_out_nit_sot
# + outer_out_shared
# + cond
( (
outer_out_mit_mot, inner_in_mit_mot,
outer_out_mit_sot, inner_in_mit_sot,
outer_out_sit_sot, inner_in_sit_sot,
outer_out_nit_sot, inner_in_shared,
outer_out_shared, inner_in_non_seqs,
cond, ) = old_carry
) = inner_scan_outs
outer_out_non_seqs = old_carry[:-n_non_seqs] def update_mit_sot(mit_sot, new_val):
return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0)
inner_out_mit_sot = [
update_mit_sot(mit_sot, new_val)
for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
]
# This should contain all inner-output taps, non_seqs, and shared # This should contain all inner-output taps, non_seqs, and shared
# terms # terms
carry = [ if not inner_in_sit_sot:
outer_out_mit_mot, inner_out_sit_sot = []
outer_out_mit_sot, else:
outer_out_sit_sot, inner_out_sit_sot = inner_scan_outs
outer_out_shared, new_carry = (
outer_out_non_seqs, inner_in_mit_mot,
] inner_out_mit_sot,
# This should contain all inner-outputs that produce inner_out_sit_sot,
# outer-outputs inner_in_shared,
y = [] inner_in_non_seqs,
)
raise NotImplementedError() return new_carry
return (carry, y)
def jax_inner_func(carry, x): def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_scan(op, carry, x) inner_args = jax_args_to_inner_scan(op, carry, x)
inner_scan_outs = jax_tt_inner_func(*inner_args) inner_scan_outs = [fn(*inner_args) for fn in jax_tt_inner_func]
new_carry, y = inner_scan_outs_to_jax_outs(op, inner_scan_outs) new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
return new_carry, y return new_carry, inner_scan_outs
_, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
# We need to prepend the initial values so that the JAX output will
# match the raw `Scan` `Op` output and, thus, work with a downstream
# `Subtensor` `Op` introduced by the `scan` helper function.
def append_scan_out(scan_in_part, scan_out_part):
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)
if scan_args.outer_in_mit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
]
elif scan_args.outer_in_sit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
]
return jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps) if len(scan_out_final) == 1:
scan_out_final = scan_out_final[0]
return scan_out_final
return scan return scan
......
...@@ -1075,8 +1075,9 @@ class scan_args: ...@@ -1075,8 +1075,9 @@ class scan_args:
if k in info: if k in info:
self.other_info[k] = info[k] self.other_info[k] = info[k]
inner_inputs = property( @property
lambda self: ( def inner_inputs(self):
return (
self.inner_in_seqs self.inner_in_seqs
+ sum(self.inner_in_mit_mot, []) + sum(self.inner_in_mit_mot, [])
+ sum(self.inner_in_mit_sot, []) + sum(self.inner_in_mit_sot, [])
...@@ -1084,10 +1085,10 @@ class scan_args: ...@@ -1084,10 +1085,10 @@ class scan_args:
+ self.inner_in_shared + self.inner_in_shared
+ self.inner_in_non_seqs + self.inner_in_non_seqs
) )
)
outer_inputs = property( @property
lambda self: ( def outer_inputs(self):
return (
[self.n_steps] [self.n_steps]
+ self.outer_in_seqs + self.outer_in_seqs
+ self.outer_in_mit_mot + self.outer_in_mit_mot
...@@ -1097,10 +1098,10 @@ class scan_args: ...@@ -1097,10 +1098,10 @@ class scan_args:
+ self.outer_in_nit_sot + self.outer_in_nit_sot
+ self.outer_in_non_seqs + self.outer_in_non_seqs
) )
)
inner_outputs = property( @property
lambda self: ( def inner_outputs(self):
return (
sum(self.inner_out_mit_mot, []) sum(self.inner_out_mit_mot, [])
+ self.inner_out_mit_sot + self.inner_out_mit_sot
+ self.inner_out_sit_sot + self.inner_out_sit_sot
...@@ -1108,20 +1109,20 @@ class scan_args: ...@@ -1108,20 +1109,20 @@ class scan_args:
+ self.inner_out_shared + self.inner_out_shared
+ self.cond + self.cond
) )
)
outer_outputs = property( @property
lambda self: ( def outer_outputs(self):
return (
self.outer_out_mit_mot self.outer_out_mit_mot
+ self.outer_out_mit_sot + self.outer_out_mit_sot
+ self.outer_out_sit_sot + self.outer_out_sit_sot
+ self.outer_out_nit_sot + self.outer_out_nit_sot
+ self.outer_out_shared + self.outer_out_shared
) )
)
info = property( @property
lambda self: OrderedDict( def info(self):
return OrderedDict(
n_seqs=len(self.outer_in_seqs), n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot), n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot), n_mit_sot=len(self.outer_in_mit_sot),
...@@ -1137,7 +1138,6 @@ class scan_args: ...@@ -1137,7 +1138,6 @@ class scan_args:
mit_mot_out_slices=self.mit_mot_out_slices, mit_mot_out_slices=self.mit_mot_out_slices,
**self.other_info, **self.other_info,
) )
)
def __copy__(self): def __copy__(self):
res = object.__new__(type(self)) res = object.__new__(type(self))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论