提交 04c4d86d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Convert redundant ScanInfo attributes to properties

上级 ecf77a1f
...@@ -688,7 +688,6 @@ def scan( ...@@ -688,7 +688,6 @@ def scan(
# MIT_MOT -- not provided by the user only by the grad function # MIT_MOT -- not provided by the user only by the grad function
n_mit_mot = 0 n_mit_mot = 0
n_mit_mot_outs = 0
mit_mot_scan_inputs = [] mit_mot_scan_inputs = []
mit_mot_inner_inputs = [] mit_mot_inner_inputs = []
mit_mot_inner_outputs = [] mit_mot_inner_outputs = []
...@@ -1129,13 +1128,9 @@ def scan( ...@@ -1129,13 +1128,9 @@ def scan(
info = ScanInfo( info = ScanInfo(
n_seqs=n_seqs, n_seqs=n_seqs,
mit_mot_in_slices=(), mit_mot_in_slices=(),
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array), mit_sot_in_slices=tuple(tuple(v) for v in mit_sot_tap_array),
sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)), sit_sot_in_slices=tuple((-1,) for x in range(n_sit_sot)),
n_mit_mot=n_mit_mot,
n_mit_mot_outs=n_mit_mot_outs,
mit_mot_out_slices=tuple(tuple(v) for v in mit_mot_out_slices),
n_mit_sot=n_mit_sot,
n_sit_sot=n_sit_sot,
n_shared_outs=n_shared_outs, n_shared_outs=n_shared_outs,
n_nit_sot=n_nit_sot, n_nit_sot=n_nit_sot,
n_non_seqs=len(other_shared_inner_args) + len(other_inner_args), n_non_seqs=len(other_shared_inner_args) + len(other_inner_args),
......
...@@ -203,32 +203,42 @@ def copy_var_format(var, as_var): ...@@ -203,32 +203,42 @@ def copy_var_format(var, as_var):
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ScanInfo: class ScanInfo:
n_seqs: int
mit_mot_in_slices: tuple mit_mot_in_slices: tuple
mit_mot_out_slices: tuple
mit_sot_in_slices: tuple mit_sot_in_slices: tuple
sit_sot_in_slices: tuple sit_sot_in_slices: tuple
n_seqs: int
n_mit_mot: int
n_mit_mot_outs: int
mit_mot_out_slices: tuple
n_mit_sot: int
n_sit_sot: int
n_shared_outs: int
n_nit_sot: int n_nit_sot: int
n_shared_outs: int
n_non_seqs: int n_non_seqs: int
as_while: bool as_while: bool
@property
def n_mit_mot(self):
return len(self.mit_mot_in_slices)
@property
def n_mit_mot_outs(self):
return sum(len(x) for x in self.mit_mot_out_slices)
@property
def n_mit_sot(self):
return len(self.mit_sot_in_slices)
@property
def n_sit_sot(self):
return len(self.sit_sot_in_slices)
@property @property
def tap_array(self): def tap_array(self):
return self.mit_mot_in_slices + self.mit_sot_in_slices + self.sit_sot_in_slices return self.mit_mot_in_slices + self.mit_sot_in_slices + self.sit_sot_in_slices
@property @property
def n_inner_inputs(self): def n_inner_inputs(self):
n_mit_mot_taps = sum(len(x) for x in self.mit_mot_in_slices)
n_mit_sot_taps = sum(len(x) for x in self.mit_sot_in_slices)
return ( return (
self.n_seqs self.n_seqs
+ n_mit_mot_taps + sum(len(x) for x in self.mit_mot_in_slices)
+ n_mit_sot_taps + sum(len(x) for x in self.mit_sot_in_slices)
+ self.n_sit_sot + self.n_sit_sot
+ self.n_shared_outs + self.n_shared_outs
+ self.n_non_seqs + self.n_non_seqs
...@@ -236,9 +246,8 @@ class ScanInfo: ...@@ -236,9 +246,8 @@ class ScanInfo:
@property @property
def n_inner_outputs(self): def n_inner_outputs(self):
n_mit_mot_out_taps = sum(len(x) for x in self.mit_mot_out_slices)
return ( return (
n_mit_mot_out_taps self.n_mit_mot_outs
+ self.n_mit_sot + self.n_mit_sot
+ self.n_sit_sot + self.n_sit_sot
+ self.n_nit_sot + self.n_nit_sot
...@@ -262,7 +271,7 @@ class ScanInfo: ...@@ -262,7 +271,7 @@ class ScanInfo:
@property @property
def n_outer_outputs(self): def n_outer_outputs(self):
return ( return (
self.n_mit_mot len(self.mit_mot_out_slices)
+ self.n_mit_sot + self.n_mit_sot
+ self.n_sit_sot + self.n_sit_sot
+ self.n_nit_sot + self.n_nit_sot
...@@ -758,6 +767,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -758,6 +767,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self.strict = strict self.strict = strict
self.__dict__.update(dataclasses.asdict(info)) self.__dict__.update(dataclasses.asdict(info))
self.n_mit_mot = self.info.n_mit_mot
self.n_mit_mot_outs = self.info.n_mit_mot_outs
self.n_mit_sot = self.info.n_mit_sot
self.n_sit_sot = self.info.n_sit_sot
# Clone mode_instance, altering "allow_gc" for the linker, # Clone mode_instance, altering "allow_gc" for the linker,
# and adding a message if we profile # and adding a message if we profile
if self.name: if self.name:
...@@ -2956,16 +2970,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2956,16 +2970,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
out_info = ScanInfo( out_info = ScanInfo(
n_seqs=len(outer_inp_seqs), n_seqs=len(outer_inp_seqs),
n_mit_sot=0,
mit_mot_in_slices=tuple(tuple(v) for v in mitmot_inp_taps), mit_mot_in_slices=tuple(tuple(v) for v in mitmot_inp_taps),
mit_mot_out_slices=tuple(tuple(v) for v in mitmot_out_taps),
mit_sot_in_slices=(), mit_sot_in_slices=(),
sit_sot_in_slices=tuple((-1,) for k in range(n_sitsot_outs)), sit_sot_in_slices=tuple((-1,) for k in range(n_sitsot_outs)),
n_mit_mot=len(outer_inp_mitmot),
n_mit_mot_outs=n_mitmot_outs,
mit_mot_out_slices=tuple(tuple(v) for v in mitmot_out_taps),
n_sit_sot=n_sitsot_outs,
n_shared_outs=0,
n_nit_sot=n_nit_sot, n_nit_sot=n_nit_sot,
n_shared_outs=0,
n_non_seqs=len(self.outer_shared(inputs)) n_non_seqs=len(self.outer_shared(inputs))
+ len(self.outer_non_seqs(inputs)), + len(self.outer_non_seqs(inputs)),
as_while=False, as_while=False,
...@@ -3279,15 +3289,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -3279,15 +3289,11 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
out_info = ScanInfo( out_info = ScanInfo(
n_seqs=info.n_seqs * 2, n_seqs=info.n_seqs * 2,
mit_mot_in_slices=new_mit_mot_in_slices, mit_mot_in_slices=new_mit_mot_in_slices,
mit_mot_out_slices=tuple(tuple(v) for v in info.mit_mot_out_slices) * 2,
mit_sot_in_slices=new_mit_sot_in_slices, mit_sot_in_slices=new_mit_sot_in_slices,
sit_sot_in_slices=new_sit_sot_in_slices, sit_sot_in_slices=new_sit_sot_in_slices,
n_mit_sot=info.n_mit_sot * 2,
n_sit_sot=info.n_sit_sot * 2,
n_mit_mot=info.n_mit_mot * 2,
n_nit_sot=info.n_nit_sot * 2, n_nit_sot=info.n_nit_sot * 2,
n_shared_outs=info.n_shared_outs, n_shared_outs=info.n_shared_outs,
n_mit_mot_outs=n_mit_mot_outs * 2,
mit_mot_out_slices=tuple(tuple(v) for v in info.mit_mot_out_slices) * 2,
n_non_seqs=len(inner_other), n_non_seqs=len(inner_other),
as_while=info.as_while, as_while=info.as_while,
) )
......
...@@ -1794,17 +1794,13 @@ class ScanMerge(GlobalOptimizer): ...@@ -1794,17 +1794,13 @@ class ScanMerge(GlobalOptimizer):
new_inner_outs += inner_outs[idx][gr_idx] new_inner_outs += inner_outs[idx][gr_idx]
info = ScanInfo( info = ScanInfo(
n_seqs=sum(nd.op.n_seqs for nd in nodes),
mit_mot_in_slices=mit_mot_in_slices, mit_mot_in_slices=mit_mot_in_slices,
mit_mot_out_slices=mit_mot_out_slices,
mit_sot_in_slices=mit_sot_in_slices, mit_sot_in_slices=mit_sot_in_slices,
sit_sot_in_slices=sit_sot_in_slices, sit_sot_in_slices=sit_sot_in_slices,
n_seqs=sum(nd.op.n_seqs for nd in nodes),
n_mit_mot=sum(nd.op.n_mit_mot for nd in nodes),
n_mit_mot_outs=sum(nd.op.n_mit_mot_outs for nd in nodes),
mit_mot_out_slices=mit_mot_out_slices,
n_mit_sot=sum(nd.op.n_mit_sot for nd in nodes),
n_sit_sot=sum(nd.op.n_sit_sot for nd in nodes),
n_shared_outs=sum(nd.op.n_shared_outs for nd in nodes),
n_nit_sot=sum(nd.op.n_nit_sot for nd in nodes), n_nit_sot=sum(nd.op.n_nit_sot for nd in nodes),
n_shared_outs=sum(nd.op.n_shared_outs for nd in nodes),
n_non_seqs=n_non_seqs, n_non_seqs=n_non_seqs,
as_while=as_while, as_while=as_while,
) )
...@@ -2218,7 +2214,6 @@ def push_out_dot1_scan(fgraph, node): ...@@ -2218,7 +2214,6 @@ def push_out_dot1_scan(fgraph, node):
op.info, op.info,
sit_sot_in_slices=op.info.sit_sot_in_slices[:idx] sit_sot_in_slices=op.info.sit_sot_in_slices[:idx]
+ op.info.sit_sot_in_slices[idx + 1 :], + op.info.sit_sot_in_slices[idx + 1 :],
n_sit_sot=op.info.n_sit_sot - 1,
n_nit_sot=op.info.n_nit_sot + 1, n_nit_sot=op.info.n_nit_sot + 1,
) )
inner_sitsot = inner_sitsot[:idx] + inner_sitsot[idx + 1 :] inner_sitsot = inner_sitsot[:idx] + inner_sitsot[idx + 1 :]
......
...@@ -401,17 +401,13 @@ def compress_outs(op, not_required, inputs): ...@@ -401,17 +401,13 @@ def compress_outs(op, not_required, inputs):
from aesara.scan.op import ScanInfo from aesara.scan.op import ScanInfo
info = ScanInfo( info = ScanInfo(
n_seqs=op.info.n_seqs,
mit_mot_in_slices=(), mit_mot_in_slices=(),
mit_mot_out_slices=(),
mit_sot_in_slices=(), mit_sot_in_slices=(),
sit_sot_in_slices=(), sit_sot_in_slices=(),
n_seqs=op.info.n_seqs,
n_mit_mot=0,
n_mit_mot_outs=0,
mit_mot_out_slices=(),
n_mit_sot=0,
n_sit_sot=0,
n_shared_outs=0,
n_nit_sot=0, n_nit_sot=0,
n_shared_outs=0,
n_non_seqs=0, n_non_seqs=0,
as_while=op.info.as_while, as_while=op.info.as_while,
) )
...@@ -432,7 +428,6 @@ def compress_outs(op, not_required, inputs): ...@@ -432,7 +428,6 @@ def compress_outs(op, not_required, inputs):
curr_pos += 1 curr_pos += 1
info = dataclasses.replace( info = dataclasses.replace(
info, info,
n_mit_mot=info.n_mit_mot + 1,
mit_mot_in_slices=info.mit_mot_in_slices + (op.mit_mot_in_slices[idx],), mit_mot_in_slices=info.mit_mot_in_slices + (op.mit_mot_in_slices[idx],),
mit_mot_out_slices=info.mit_mot_out_slices mit_mot_out_slices=info.mit_mot_out_slices
+ (op.mit_mot_out_slices[idx],), + (op.mit_mot_out_slices[idx],),
...@@ -451,7 +446,6 @@ def compress_outs(op, not_required, inputs): ...@@ -451,7 +446,6 @@ def compress_outs(op, not_required, inputs):
o_offset += len(op.mit_mot_out_slices[idx]) o_offset += len(op.mit_mot_out_slices[idx])
i_offset += len(op.mit_mot_in_slices[idx]) i_offset += len(op.mit_mot_in_slices[idx])
info = dataclasses.replace(info, n_mit_mot_outs=len(op_outputs))
offset += op.n_mit_mot offset += op.n_mit_mot
ni_offset += op.n_mit_mot ni_offset += op.n_mit_mot
...@@ -461,7 +455,6 @@ def compress_outs(op, not_required, inputs): ...@@ -461,7 +455,6 @@ def compress_outs(op, not_required, inputs):
curr_pos += 1 curr_pos += 1
info = dataclasses.replace( info = dataclasses.replace(
info, info,
n_mit_sot=info.n_mit_sot + 1,
mit_sot_in_slices=info.mit_sot_in_slices + (op.mit_sot_in_slices[idx],), mit_sot_in_slices=info.mit_sot_in_slices + (op.mit_sot_in_slices[idx],),
) )
# input taps # input taps
...@@ -485,7 +478,6 @@ def compress_outs(op, not_required, inputs): ...@@ -485,7 +478,6 @@ def compress_outs(op, not_required, inputs):
curr_pos += 1 curr_pos += 1
info = dataclasses.replace( info = dataclasses.replace(
info, info,
n_sit_sot=info.n_sit_sot + 1,
sit_sot_in_slices=info.sit_sot_in_slices + (op.sit_sot_in_slices[idx],), sit_sot_in_slices=info.sit_sot_in_slices + (op.sit_sot_in_slices[idx],),
) )
# input taps # input taps
...@@ -807,17 +799,13 @@ class ScanArgs: ...@@ -807,17 +799,13 @@ class ScanArgs:
from aesara.scan.op import ScanInfo from aesara.scan.op import ScanInfo
return ScanInfo( return ScanInfo(
n_seqs=len(self.outer_in_seqs),
mit_mot_in_slices=tuple(tuple(v) for v in self.mit_mot_in_slices), mit_mot_in_slices=tuple(tuple(v) for v in self.mit_mot_in_slices),
mit_mot_out_slices=tuple(self.mit_mot_out_slices),
mit_sot_in_slices=tuple(tuple(v) for v in self.mit_sot_in_slices), mit_sot_in_slices=tuple(tuple(v) for v in self.mit_sot_in_slices),
sit_sot_in_slices=((-1,),) * len(self.inner_in_sit_sot), sit_sot_in_slices=((-1,),) * len(self.inner_in_sit_sot),
n_seqs=len(self.outer_in_seqs),
n_mit_mot=len(self.outer_in_mit_mot),
n_mit_sot=len(self.outer_in_mit_sot),
n_sit_sot=len(self.outer_in_sit_sot),
n_nit_sot=len(self.outer_in_nit_sot), n_nit_sot=len(self.outer_in_nit_sot),
n_shared_outs=len(self.outer_in_shared), n_shared_outs=len(self.outer_in_shared),
n_mit_mot_outs=sum(len(s) for s in self.mit_mot_out_slices),
mit_mot_out_slices=tuple(self.mit_mot_out_slices),
n_non_seqs=len(self.inner_in_non_seqs), n_non_seqs=len(self.inner_in_non_seqs),
as_while=self.as_while, as_while=self.as_while,
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论