提交 04c4d86d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Convert redundant ScanInfo attributes to properties

上级 ecf77a1f
......@@ -688,7 +688,6 @@ def scan(
# MIT_MOT -- not provided by the user only by the grad function
n_mit_mot = 0
n_mit_mot_outs = 0
mit_mot_scan_inputs = []
mit_mot_inner_inputs = []
mit_mot_inner_outputs = []
......@@ -1129,13 +1128,9 @@ def scan(
info = ScanInfo(
n_seqs=n_seqs,
mit_mot_in_slices=(),
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array),
sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)),
n_mit_mot=n_mit_mot,
n_mit_mot_outs=n_mit_mot_outs,
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
n_mit_sot=n_mit_sot,
n_sit_sot=n_sit_sot,
n_shared_outs=n_shared_outs,
n_nit_sot=n_nit_sot,
n_non_seqs=len(other_shared_inner_args) + len(other_inner_args),
......
......@@ -203,32 +203,42 @@ def copy_var_format(var, as_var):
@dataclasses.dataclass(frozen=True)
class ScanInfo:
n_seqs: int
mit_mot_in_slices: tuple
mit_mot_out_slices: tuple
mit_sot_in_slices: tuple
sit_sot_in_slices: tuple
n_seqs: int
n_mit_mot: int
n_mit_mot_outs: int
mit_mot_out_slices: tuple
n_mit_sot: int
n_sit_sot: int
n_shared_outs: int
n_nit_sot: int
n_shared_outs: int
n_non_seqs: int
as_while: bool
@property
def n_mit_mot(self):
return len(self.mit_mot_in_slices)
@property
def n_mit_mot_outs(self):
return sum(len(x) for x in self.mit_mot_out_slices)
@property
def n_mit_sot(self):
return len(self.mit_sot_in_slices)
@property
def n_sit_sot(self):
return len(self.sit_sot_in_slices)
@property
def tap_array(self):
return self.mit_mot_in_slices + self.mit_sot_in_slices + self.sit_sot_in_slices
@property
def n_inner_inputs(self):
n_mit_mot_taps = sum(len(x) for x in self.mit_mot_in_slices)
n_mit_sot_taps = sum(len(x) for x in self.mit_sot_in_slices)
return (
self.n_seqs
+ n_mit_mot_taps
+ n_mit_sot_taps
+ sum(len(x) for x in self.mit_mot_in_slices)
+ sum(len(x) for x in self.mit_sot_in_slices)
+ self.n_sit_sot
+ self.n_shared_outs
+ self.n_non_seqs
......@@ -236,9 +246,8 @@ class ScanInfo:
@property
def n_inner_outputs(self):
n_mit_mot_out_taps = sum(len(x) for x in self.mit_mot_out_slices)
return (
n_mit_mot_out_taps
self.n_mit_mot_outs
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_nit_sot
......@@ -262,7 +271,7 @@ class ScanInfo:
@property
def n_outer_outputs(self):
return (
self.n_mit_mot
len(self.mit_mot_out_slices)
+ self.n_mit_sot
+ self.n_sit_sot
+ self.n_nit_sot
......@@ -758,6 +767,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.strict = strict
self.__dict__.update(dataclasses.asdict(info))
self.n_mit_mot = self.info.n_mit_mot
self.n_mit_mot_outs = self.info.n_mit_mot_outs
self.n_mit_sot = self.info.n_mit_sot
self.n_sit_sot = self.info.n_sit_sot
# Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile
if self.name:
......@@ -2956,16 +2970,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
out_info = ScanInfo(
n_seqs=len(outer_inp_seqs),
n_mit_sot=0,
mit_mot_in_slices=tuple(tuple(v) for v in mitmot_inp_taps),
mit_mot_out_slices=tuple(tuple(v) for v in mitmot_out_taps),
mit_sot_in_slices=(),
sit_sot_in_slices=tuple((-1,) for k in range(n_sitsot_outs)),
n_mit_mot=len(outer_inp_mitmot),
n_mit_mot_outs=n_mitmot_outs,
mit_mot_out_slices=tuple(tuple(v) for v in mitmot_out_taps),
n_sit_sot=n_sitsot_outs,
n_shared_outs=0,
n_nit_sot=n_nit_sot,
n_shared_outs=0,
n_non_seqs=len(self.outer_shared(inputs))
+ len(self.outer_non_seqs(inputs)),
as_while=False,
......@@ -3279,15 +3289,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
out_info = ScanInfo(
n_seqs=info.n_seqs * 2,
mit_mot_in_slices=new_mit_mot_in_slices,
mit_mot_out_slices=tuple(tuple(v) for v in info.mit_mot_out_slices) * 2,
mit_sot_in_slices=new_mit_sot_in_slices,
sit_sot_in_slices=new_sit_sot_in_slices,
n_mit_sot=info.n_mit_sot * 2,
n_sit_sot=info.n_sit_sot * 2,
n_mit_mot=info.n_mit_mot * 2,
n_nit_sot=info.n_nit_sot * 2,
n_shared_outs=info.n_shared_outs,
n_mit_mot_outs=n_mit_mot_outs * 2,
mit_mot_out_slices=tuple(tuple(v) for v in info.mit_mot_out_slices) * 2,
n_non_seqs=len(inner_other),
as_while=info.as_while,
)
......
......@@ -1794,17 +1794,13 @@ class ScanMerge(GlobalOptimizer):
new_inner_outs += inner_outs[idx][gr_idx]
info = ScanInfo(
n_seqs=sum(nd.op.n_seqs for nd in nodes),
mit_mot_in_slices=mit_mot_in_slices,
mit_mot_out_slices=mit_mot_out_slices,
mit_sot_in_slices=mit_sot_in_slices,
sit_sot_in_slices=sit_sot_in_slices,
n_seqs=sum(nd.op.n_seqs for nd in nodes),
n_mit_mot=sum(nd.op.n_mit_mot for nd in nodes),
n_mit_mot_outs=sum(nd.op.n_mit_mot_outs for nd in nodes),
mit_mot_out_slices=mit_mot_out_slices,
n_mit_sot=sum(nd.op.n_mit_sot for nd in nodes),
n_sit_sot=sum(nd.op.n_sit_sot for nd in nodes),
n_shared_outs=sum(nd.op.n_shared_outs for nd in nodes),
n_nit_sot=sum(nd.op.n_nit_sot for nd in nodes),
n_shared_outs=sum(nd.op.n_shared_outs for nd in nodes),
n_non_seqs=n_non_seqs,
as_while=as_while,
)
......@@ -2218,7 +2214,6 @@ def push_out_dot1_scan(fgraph, node):
op.info,
sit_sot_in_slices=op.info.sit_sot_in_slices[:idx]
+ op.info.sit_sot_in_slices[idx + 1 :],
n_sit_sot=op.info.n_sit_sot - 1,
n_nit_sot=op.info.n_nit_sot + 1,
)
inner_sitsot = inner_sitsot[:idx] + inner_sitsot[idx + 1 :]
......
......@@ -401,17 +401,13 @@ def compress_outs(op, not_required, inputs):
from aesara.scan.op import ScanInfo
info = ScanInfo(
n_seqs=op.info.n_seqs,
mit_mot_in_slices=(),
mit_mot_out_slices=(),
mit_sot_in_slices=(),
sit_sot_in_slices=(),
n_seqs=op.info.n_seqs,
n_mit_mot=0,
n_mit_mot_outs=0,
mit_mot_out_slices=(),
n_mit_sot=0,
n_sit_sot=0,
n_shared_outs=0,
n_nit_sot=0,
n_shared_outs=0,
n_non_seqs=0,
as_while=op.info.as_while,
)
......@@ -432,7 +428,6 @@ def compress_outs(op, not_required, inputs):
curr_pos += 1
info = dataclasses.replace(
info,
n_mit_mot=info.n_mit_mot + 1,
mit_mot_in_slices=info.mit_mot_in_slices + (op.mit_mot_in_slices[idx],),
mit_mot_out_slices=info.mit_mot_out_slices
+ (op.mit_mot_out_slices[idx],),
......@@ -451,7 +446,6 @@ def compress_outs(op, not_required, inputs):
o_offset += len(op.mit_mot_out_slices[idx])
i_offset += len(op.mit_mot_in_slices[idx])
info = dataclasses.replace(info, n_mit_mot_outs=len(op_outputs))
offset += op.n_mit_mot
ni_offset += op.n_mit_mot
......@@ -461,7 +455,6 @@ def compress_outs(op, not_required, inputs):
curr_pos += 1
info = dataclasses.replace(
info,
n_mit_sot=info.n_mit_sot + 1,
mit_sot_in_slices=info.mit_sot_in_slices + (op.mit_sot_in_slices[idx],),
)
# input taps
......@@ -485,7 +478,6 @@ def compress_outs(op, not_required, inputs):
curr_pos += 1
info = dataclasses.replace(
info,
n_sit_sot=info.n_sit_sot + 1,
sit_sot_in_slices=info.sit_sot_in_slices + (op.sit_sot_in_slices[idx],),
)
# input taps
......@@ -807,17 +799,13 @@ class ScanArgs:
from aesara.scan.op import ScanInfo
return ScanInfo(
n_seqs=len(self.outer_in_seqs),
mit_mot_in_slices=tuple(tuple(v) for v in self.mit_mot_in_slices),
mit_mot_out_slices=tuple(self.mit_mot_out_slices),
mit_sot_in_slices=tuple(tuple(v) for v in self.mit_sot_in_slices),
sit_sot_in_slices=((-1,),) * len(self.inner_in_sit_sot),
n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot),
n_sit_sot=len(self.outer_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot),
n_shared_outs=len(self.outer_in_shared),
n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
mit_mot_out_slices=tuple(self.mit_mot_out_slices),
n_non_seqs=len(self.inner_in_non_seqs),
as_while=self.as_while,
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论