提交 711843bb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove ScanMethodsMixin Apply argument handling

上级 93eb73fd
......@@ -225,8 +225,6 @@ class ScanMethodsMixin:
return list_inputs[: self.n_seqs]
def outer_seqs(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
# Given the list of outer inputs this function grabs those
# corresponding to sequences
return list_inputs[1 : 1 + self.n_seqs]
......@@ -236,8 +234,6 @@ class ScanMethodsMixin:
return list_inputs[self.n_seqs : self.n_seqs + n_taps]
def outer_mitmot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
return list_inputs[1 + self.n_seqs : 1 + self.n_seqs + self.n_mit_mot]
def inner_mitmot_outs(self, list_outputs):
......@@ -245,8 +241,6 @@ class ScanMethodsMixin:
return list_outputs[:n_taps]
def outer_mitmot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
return list_outputs[: self.n_mit_mot]
def mitmot_taps(self):
......@@ -265,8 +259,6 @@ class ScanMethodsMixin:
]
def outer_mitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot
return list_inputs[offset : offset + self.n_mit_sot]
......@@ -275,8 +267,6 @@ class ScanMethodsMixin:
return list_outputs[n_taps : n_taps + self.n_mit_sot]
def outer_mitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
return list_outputs[self.n_mit_mot : self.n_mit_mot + self.n_mit_sot]
def mitsot_taps(self):
......@@ -290,8 +280,6 @@ class ScanMethodsMixin:
return list_inputs[offset : offset + self.n_sit_sot]
def outer_sitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot
return list_inputs[offset : offset + self.n_sit_sot]
......@@ -301,14 +289,10 @@ class ScanMethodsMixin:
return list_outputs[offset : offset + self.n_sit_sot]
def outer_sitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = self.n_mit_mot + self.n_mit_sot
return list_outputs[offset : offset + self.n_sit_sot]
def outer_nitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = (
1
+ self.n_seqs
......@@ -325,8 +309,6 @@ class ScanMethodsMixin:
return list_outputs[offset : offset + self.n_nit_sot]
def outer_nitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
return list_outputs[offset : offset + self.n_nit_sot]
......@@ -338,8 +320,6 @@ class ScanMethodsMixin:
return list_inputs[offset : offset + self.n_shared_outs]
def outer_shared(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
return list_inputs[offset : offset + self.n_shared_outs]
......@@ -349,8 +329,6 @@ class ScanMethodsMixin:
return list_outputs[offset : offset + self.n_shared_outs]
def outer_shared_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot
return list_outputs[offset : offset + self.n_shared_outs]
......@@ -362,8 +340,6 @@ class ScanMethodsMixin:
return list_inputs[offset:]
def outer_non_seqs(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = (
1
+ self.n_seqs
......
......@@ -715,7 +715,7 @@ class PushOutSeqScan(GlobalOptimizer):
reason="scanOp_pushout_seqs_ops",
)
return True
elif not to_keep_set and not op.as_while and not op.outer_mitmot(node):
elif not to_keep_set and not op.as_while and not op.outer_mitmot(node.inputs):
# Nothing in the inner graph should be kept
replace_with = {}
for out, idx in to_replace_map.items():
......@@ -725,12 +725,12 @@ class PushOutSeqScan(GlobalOptimizer):
ls = clean_outputs
if out in op.inner_mitsot_outs(ls):
odx = op.inner_mitsot_outs(ls).index(out)
inp = op.outer_mitsot(node)[odx]
inp = op.outer_mitsot(node.inputs)[odx]
st = abs(np.min(op.mitsot_taps()))
y = set_subtensor(inp[st:], _y)
elif out in op.inner_sitsot_outs(ls):
odx = op.inner_sitsot_outs(ls).index(out)
inp = op.outer_sitsot(node)[odx]
inp = op.outer_sitsot(node.inputs)[odx]
y = set_subtensor(inp[1:], _y)
elif out in op.inner_nitsot_outs(ls):
y = _y
......@@ -2301,7 +2301,7 @@ class PushOutDot1(GlobalOptimizer):
op = node.op
sitsot_ins = op.inner_sitsot(op.inputs)
sitsot_outs = op.inner_sitsot_outs(op.outputs)
outer_sitsot = op.outer_sitsot_outs(node)
outer_sitsot = op.outer_sitsot_outs(node.outputs)
seqs = op.inner_seqs(op.inputs)
for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot):
......@@ -2345,23 +2345,23 @@ class PushOutDot1(GlobalOptimizer):
# corresponding categories
inner_seqs = op.inner_seqs(op.inputs)
outer_seqs = op.outer_seqs(node)
outer_seqs = op.outer_seqs(node.inputs)
inner_mitmot = op.inner_mitmot(op.inputs)
outer_mitmot = op.outer_mitmot(node)
outer_mitmot = op.outer_mitmot(node.inputs)
inner_mitmot_outs = op.inner_mitmot_outs(op.outputs)
inner_mitsot = op.inner_mitsot(op.inputs)
outer_mitsot = op.outer_mitsot(node)
outer_mitsot = op.outer_mitsot(node.inputs)
inner_mitsot_outs = op.inner_mitsot_outs(op.outputs)
inner_sitsot = op.inner_sitsot(op.inputs)
outer_sitsot = op.outer_sitsot(node)
outer_sitsot = op.outer_sitsot(node.inputs)
inner_sitsot_outs = op.inner_sitsot_outs(op.outputs)
outer_nitsot = op.outer_nitsot(node)
outer_nitsot = op.outer_nitsot(node.inputs)
inner_nitsot_outs = op.inner_nitsot_outs(op.outputs)
inner_shared = op.inner_shared(op.inputs)
outer_shared = op.outer_shared(node)
outer_shared = op.outer_shared(node.inputs)
inner_shared_outs = op.inner_shared_outs(op.outputs)
inner_non_seqs = op.inner_non_seqs(op.inputs)
outer_non_seqs = op.outer_non_seqs(node)
outer_non_seqs = op.outer_non_seqs(node.inputs)
st = len(op.mitmot_taps()) + len(op.mitsot_taps())
......@@ -2437,7 +2437,7 @@ class PushOutDot1(GlobalOptimizer):
_val = outer_nitsot_outs[-1]
outer_nitsot_outs = outer_nitsot_outs[:-1]
if inp1 in seqs:
_out_seq = op.outer_seqs(node)[seqs.index(inp1)]
_out_seq = op.outer_seqs(node.inputs)[seqs.index(inp1)]
# We need to clip the seq to the number of steps
_out_seq = _out_seq[: node.inputs[0]]
sh0 = _out_seq.shape[0]
......@@ -2452,7 +2452,7 @@ class PushOutDot1(GlobalOptimizer):
val = _val.reshape((sh0 * sh1, sh2))
new_out = dot(out_seq, val)
else:
_out_seq = op.outer_seqs(node)[seqs.index(inp2)]
_out_seq = op.outer_seqs(node.inputs)[seqs.index(inp2)]
out_seq = _out_seq.reshape(
(
_out_seq.shape[0] * _out_seq.shape[1],
......
......@@ -4341,7 +4341,7 @@ class TestScan:
inp = scan_node.op.inner_non_seqs(scan_node.op.inputs)
assert len(inp) == 1
assert len(inp) == len(set(inp))
inp = scan_node.op.outer_non_seqs(scan_node)
inp = scan_node.op.outer_non_seqs(scan_node.inputs)
assert len(inp) == 1
assert len(inp) == len(set(inp))
......@@ -4409,11 +4409,11 @@ class TestScan:
assert len(scan_node.inputs[1:]) == len(set(scan_node.inputs[1:]))
inp = scan_node.op.inner_seqs(scan_node.op.inputs)
assert len(inp) == 1
inp = scan_node.op.outer_seqs(scan_node)
inp = scan_node.op.outer_seqs(scan_node.inputs)
assert len(inp) == 1
inp = scan_node.op.inner_non_seqs(scan_node.op.inputs)
assert len(inp) == 1
inp = scan_node.op.outer_non_seqs(scan_node)
inp = scan_node.op.outer_non_seqs(scan_node.inputs)
assert len(inp) == 1
@pytest.mark.slow
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论