提交 da367cc3 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #304 from nouiz/fix_gpu_scan_opt

Fix scan crash in recent change. The code seems fine.,
......@@ -550,6 +550,8 @@ class Scan(PureOp):
return list_inputs[:self.n_seqs]
def outer_seqs(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
# Given the list of outter inputs this function grabs those
# corresponding to sequences
return list_inputs[1:1 + self.n_seqs]
......@@ -559,6 +561,8 @@ class Scan(PureOp):
return list_inputs[self.n_seqs: self.n_seqs + n_taps]
def outer_mitmot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
return list_inputs[1 + self.n_seqs:1 + self.n_seqs + self.n_mit_mot]
def inner_mitmot_outs(self, list_outputs):
......@@ -566,6 +570,8 @@ class Scan(PureOp):
return list_outputs[:n_taps]
def outer_mitmot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.ouputs
return list_outputs[:self.n_mit_mot]
def mitmot_taps(self):
......@@ -583,6 +589,8 @@ class Scan(PureOp):
self.n_seqs + ntaps_upto_sit_sot]
def outer_mitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot
return list_inputs[offset:offset + self.n_mit_sot]
......@@ -591,6 +599,8 @@ class Scan(PureOp):
return list_outputs[n_taps:n_taps + self.n_mit_sot]
def outer_mitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
return list_outputs[self.n_mit_mot:
self.n_mit_mot + self.n_mit_sot]
......@@ -606,6 +616,8 @@ class Scan(PureOp):
return list_inputs[offset:offset + self.n_sit_sot]
def outer_sitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = 1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot
return list_inputs[offset:offset + self.n_sit_sot]
......@@ -615,10 +627,14 @@ class Scan(PureOp):
return list_outputs[offset:offset + self.n_sit_sot]
def outer_sitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = self.n_mit_mot + self.n_mit_sot
return list_outputs[offset:offset + self.n_sit_sot]
def outer_nitsot(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_shared_outs)
return list_inputs[offset:offset + self.n_nit_sot]
......@@ -629,6 +645,8 @@ class Scan(PureOp):
return list_outputs[offset:offset + self.n_nit_sot]
def outer_nitsot_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot)
return list_outputs[offset:offset + self.n_nit_sot]
......@@ -640,6 +658,8 @@ class Scan(PureOp):
return list_inputs[offset:offset + self.n_shared_outs]
def outer_shared(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot)
return list_inputs[offset:offset + self.n_shared_outs]
......@@ -650,6 +670,8 @@ class Scan(PureOp):
return list_outputs[offset:offset + self.n_shared_outs]
def outer_shared_outs(self, list_outputs):
if isinstance(list_outputs, Apply):
list_outputs = list_outputs.outputs
offset = (self.n_mit_mot + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot)
return list_outputs[offset:offset + self.n_shared_outs]
......@@ -663,6 +685,8 @@ class Scan(PureOp):
return list_inputs[offset:]
def outer_non_seqs(self, list_inputs):
if isinstance(list_inputs, Apply):
list_inputs = list_inputs.inputs
offset = (1 + self.n_seqs + self.n_mit_mot + self.n_mit_sot +
self.n_sit_sot + self.n_nit_sot + self.n_shared_outs)
return list_inputs[offset:]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论