提交 da367cc3 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #304 from nouiz/fix_gpu_scan_opt

Fix scan crash in recent change. The code seems fine.,
...@@ -550,6 +550,8 @@ class Scan(PureOp): ...@@ -550,6 +550,8 @@ class Scan(PureOp):
return list_inputs[:self.n_seqs] return list_inputs[:self.n_seqs]
def outer_seqs(self, list_inputs): def outer_seqs(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
# Given the list of outter inputs this function grabs those # Given the list of outter inputs this function grabs those
# corresponding to sequences # corresponding to sequences
return list_inputs[1:1 + self.n_seqs] return list_inputs[1:1 + self.n_seqs]
...@@ -559,6 +561,8 @@ class Scan(PureOp): ...@@ -559,6 +561,8 @@ class Scan(PureOp):
return list_inputs[self.n_seqs: self.n_seqs + n_taps] return list_inputs[self.n_seqs: self.n_seqs + n_taps]
def outer_mitmot(self, list_inputs): def outer_mitmot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
return list_inputs[1 + self.n_seqs:1 + self.n_seqs + self.n_mit_mot] return list_inputs[1 + self.n_seqs:1 + self.n_seqs + self.n_mit_mot]
def inner_mitmot_outs(self, list_outputs): def inner_mitmot_outs(self, list_outputs):
...@@ -566,6 +570,8 @@ class Scan(PureOp): ...@@ -566,6 +570,8 @@ class Scan(PureOp):
return list_outputs[:n_taps] return list_outputs[:n_taps]
def outer_mitmot_outs(self, list_outputs): def outer_mitmot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.ouputs
return list_outputs[:self.n_mit_mot] return list_outputs[:self.n_mit_mot]
def mitmot_taps(self): def mitmot_taps(self):
...@@ -583,6 +589,8 @@ class Scan(PureOp): ...@@ -583,6 +589,8 @@ class Scan(PureOp):
self.n_seqs + ntaps_upto_sit_sot] self.n_seqs + ntaps_upto_sit_sot]
def outer_mitsot(self, list_inputs): def outer_mitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot offset = 1 + self.n_seqs + self.n_mit_mot
return list_inputs[offset:offset + self.n_mit_sot] return list_inputs[offset:offset + self.n_mit_sot]
...@@ -591,6 +599,8 @@ class Scan(PureOp): ...@@ -591,6 +599,8 @@ class Scan(PureOp):
return list_outputs[n_taps:n_taps + self.n_mit_sot] return list_outputs[n_taps:n_taps + self.n_mit_sot]
def outer_mitsot_outs(self, list_outputs): def outer_mitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
return list_outputs[self.n_mit_mot: return list_outputs[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot] self.n_mit_mot + self.n_mit_sot]
...@@ -606,6 +616,8 @@ class Scan(PureOp): ...@@ -606,6 +616,8 @@ class Scan(PureOp):
return list_inputs[offset:offset + self.n_sit_sot] return list_inputs[offset:offset + self.n_sit_sot]
def outer_sitsot(self, list_inputs): def outer_sitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot
return list_inputs[offset:offset + self.n_sit_sot] return list_inputs[offset:offset + self.n_sit_sot]
...@@ -615,10 +627,14 @@ class Scan(PureOp): ...@@ -615,10 +627,14 @@ class Scan(PureOp):
return list_outputs[offset:offset + self.n_sit_sot] return list_outputs[offset:offset + self.n_sit_sot]
def outer_sitsot_outs(self, list_outputs): def outer_sitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = self.n_mit_mot + self.n_mit_sot offset = self.n_mit_mot + self.n_mit_sot
return list_outputs[offset:offset + self.n_sit_sot] return list_outputs[offset:offset + self.n_sit_sot]
def outer_nitsot(self, list_inputs): def outer_nitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_shared_outs) self.n_sit_sot + self.n_shared_outs)
return list_inputs[offset:offset + self.n_nit_sot] return list_inputs[offset:offset + self.n_nit_sot]
...@@ -629,6 +645,8 @@ class Scan(PureOp): ...@@ -629,6 +645,8 @@ class Scan(PureOp):
return list_outputs[offset:offset + self.n_nit_sot] return list_outputs[offset:offset + self.n_nit_sot]
def outer_nitsot_outs(self, list_outputs): def outer_nitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot) offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot)
return list_outputs[offset:offset + self.n_nit_sot] return list_outputs[offset:offset + self.n_nit_sot]
...@@ -640,6 +658,8 @@ class Scan(PureOp): ...@@ -640,6 +658,8 @@ class Scan(PureOp):
return list_inputs[offset:offset + self.n_shared_outs] return list_inputs[offset:offset + self.n_shared_outs]
def outer_shared(self, list_inputs): def outer_shared(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot) self.n_sit_sot)
return list_inputs[offset:offset + self.n_shared_outs] return list_inputs[offset:offset + self.n_shared_outs]
...@@ -650,6 +670,8 @@ class Scan(PureOp): ...@@ -650,6 +670,8 @@ class Scan(PureOp):
return list_outputs[offset:offset + self.n_shared_outs] return list_outputs[offset:offset + self.n_shared_outs]
def outer_shared_outs(self, list_outputs): def outer_shared_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot + offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot) self.n_nit_sot)
return list_outputs[offset:offset + self.n_shared_outs] return list_outputs[offset:offset + self.n_shared_outs]
...@@ -663,6 +685,8 @@ class Scan(PureOp): ...@@ -663,6 +685,8 @@ class Scan(PureOp):
return list_inputs[offset:] return list_inputs[offset:]
def outer_non_seqs(self, list_inputs): def outer_non_seqs(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot + offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_nit_sot + self.n_shared_outs) self.n_sit_sot + self.n_nit_sot + self.n_shared_outs)
return list_inputs[offset:] return list_inputs[offset:]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论