提交 45c0a9a9 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5518 from abergeron/fix_scan_gpuarray

Fix gpu integers in scan.
...@@ -193,7 +193,8 @@ def op_lifter(OP, cuda_only=False): ...@@ -193,7 +193,8 @@ def op_lifter(OP, cuda_only=False):
context_name = None context_name = None
# We replace if any input is a host_from_gpu # We replace if any input is a host_from_gpu
for i in node.inputs: for i in node.inputs:
if i.owner and i.owner.op == host_from_gpu: if (i.owner and i.owner.op == host_from_gpu and
move_to_gpu(i)):
context_name = i.owner.inputs[0].type.context_name context_name = i.owner.inputs[0].type.context_name
replace = True replace = True
break break
...@@ -922,7 +923,7 @@ def local_gpua_lazy_ifelse(op, context_name, inputs, outputs): ...@@ -922,7 +923,7 @@ def local_gpua_lazy_ifelse(op, context_name, inputs, outputs):
c = inputs[0] c = inputs[0]
inps = [] inps = []
for v in inputs[1:]: for v in inputs[1:]:
if isinstance(v.type, tensor.TensorType): if isinstance(v.type, tensor.TensorType) and move_to_gpu(v):
inps.append(as_gpuarray_variable(v, context_name)) inps.append(as_gpuarray_variable(v, context_name))
else: else:
inps.append(v) inps.append(v)
......
...@@ -619,7 +619,8 @@ class Scan(PureOp): ...@@ -619,7 +619,8 @@ class Scan(PureOp):
# in this case is just a int saying how many steps of this output we # in this case is just a int saying how many steps of this output we
# need to store. This input does not have the same dtype, nor is it the same # need to store. This input does not have the same dtype, nor is it the same
# type of tensor as the output, it is always a scalar int. # type of tensor as the output, it is always a scalar int.
new_inputs += self.outer_nitsot(inputs) new_inputs += [as_tensor_variable(ons)
for ons in self.outer_nitsot(inputs)]
for inner_nonseq, _outer_nonseq in zip(self.inner_non_seqs(self.inputs), for inner_nonseq, _outer_nonseq in zip(self.inner_non_seqs(self.inputs),
self.outer_non_seqs(inputs)): self.outer_non_seqs(inputs)):
outer_nonseq = format(_outer_nonseq, as_var=inner_nonseq) outer_nonseq = format(_outer_nonseq, as_var=inner_nonseq)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论