提交 1d19c375 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Allow non-shared untraced SIT-SOT

上级 d10c61ba
...@@ -60,23 +60,23 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -60,23 +60,23 @@ def jax_funcify_Scan(op: Scan, **kwargs):
mit_mot_init, mit_mot_init,
mit_sot_init, mit_sot_init,
sit_sot_init, sit_sot_init,
op.outer_shared(outer_inputs), op.outer_untraced_sit_sot(outer_inputs),
op.outer_non_seqs(outer_inputs), op.outer_non_seqs(outer_inputs),
) # JAX `init` ) # JAX `init`
def jax_args_to_inner_func_args(carry, x): def jax_args_to_inner_func_args(carry, x):
"""Convert JAX scan arguments into format expected by scan_inner_func. """Convert JAX scan arguments into format expected by scan_inner_func.
scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, shared, non_seqs) scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT, untraced SIT-SOT, non_seqs)
""" """
# `carry` contains all inner taps, shared terms, and non_seqs # `carry` contains all inner taps and non_seqs
( (
i, i,
inner_mit_mot, inner_mit_mot,
inner_mit_sot, inner_mit_sot,
inner_sit_sot, inner_sit_sot,
inner_shared, inner_untraced_sit_sot,
inner_non_seqs, inner_non_seqs,
) = carry ) = carry
...@@ -108,7 +108,7 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -108,7 +108,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
*mit_mot_flatten, *mit_mot_flatten,
*mit_sot_flatten, *mit_sot_flatten,
*inner_sit_sot, *inner_sit_sot,
*inner_shared, *inner_untraced_sit_sot,
*inner_non_seqs, *inner_non_seqs,
) )
...@@ -118,14 +118,14 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -118,14 +118,14 @@ def jax_funcify_Scan(op: Scan, **kwargs):
): ):
"""Convert inner_scan_func outputs into format expected by JAX scan. """Convert inner_scan_func outputs into format expected by JAX scan.
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, shared_outs) -> (new_carry, ys) old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs, untraced_SIT-SOT_outs) -> (new_carry, ys)
""" """
( (
i, i,
old_mit_mot, old_mit_mot,
old_mit_sot, old_mit_sot,
_old_sit_sot, _old_sit_sot,
_old_shared, _old_untraced_sit_sot,
inner_non_seqs, inner_non_seqs,
) = old_carry ) = old_carry
...@@ -133,7 +133,7 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -133,7 +133,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs) new_mit_sot_vals = op.inner_mitsot_outs(inner_scan_outs)
new_sit_sot = op.inner_sitsot_outs(inner_scan_outs) new_sit_sot = op.inner_sitsot_outs(inner_scan_outs)
new_nit_sot = op.inner_nitsot_outs(inner_scan_outs) new_nit_sot = op.inner_nitsot_outs(inner_scan_outs)
new_shared = op.inner_shared_outs(inner_scan_outs) new_untraced_sit_sot = op.inner_untraced_sit_sot_outs(inner_scan_outs)
# New carry for next step # New carry for next step
# Update MIT-MOT buffer at positions indicated by output taps # Update MIT-MOT buffer at positions indicated by output taps
...@@ -150,14 +150,14 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -150,14 +150,14 @@ def jax_funcify_Scan(op: Scan, **kwargs):
old_mit_sot, new_mit_sot_vals, strict=True old_mit_sot, new_mit_sot_vals, strict=True
) )
] ]
# For SIT-SOT, and shared just pass along the new value # For SIT-SOT just pass along the new value
# Non-sequences remain unchanged # Non-sequences remain unchanged
new_carry = ( new_carry = (
i + 1, i + 1,
new_mit_mot, new_mit_mot,
new_mit_sot, new_mit_sot,
new_sit_sot, new_sit_sot,
new_shared, new_untraced_sit_sot,
inner_non_seqs, inner_non_seqs,
) )
...@@ -183,7 +183,7 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -183,7 +183,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
final_mit_mot, final_mit_mot,
_final_mit_sot, _final_mit_sot,
_final_sit_sot, _final_sit_sot,
final_shared, final_untraced_sit_sot,
_final_non_seqs, _final_non_seqs,
), ),
traces, traces,
...@@ -238,7 +238,7 @@ def jax_funcify_Scan(op: Scan, **kwargs): ...@@ -238,7 +238,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
scan_outs_final = [ scan_outs_final = [
*final_mit_mot, *final_mit_mot,
*get_partial_traces(traces), *get_partial_traces(traces),
*final_shared, *final_untraced_sit_sot,
] ]
if len(scan_outs_final) == 1: if len(scan_outs_final) == 1:
......
...@@ -108,19 +108,19 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -108,19 +108,19 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
outer_in_mit_sot_names = op.outer_mitsot(outer_in_names) outer_in_mit_sot_names = op.outer_mitsot(outer_in_names)
outer_in_sit_sot_names = op.outer_sitsot(outer_in_names) outer_in_sit_sot_names = op.outer_sitsot(outer_in_names)
outer_in_nit_sot_names = op.outer_nitsot(outer_in_names) outer_in_nit_sot_names = op.outer_nitsot(outer_in_names)
outer_in_shared_names = op.outer_shared(outer_in_names) outer_in_untraced_sit_sot_names = op.outer_untraced_sit_sot(outer_in_names)
outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names) outer_in_non_seqs_names = op.outer_non_seqs(outer_in_names)
# These are all the outer-input names that have produce outputs/have output # These are all the outer-input names that have produce outputs/have output
# taps (i.e. they have inner-outputs and corresponding outer-outputs). # taps (i.e. they have inner-outputs and corresponding outer-outputs).
# Outer-outputs are ordered as follows: # Outer-outputs are ordered as follows:
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + shared-outputs # mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + untraced-sit-sot-outputs
outer_in_outtap_names = ( outer_in_outtap_names = (
outer_in_mit_mot_names outer_in_mit_mot_names
+ outer_in_mit_sot_names + outer_in_mit_sot_names
+ outer_in_sit_sot_names + outer_in_sit_sot_names
+ outer_in_nit_sot_names + outer_in_nit_sot_names
+ outer_in_shared_names + outer_in_untraced_sit_sot_names
) )
# We create distinct variables for/references to the storage arrays for # We create distinct variables for/references to the storage arrays for
...@@ -138,8 +138,10 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -138,8 +138,10 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
for outer_in_name in outer_in_nit_sot_names: for outer_in_name in outer_in_nit_sot_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_nitsot_storage" outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_nitsot_storage"
for outer_in_name in outer_in_shared_names: for outer_in_name in outer_in_untraced_sit_sot_names:
outer_in_to_storage_name[outer_in_name] = f"{outer_in_name}_shared_storage" outer_in_to_storage_name[outer_in_name] = (
f"{outer_in_name}_untraced_sit_sot_storage"
)
outer_output_names = list(outer_in_to_storage_name.values()) outer_output_names = list(outer_in_to_storage_name.values())
assert len(outer_output_names) == len(node.outputs) assert len(outer_output_names) == len(node.outputs)
...@@ -147,7 +149,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -147,7 +149,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
# Construct the inner-input expressions (e.g. indexed storage expressions) # Construct the inner-input expressions (e.g. indexed storage expressions)
# Inner-inputs are ordered as follows: # Inner-inputs are ordered as follows:
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + # sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
# shared-inputs + non-sequences. # untraced-sit-sot-inputs + non-sequences.
temp_scalar_storage_alloc_stmts: list[str] = [] temp_scalar_storage_alloc_stmts: list[str] = []
inner_in_exprs_scalar: list[str] = [] inner_in_exprs_scalar: list[str] = []
inner_in_exprs: list[str] = [] inner_in_exprs: list[str] = []
...@@ -204,11 +206,9 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -204,11 +206,9 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
# Inner-outputs consist of: # Inner-outputs consist of:
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots + # mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
# shared-outputs [+ while-condition] # untraced-sit-sot-outputs [+ while-condition]
inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))] inner_output_names = [f"inner_out_{i}" for i in range(len(op.inner_outputs))]
# inner_out_shared_names = op.inner_shared_outs(inner_output_names)
# The assignment statements that copy inner-outputs into the outer-outputs # The assignment statements that copy inner-outputs into the outer-outputs
# storage # storage
inner_out_to_outer_in_stmts: list[str] = [] inner_out_to_outer_in_stmts: list[str] = []
......
差异被折叠。
差异被折叠。
...@@ -110,7 +110,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -110,7 +110,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
sum(len(x) for x in chain(op_info.mit_mot_in_slices, op_info.mit_sot_in_slices)) sum(len(x) for x in chain(op_info.mit_mot_in_slices, op_info.mit_sot_in_slices))
) )
st += op_info.n_sit_sot st += op_info.n_sit_sot
st += op_info.n_shared_outs st += op_info.n_untraced_sit_sot_outs
op_ins = op.inner_inputs op_ins = op.inner_inputs
op_outs = op.inner_outputs op_outs = op.inner_outputs
...@@ -126,7 +126,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -126,7 +126,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
+ op_info.n_mit_sot + op_info.n_mit_sot
+ op_info.n_sit_sot + op_info.n_sit_sot
+ op_info.n_nit_sot + op_info.n_nit_sot
+ op_info.n_shared_outs + op_info.n_untraced_sit_sot_outs
+ 1 + 1
) )
outer_non_seqs = node.inputs[st:] outer_non_seqs = node.inputs[st:]
...@@ -983,7 +983,7 @@ class ScanInplaceOptimizer(GraphRewriter): ...@@ -983,7 +983,7 @@ class ScanInplaceOptimizer(GraphRewriter):
ls = op.outer_mitmot(node.inputs) ls = op.outer_mitmot(node.inputs)
ls += op.outer_mitsot(node.inputs) ls += op.outer_mitsot(node.inputs)
ls += op.outer_sitsot(node.inputs) ls += op.outer_sitsot(node.inputs)
ls_end = op.outer_shared(node.inputs) ls_end = op.outer_untraced_sit_sot(node.inputs)
ls_end += op.outer_nitsot(node.inputs) ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs) ls_end += op.outer_non_seqs(node.inputs)
...@@ -1628,7 +1628,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1628,7 +1628,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
+ idx + idx
+ op_info.n_seqs + op_info.n_seqs
+ 1 + 1
+ op_info.n_shared_outs + op_info.n_untraced_sit_sot_outs
) )
if nw_inputs[pos] == node.inputs[0]: if nw_inputs[pos] == node.inputs[0]:
nw_inputs[pos] = 1 if required_orphan else val nw_inputs[pos] = 1 if required_orphan else val
...@@ -1662,7 +1662,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: ...@@ -1662,7 +1662,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
elif ( elif (
idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot
): ):
in_idx = offset + idx + op_info.n_shared_outs in_idx = offset + idx + op_info.n_untraced_sit_sot_outs
if nw_inputs[in_idx] == node.inputs[0]: if nw_inputs[in_idx] == node.inputs[0]:
nw_inputs[in_idx] = nw_steps nw_inputs[in_idx] = nw_steps
...@@ -1886,8 +1886,8 @@ class ScanMerge(GraphRewriter): ...@@ -1886,8 +1886,8 @@ class ScanMerge(GraphRewriter):
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Shared # Shared
inner_ins[idx].append(nd.op.inner_shared(nd.op.inner_inputs)) inner_ins[idx].append(nd.op.inner_untraced_sit_sot(nd.op.inner_inputs))
outer_ins += nd.op.outer_shared(nd.inputs) outer_ins += nd.op.outer_untraced_sit_sot(nd.inputs)
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# NitSot # NitSot
...@@ -1897,8 +1897,10 @@ class ScanMerge(GraphRewriter): ...@@ -1897,8 +1897,10 @@ class ScanMerge(GraphRewriter):
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
# Shared # Shared
outer_outs += nd.op.outer_shared_outs(nd.outputs) outer_outs += nd.op.outer_untraced_sit_sot_outs(nd.outputs)
inner_outs[idx].append(nd.op.inner_shared_outs(nd.op.inner_outputs)) inner_outs[idx].append(
nd.op.inner_untraced_sit_sot_outs(nd.op.inner_outputs)
)
n_non_seqs = 0 n_non_seqs = 0
for idx, nd in enumerate(nodes): for idx, nd in enumerate(nodes):
...@@ -1978,7 +1980,9 @@ class ScanMerge(GraphRewriter): ...@@ -1978,7 +1980,9 @@ class ScanMerge(GraphRewriter):
mit_sot_in_slices=mit_sot_in_slices, mit_sot_in_slices=mit_sot_in_slices,
sit_sot_in_slices=sit_sot_in_slices, sit_sot_in_slices=sit_sot_in_slices,
n_nit_sot=sum(nd.op.info.n_nit_sot for nd in nodes), n_nit_sot=sum(nd.op.info.n_nit_sot for nd in nodes),
n_shared_outs=sum(nd.op.info.n_shared_outs for nd in nodes), n_untraced_sit_sot_outs=sum(
nd.op.info.n_untraced_sit_sot_outs for nd in nodes
),
n_non_seqs=n_non_seqs, n_non_seqs=n_non_seqs,
as_while=as_while, as_while=as_while,
) )
...@@ -2360,7 +2364,7 @@ def scan_push_out_dot1(fgraph, node): ...@@ -2360,7 +2364,7 @@ def scan_push_out_dot1(fgraph, node):
# When seq[t] is a vector/matrix and `value` is a matrix # When seq[t] is a vector/matrix and `value` is a matrix
# Note that this works when only you need X[-1] in the end # Note that this works when only you need X[-1] in the end
# and assumes dimshuffle are applied to vectors before calling dot # and assumes dimshuffle are applied to vectors before calling dot
op = node.op op: Scan = node.op
sitsot_ins = op.inner_sitsot(op.inner_inputs) sitsot_ins = op.inner_sitsot(op.inner_inputs)
sitsot_outs = op.inner_sitsot_outs(op.inner_outputs) sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)
outer_sitsot = op.outer_sitsot_outs(node.outputs) outer_sitsot = op.outer_sitsot_outs(node.outputs)
...@@ -2416,9 +2420,13 @@ def scan_push_out_dot1(fgraph, node): ...@@ -2416,9 +2420,13 @@ def scan_push_out_dot1(fgraph, node):
inner_sitsot_outs = op.inner_sitsot_outs(op.inner_outputs) inner_sitsot_outs = op.inner_sitsot_outs(op.inner_outputs)
outer_nitsot = op.outer_nitsot(node.inputs) outer_nitsot = op.outer_nitsot(node.inputs)
inner_nitsot_outs = op.inner_nitsot_outs(op.inner_outputs) inner_nitsot_outs = op.inner_nitsot_outs(op.inner_outputs)
inner_shared = op.inner_shared(op.inner_inputs) inner_untraced_sitsot = op.inner_untraced_sitsot(op.inner_inputs)
outer_shared = op.outer_shared(node.inputs) outer_untraced_sitsot_outs = op.outer_untraced_sitsot_outs(
inner_shared_outs = op.inner_shared_outs(op.inner_outputs) node.inputs
)
inner_untraced_sitsot_outs = op.inner_untraced_sitsot_outs(
op.inner_outputs
)
inner_non_seqs = op.inner_non_seqs(op.inner_inputs) inner_non_seqs = op.inner_non_seqs(op.inner_inputs)
outer_non_seqs = op.outer_non_seqs(node.inputs) outer_non_seqs = op.outer_non_seqs(node.inputs)
...@@ -2441,7 +2449,7 @@ def scan_push_out_dot1(fgraph, node): ...@@ -2441,7 +2449,7 @@ def scan_push_out_dot1(fgraph, node):
+ inner_mitmot + inner_mitmot
+ inner_mitsot + inner_mitsot
+ inner_sitsot + inner_sitsot
+ inner_shared + inner_untraced_sitsot
+ inner_non_seqs + inner_non_seqs
) )
_new_inner_outs = ( _new_inner_outs = (
...@@ -2449,7 +2457,7 @@ def scan_push_out_dot1(fgraph, node): ...@@ -2449,7 +2457,7 @@ def scan_push_out_dot1(fgraph, node):
+ inner_mitsot_outs + inner_mitsot_outs
+ inner_sitsot_outs + inner_sitsot_outs
+ inner_nitsot_outs + inner_nitsot_outs
+ inner_shared_outs + inner_untraced_sitsot_outs
) )
new_inner_inps, new_inner_outs = reconstruct_graph( new_inner_inps, new_inner_outs = reconstruct_graph(
_new_inner_inps, _new_inner_outs _new_inner_inps, _new_inner_outs
...@@ -2471,7 +2479,7 @@ def scan_push_out_dot1(fgraph, node): ...@@ -2471,7 +2479,7 @@ def scan_push_out_dot1(fgraph, node):
*outer_mitmot, *outer_mitmot,
*outer_mitsot, *outer_mitsot,
*outer_sitsot, *outer_sitsot,
*outer_shared, *outer_untraced_sitsot_outs,
*outer_nitsot, *outer_nitsot,
node.inputs[0], node.inputs[0],
*outer_non_seqs, *outer_non_seqs,
......
...@@ -370,7 +370,9 @@ def scan_can_remove_outs(op, out_idxs): ...@@ -370,7 +370,9 @@ def scan_can_remove_outs(op, out_idxs):
out_ins += [op.inner_inputs[offset : offset + n_ins]] out_ins += [op.inner_inputs[offset : offset + n_ins]]
offset += n_ins offset += n_ins
out_ins += [[] for k in range(op.info.n_nit_sot)] out_ins += [[] for k in range(op.info.n_nit_sot)]
out_ins += [[op.inner_inputs[offset + k]] for k in range(op.info.n_shared_outs)] out_ins += [
[op.inner_inputs[offset + k]] for k in range(op.info.n_untraced_sit_sot_outs)
]
added = True added = True
out_idxs_mask = [1 for idx in out_idxs] out_idxs_mask = [1 for idx in out_idxs]
...@@ -409,7 +411,7 @@ def compress_outs(op, not_required, inputs): ...@@ -409,7 +411,7 @@ def compress_outs(op, not_required, inputs):
mit_sot_in_slices=(), mit_sot_in_slices=(),
sit_sot_in_slices=(), sit_sot_in_slices=(),
n_nit_sot=0, n_nit_sot=0,
n_shared_outs=0, n_untraced_sit_sot_outs=0,
n_non_seqs=0, n_non_seqs=0,
as_while=op_info.as_while, as_while=op_info.as_while,
) )
...@@ -515,17 +517,19 @@ def compress_outs(op, not_required, inputs): ...@@ -515,17 +517,19 @@ def compress_outs(op, not_required, inputs):
info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1) info = dataclasses.replace(info, n_nit_sot=info.n_nit_sot + 1)
op_outputs += [op.inner_outputs[o_offset]] op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1 o_offset += 1
nit_sot_ins += [inputs[ni_offset + idx + op_info.n_shared_outs]] nit_sot_ins += [inputs[ni_offset + idx + op_info.n_untraced_sit_sot_outs]]
else: else:
o_offset += 1 o_offset += 1
offset += op_info.n_nit_sot offset += op_info.n_nit_sot
shared_ins = [] shared_ins = []
for idx in range(op_info.n_shared_outs): for idx in range(op_info.n_untraced_sit_sot_outs):
if offset + idx not in not_required: if offset + idx not in not_required:
map_old_new[offset + idx] = curr_pos map_old_new[offset + idx] = curr_pos
curr_pos += 1 curr_pos += 1
info = dataclasses.replace(info, n_shared_outs=info.n_shared_outs + 1) info = dataclasses.replace(
info, n_untraced_sit_sot_outs=info.n_untraced_sit_sot_outs + 1
)
op_outputs += [op.inner_outputs[o_offset]] op_outputs += [op.inner_outputs[o_offset]]
o_offset += 1 o_offset += 1
op_inputs += [op.inner_inputs[i_offset]] op_inputs += [op.inner_inputs[i_offset]]
...@@ -539,7 +543,9 @@ def compress_outs(op, not_required, inputs): ...@@ -539,7 +543,9 @@ def compress_outs(op, not_required, inputs):
# other stuff # other stuff
op_inputs += op.inner_inputs[i_offset:] op_inputs += op.inner_inputs[i_offset:]
info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:])) info = dataclasses.replace(info, n_non_seqs=len(op.inner_inputs[i_offset:]))
node_inputs += inputs[ni_offset + op_info.n_shared_outs + op_info.n_nit_sot :] node_inputs += inputs[
ni_offset + op_info.n_untraced_sit_sot_outs + op_info.n_nit_sot :
]
if op_info.as_while: if op_info.as_while:
op_outputs += [op.inner_outputs[o_offset]] op_outputs += [op.inner_outputs[o_offset]]
map_old_new[o_offset] = len(op_outputs) - 1 map_old_new[o_offset] = len(op_outputs) - 1
...@@ -658,11 +664,11 @@ class ScanArgs: ...@@ -658,11 +664,11 @@ class ScanArgs:
p += n_sit_sot p += n_sit_sot
q += n_sit_sot q += n_sit_sot
n_shared_outs = info.n_shared_outs n_untraced_sit_sot_outs = info.n_untraced_sit_sot_outs
self.outer_in_shared = list(outer_inputs[p : p + n_shared_outs]) self.outer_in_shared = list(outer_inputs[p : p + n_untraced_sit_sot_outs])
self.inner_in_shared = list(inner_inputs[q : q + n_shared_outs]) self.inner_in_shared = list(inner_inputs[q : q + n_untraced_sit_sot_outs])
p += n_shared_outs p += n_untraced_sit_sot_outs
q += n_shared_outs q += n_untraced_sit_sot_outs
n_nit_sot = info.n_nit_sot n_nit_sot = info.n_nit_sot
self.outer_in_nit_sot = list(outer_inputs[p : p + n_nit_sot]) self.outer_in_nit_sot = list(outer_inputs[p : p + n_nit_sot])
...@@ -702,10 +708,10 @@ class ScanArgs: ...@@ -702,10 +708,10 @@ class ScanArgs:
p += n_nit_sot p += n_nit_sot
q += n_nit_sot q += n_nit_sot
self.outer_out_shared = list(outer_outputs[p : p + n_shared_outs]) self.outer_out_shared = list(outer_outputs[p : p + n_untraced_sit_sot_outs])
self.inner_out_shared = list(inner_outputs[q : q + n_shared_outs]) self.inner_out_shared = list(inner_outputs[q : q + n_untraced_sit_sot_outs])
p += n_shared_outs p += n_untraced_sit_sot_outs
q += n_shared_outs q += n_untraced_sit_sot_outs
assert p == len(outer_outputs) assert p == len(outer_outputs)
assert q == len(inner_outputs) assert q == len(inner_outputs)
...@@ -816,7 +822,7 @@ class ScanArgs: ...@@ -816,7 +822,7 @@ class ScanArgs:
mit_sot_in_slices=tuple(tuple(v) for v in self.mit_sot_in_slices), mit_sot_in_slices=tuple(tuple(v) for v in self.mit_sot_in_slices),
sit_sot_in_slices=((-1,),) * len(self.inner_in_sit_sot), sit_sot_in_slices=((-1,),) * len(self.inner_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot), n_nit_sot=len(self.outer_in_nit_sot),
n_shared_outs=len(self.outer_in_shared), n_untraced_sit_sot_outs=len(self.outer_in_shared),
n_non_seqs=len(self.inner_in_non_seqs), n_non_seqs=len(self.inner_in_non_seqs),
as_while=self.as_while, as_while=self.as_while,
) )
......
...@@ -85,7 +85,7 @@ from tests.link.numba.test_basic import compare_numba_and_py ...@@ -85,7 +85,7 @@ from tests.link.numba.test_basic import compare_numba_and_py
3, 3,
[], [],
[np.array([0.50100236, 2.16822932, 1.36326596])], [np.array([0.50100236, 2.16822932, 1.36326596])],
lambda op: op.info.n_shared_outs > 0, lambda op: op.info.n_untraced_sit_sot_outs > 0,
), ),
# mit-sot (that's also a type of sit-sot) # mit-sot (that's also a type of sit-sot)
( (
......
...@@ -42,11 +42,13 @@ from pytensor.tensor.math import all as pt_all ...@@ -42,11 +42,13 @@ from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import dot, exp, mean, sigmoid, tanh from pytensor.tensor.math import dot, exp, mean, sigmoid, tanh
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.random import normal from pytensor.tensor.random import normal
from pytensor.tensor.random.type import RandomGeneratorType, random_generator_type
from pytensor.tensor.random.utils import RandomStream from pytensor.tensor.random.utils import RandomStream
from pytensor.tensor.shape import Shape_i, reshape, specify_shape from pytensor.tensor.shape import Shape_i, reshape, specify_shape
from pytensor.tensor.sharedvar import SharedVariable from pytensor.tensor.sharedvar import SharedVariable
from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.subtensor import Subtensor
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType,
dcol, dcol,
dmatrix, dmatrix,
dscalar, dscalar,
...@@ -4007,7 +4009,7 @@ class TestExamples: ...@@ -4007,7 +4009,7 @@ class TestExamples:
[{}], [{}],
[], [],
3, 3,
lambda op: op.info.n_shared_outs > 0, lambda op: op.info.n_untraced_sit_sot_outs > 0,
), ),
# mit-sot (that's also a type of sit-sot) # mit-sot (that's also a type of sit-sot)
( (
...@@ -4106,3 +4108,34 @@ def test_output_storage_reuse(linker_mode): ...@@ -4106,3 +4108,34 @@ def test_output_storage_reuse(linker_mode):
res = f_cvm() res = f_cvm()
assert np.array_equal(res, np.array([3, 1, 0])) assert np.array_equal(res, np.array([3, 1, 0]))
def test_rng_outputs_info():
rng_init = random_generator_type("rng")
rng_x0, x0 = pt.random.normal(0, rng=rng_init, dtype="float64").owner.outputs
def step(prev_x, prev_rng):
next_rng, next_x = pt.random.normal(
prev_x, rng=prev_rng, dtype="float64"
).owner.outputs
return next_x, next_rng
[xs, rng_final], updates = scan(
fn=step,
outputs_info=[x0, rng_x0],
n_steps=10,
)
assert isinstance(xs.type, TensorType)
assert isinstance(rng_final.type, RandomGeneratorType)
assert not updates
fn = function([rng_init], [xs, rng_final])
xs_eval, rng_final_eval = fn(np.random.default_rng(0))
rng_ref = np.random.default_rng(0)
assert not random_generator_type.values_eq(rng_ref, rng_final_eval)
xs_ref = [rng_ref.normal(0)]
for i in range(10):
xs_ref.append(rng_ref.normal(xs_ref[-1]))
assert random_generator_type.values_eq(rng_ref, rng_final_eval)
np.testing.assert_allclose(xs_eval, xs_ref[1:])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论