提交 711843bb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove ScanMethodsMixin Apply argument handling

上级 93eb73fd
...@@ -225,8 +225,6 @@ class ScanMethodsMixin: ...@@ -225,8 +225,6 @@ class ScanMethodsMixin:
return list_inputs[: self.n_seqs] return list_inputs[: self.n_seqs]
def outer_seqs(self, list_inputs): def outer_seqs(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
# Given the list of outer inputs this function grabs those # Given the list of outer inputs this function grabs those
# corresponding to sequences # corresponding to sequences
return list_inputs[1 : 1 + self.n_seqs] return list_inputs[1 : 1 + self.n_seqs]
...@@ -236,8 +234,6 @@ class ScanMethodsMixin: ...@@ -236,8 +234,6 @@ class ScanMethodsMixin:
return list_inputs[self.n_seqs : self.n_seqs + n_taps] return list_inputs[self.n_seqs : self.n_seqs + n_taps]
def outer_mitmot(self, list_inputs): def outer_mitmot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
return list_inputs[1 + self.n_seqs : 1 + self.n_seqs + self.n_mit_mot] return list_inputs[1 + self.n_seqs : 1 + self.n_seqs + self.n_mit_mot]
def inner_mitmot_outs(self, list_outputs): def inner_mitmot_outs(self, list_outputs):
...@@ -245,8 +241,6 @@ class ScanMethodsMixin: ...@@ -245,8 +241,6 @@ class ScanMethodsMixin:
return list_outputs[:n_taps] return list_outputs[:n_taps]
def outer_mitmot_outs(self, list_outputs): def outer_mitmot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
return list_outputs[: self.n_mit_mot] return list_outputs[: self.n_mit_mot]
def mitmot_taps(self): def mitmot_taps(self):
...@@ -265,8 +259,6 @@ class ScanMethodsMixin: ...@@ -265,8 +259,6 @@ class ScanMethodsMixin:
] ]
def outer_mitsot(self, list_inputs): def outer_mitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot offset = 1 + self.n_seqs + self.n_mit_mot
return list_inputs[offset : offset + self.n_mit_sot] return list_inputs[offset : offset + self.n_mit_sot]
...@@ -275,8 +267,6 @@ class ScanMethodsMixin: ...@@ -275,8 +267,6 @@ class ScanMethodsMixin:
return list_outputs[n_taps : n_taps + self.n_mit_sot] return list_outputs[n_taps : n_taps + self.n_mit_sot]
def outer_mitsot_outs(self, list_outputs): def outer_mitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
return list_outputs[self.n_mit_mot : self.n_mit_mot + self.n_mit_sot] return list_outputs[self.n_mit_mot : self.n_mit_mot + self.n_mit_sot]
def mitsot_taps(self): def mitsot_taps(self):
...@@ -290,8 +280,6 @@ class ScanMethodsMixin: ...@@ -290,8 +280,6 @@ class ScanMethodsMixin:
return list_inputs[offset : offset + self.n_sit_sot] return list_inputs[offset : offset + self.n_sit_sot]
def outer_sitsot(self, list_inputs): def outer_sitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot
return list_inputs[offset : offset + self.n_sit_sot] return list_inputs[offset : offset + self.n_sit_sot]
...@@ -301,14 +289,10 @@ class ScanMethodsMixin: ...@@ -301,14 +289,10 @@ class ScanMethodsMixin:
return list_outputs[offset : offset + self.n_sit_sot] return list_outputs[offset : offset + self.n_sit_sot]
def outer_sitsot_outs(self, list_outputs): def outer_sitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = self.n_mit_mot + self.n_mit_sot offset = self.n_mit_mot + self.n_mit_sot
return list_outputs[offset : offset + self.n_sit_sot] return list_outputs[offset : offset + self.n_sit_sot]
def outer_nitsot(self, list_inputs): def outer_nitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = ( offset = (
1 1
+ self.n_seqs + self.n_seqs
...@@ -325,8 +309,6 @@ class ScanMethodsMixin: ...@@ -325,8 +309,6 @@ class ScanMethodsMixin:
return list_outputs[offset : offset + self.n_nit_sot] return list_outputs[offset : offset + self.n_nit_sot]
def outer_nitsot_outs(self, list_outputs): def outer_nitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
return list_outputs[offset : offset + self.n_nit_sot] return list_outputs[offset : offset + self.n_nit_sot]
...@@ -338,8 +320,6 @@ class ScanMethodsMixin: ...@@ -338,8 +320,6 @@ class ScanMethodsMixin:
return list_inputs[offset : offset + self.n_shared_outs] return list_inputs[offset : offset + self.n_shared_outs]
def outer_shared(self, list_inputs): def outer_shared(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + self.n_sit_sot offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
return list_inputs[offset : offset + self.n_shared_outs] return list_inputs[offset : offset + self.n_shared_outs]
...@@ -349,8 +329,6 @@ class ScanMethodsMixin: ...@@ -349,8 +329,6 @@ class ScanMethodsMixin:
return list_outputs[offset : offset + self.n_shared_outs] return list_outputs[offset : offset + self.n_shared_outs]
def outer_shared_outs(self, list_outputs): def outer_shared_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot offset = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot + self.n_nit_sot
return list_outputs[offset : offset + self.n_shared_outs] return list_outputs[offset : offset + self.n_shared_outs]
...@@ -362,8 +340,6 @@ class ScanMethodsMixin: ...@@ -362,8 +340,6 @@ class ScanMethodsMixin:
return list_inputs[offset:] return list_inputs[offset:]
def outer_non_seqs(self, list_inputs): def outer_non_seqs(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = ( offset = (
1 1
+ self.n_seqs + self.n_seqs
......
...@@ -715,7 +715,7 @@ class PushOutSeqScan(GlobalOptimizer): ...@@ -715,7 +715,7 @@ class PushOutSeqScan(GlobalOptimizer):
reason="scanOp_pushout_seqs_ops", reason="scanOp_pushout_seqs_ops",
) )
return True return True
elif not to_keep_set and not op.as_while and not op.outer_mitmot(node): elif not to_keep_set and not op.as_while and not op.outer_mitmot(node.inputs):
# Nothing in the inner graph should be kept # Nothing in the inner graph should be kept
replace_with = {} replace_with = {}
for out, idx in to_replace_map.items(): for out, idx in to_replace_map.items():
...@@ -725,12 +725,12 @@ class PushOutSeqScan(GlobalOptimizer): ...@@ -725,12 +725,12 @@ class PushOutSeqScan(GlobalOptimizer):
ls = clean_outputs ls = clean_outputs
if out in op.inner_mitsot_outs(ls): if out in op.inner_mitsot_outs(ls):
odx = op.inner_mitsot_outs(ls).index(out) odx = op.inner_mitsot_outs(ls).index(out)
inp = op.outer_mitsot(node)[odx] inp = op.outer_mitsot(node.inputs)[odx]
st = abs(np.min(op.mitsot_taps())) st = abs(np.min(op.mitsot_taps()))
y = set_subtensor(inp[st:], _y) y = set_subtensor(inp[st:], _y)
elif out in op.inner_sitsot_outs(ls): elif out in op.inner_sitsot_outs(ls):
odx = op.inner_sitsot_outs(ls).index(out) odx = op.inner_sitsot_outs(ls).index(out)
inp = op.outer_sitsot(node)[odx] inp = op.outer_sitsot(node.inputs)[odx]
y = set_subtensor(inp[1:], _y) y = set_subtensor(inp[1:], _y)
elif out in op.inner_nitsot_outs(ls): elif out in op.inner_nitsot_outs(ls):
y = _y y = _y
...@@ -2301,7 +2301,7 @@ class PushOutDot1(GlobalOptimizer): ...@@ -2301,7 +2301,7 @@ class PushOutDot1(GlobalOptimizer):
op = node.op op = node.op
sitsot_ins = op.inner_sitsot(op.inputs) sitsot_ins = op.inner_sitsot(op.inputs)
sitsot_outs = op.inner_sitsot_outs(op.outputs) sitsot_outs = op.inner_sitsot_outs(op.outputs)
outer_sitsot = op.outer_sitsot_outs(node) outer_sitsot = op.outer_sitsot_outs(node.outputs)
seqs = op.inner_seqs(op.inputs) seqs = op.inner_seqs(op.inputs)
for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot): for inp, out, outer_out in zip(sitsot_ins, sitsot_outs, outer_sitsot):
...@@ -2345,23 +2345,23 @@ class PushOutDot1(GlobalOptimizer): ...@@ -2345,23 +2345,23 @@ class PushOutDot1(GlobalOptimizer):
# corresponding categories # corresponding categories
inner_seqs = op.inner_seqs(op.inputs) inner_seqs = op.inner_seqs(op.inputs)
outer_seqs = op.outer_seqs(node) outer_seqs = op.outer_seqs(node.inputs)
inner_mitmot = op.inner_mitmot(op.inputs) inner_mitmot = op.inner_mitmot(op.inputs)
outer_mitmot = op.outer_mitmot(node) outer_mitmot = op.outer_mitmot(node.inputs)
inner_mitmot_outs = op.inner_mitmot_outs(op.outputs) inner_mitmot_outs = op.inner_mitmot_outs(op.outputs)
inner_mitsot = op.inner_mitsot(op.inputs) inner_mitsot = op.inner_mitsot(op.inputs)
outer_mitsot = op.outer_mitsot(node) outer_mitsot = op.outer_mitsot(node.inputs)
inner_mitsot_outs = op.inner_mitsot_outs(op.outputs) inner_mitsot_outs = op.inner_mitsot_outs(op.outputs)
inner_sitsot = op.inner_sitsot(op.inputs) inner_sitsot = op.inner_sitsot(op.inputs)
outer_sitsot = op.outer_sitsot(node) outer_sitsot = op.outer_sitsot(node.inputs)
inner_sitsot_outs = op.inner_sitsot_outs(op.outputs) inner_sitsot_outs = op.inner_sitsot_outs(op.outputs)
outer_nitsot = op.outer_nitsot(node) outer_nitsot = op.outer_nitsot(node.inputs)
inner_nitsot_outs = op.inner_nitsot_outs(op.outputs) inner_nitsot_outs = op.inner_nitsot_outs(op.outputs)
inner_shared = op.inner_shared(op.inputs) inner_shared = op.inner_shared(op.inputs)
outer_shared = op.outer_shared(node) outer_shared = op.outer_shared(node.inputs)
inner_shared_outs = op.inner_shared_outs(op.outputs) inner_shared_outs = op.inner_shared_outs(op.outputs)
inner_non_seqs = op.inner_non_seqs(op.inputs) inner_non_seqs = op.inner_non_seqs(op.inputs)
outer_non_seqs = op.outer_non_seqs(node) outer_non_seqs = op.outer_non_seqs(node.inputs)
st = len(op.mitmot_taps()) + len(op.mitsot_taps()) st = len(op.mitmot_taps()) + len(op.mitsot_taps())
...@@ -2437,7 +2437,7 @@ class PushOutDot1(GlobalOptimizer): ...@@ -2437,7 +2437,7 @@ class PushOutDot1(GlobalOptimizer):
_val = outer_nitsot_outs[-1] _val = outer_nitsot_outs[-1]
outer_nitsot_outs = outer_nitsot_outs[:-1] outer_nitsot_outs = outer_nitsot_outs[:-1]
if inp1 in seqs: if inp1 in seqs:
_out_seq = op.outer_seqs(node)[seqs.index(inp1)] _out_seq = op.outer_seqs(node.inputs)[seqs.index(inp1)]
# We need to clip the seq to the number of steps # We need to clip the seq to the number of steps
_out_seq = _out_seq[: node.inputs[0]] _out_seq = _out_seq[: node.inputs[0]]
sh0 = _out_seq.shape[0] sh0 = _out_seq.shape[0]
...@@ -2452,7 +2452,7 @@ class PushOutDot1(GlobalOptimizer): ...@@ -2452,7 +2452,7 @@ class PushOutDot1(GlobalOptimizer):
val = _val.reshape((sh0 * sh1, sh2)) val = _val.reshape((sh0 * sh1, sh2))
new_out = dot(out_seq, val) new_out = dot(out_seq, val)
else: else:
_out_seq = op.outer_seqs(node)[seqs.index(inp2)] _out_seq = op.outer_seqs(node.inputs)[seqs.index(inp2)]
out_seq = _out_seq.reshape( out_seq = _out_seq.reshape(
( (
_out_seq.shape[0] * _out_seq.shape[1], _out_seq.shape[0] * _out_seq.shape[1],
......
...@@ -4341,7 +4341,7 @@ class TestScan: ...@@ -4341,7 +4341,7 @@ class TestScan:
inp = scan_node.op.inner_non_seqs(scan_node.op.inputs) inp = scan_node.op.inner_non_seqs(scan_node.op.inputs)
assert len(inp) == 1 assert len(inp) == 1
assert len(inp) == len(set(inp)) assert len(inp) == len(set(inp))
inp = scan_node.op.outer_non_seqs(scan_node) inp = scan_node.op.outer_non_seqs(scan_node.inputs)
assert len(inp) == 1 assert len(inp) == 1
assert len(inp) == len(set(inp)) assert len(inp) == len(set(inp))
...@@ -4409,11 +4409,11 @@ class TestScan: ...@@ -4409,11 +4409,11 @@ class TestScan:
assert len(scan_node.inputs[1:]) == len(set(scan_node.inputs[1:])) assert len(scan_node.inputs[1:]) == len(set(scan_node.inputs[1:]))
inp = scan_node.op.inner_seqs(scan_node.op.inputs) inp = scan_node.op.inner_seqs(scan_node.op.inputs)
assert len(inp) == 1 assert len(inp) == 1
inp = scan_node.op.outer_seqs(scan_node) inp = scan_node.op.outer_seqs(scan_node.inputs)
assert len(inp) == 1 assert len(inp) == 1
inp = scan_node.op.inner_non_seqs(scan_node.op.inputs) inp = scan_node.op.inner_non_seqs(scan_node.op.inputs)
assert len(inp) == 1 assert len(inp) == 1
inp = scan_node.op.outer_non_seqs(scan_node) inp = scan_node.op.outer_non_seqs(scan_node.inputs)
assert len(inp) == 1 assert len(inp) == 1
@pytest.mark.slow @pytest.mark.slow
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论