提交 ecf77a1f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add inner/outer-input/output counts to ScanInfo

上级 76585fe1
...@@ -221,6 +221,54 @@ class ScanInfo: ...@@ -221,6 +221,54 @@ class ScanInfo:
def tap_array(self): def tap_array(self):
return self.mit_mot_in_slices + self.mit_sot_in_slices + self.sit_sot_in_slices return self.mit_mot_in_slices + self.mit_sot_in_slices + self.sit_sot_in_slices
@property
def n_inner_inputs(self):
n_mit_mot_taps = sum(len(x) for x in self.mit_mot_in_slices)
n_mit_sot_taps = sum(len(x) for x in self.mit_sot_in_slices)
return (
self.n_seqs
+ n_mit_mot_taps
+ n_mit_sot_taps
+ self.n_sit_sot
+ self.n_shared_outs
+ self.n_non_seqs
)
@property
def n_inner_outputs(self):
n_mit_mot_out_taps = sum(len(x) for x in self.mit_mot_out_slices)
return (
n_mit_mot_out_taps
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_nit_sot
+ self.n_shared_outs
+ int(self.as_while)
)
@property
def n_outer_inputs(self):
return (
1
+ self.n_seqs
+ self.n_mit_mot
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_nit_sot
+ self.n_shared_outs
+ self.n_non_seqs
)
@property
def n_outer_outputs(self):
return (
self.n_mit_mot
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_nit_sot
+ self.n_shared_outs
)
TensorConstructorType = Callable[[List[bool], Union[str, np.generic]], TensorType] TensorConstructorType = Callable[[List[bool], Union[str, np.generic]], TensorType]
...@@ -794,27 +842,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -794,27 +842,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.mitmots_preallocated, self.mitmots_preallocated,
) = self._mitmot_preallocations() ) = self._mitmot_preallocations()
# The total number of inputs across all multi-input taps self.n_outer_inputs = info.n_outer_inputs
# `tap_array = mit_sot_tap_inputs + (-1,) * n_sit_sot` self.n_outer_outputs = info.n_outer_outputs
# n_mit_mot_sot_inputs = sum(len(x) for x in info.tap_array[: (info.n_mit_mot + info.n_mit_sot)])
n_mit_mot_sot_inputs = info.n_mit_mot + info.n_mit_sot
# [n_steps] + sequences + mit-mots + mit-sots + sit-sots + shared-variables + nit-sots + non-sequences
self.n_outer_inputs = (
1
+ info.n_seqs
+ n_mit_mot_sot_inputs
+ info.n_sit_sot
+ info.n_nit_sot
+ info.n_shared_outs
+ info.n_non_seqs
)
self.n_outer_outputs = (
info.n_mit_mot
+ info.n_mit_sot
+ info.n_sit_sot
+ info.n_nit_sot
+ info.n_shared_outs
)
def _mitmot_preallocations(self): def _mitmot_preallocations(self):
if config.scan__allow_output_prealloc: if config.scan__allow_output_prealloc:
......
...@@ -4087,5 +4087,7 @@ def test_n_non_seqs(fn, sequences, outputs_info, non_sequences, n_steps, op_chec ...@@ -4087,5 +4087,7 @@ def test_n_non_seqs(fn, sequences, outputs_info, non_sequences, n_steps, op_chec
_ = op_check(scan_op) _ = op_check(scan_op)
assert scan_op.n_outer_inputs == len(res.owner.inputs) assert scan_op.info.n_outer_inputs == len(res.owner.inputs)
assert scan_op.n_outer_outputs == len(res.owner.outputs) assert scan_op.info.n_outer_outputs == len(res.owner.outputs)
assert scan_op.info.n_inner_inputs == len(res.owner.op.inputs)
assert scan_op.info.n_inner_outputs == len(res.owner.op.outputs)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论