提交 08b447f7 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5508 from nouiz/scan_gpuarray

Force scan n_steps on the CPU.
......@@ -66,7 +66,7 @@ from theano.gof import PureOp, Apply
from theano.gof.graph import io_connection_pattern
from theano.gof.toolbox import NoOutputFromInplace
from theano.compat import izip
from theano.tensor import TensorType
from theano.tensor import as_tensor_variable, TensorType
from theano.tensor.opt import Shape_i
from theano.gradient import grad_undefined, DisconnectedType, NullType
from six import string_types
......@@ -349,7 +349,8 @@ class Scan(PureOp):
assert n_outer_ins == n_inner_ins, \
("The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan.")
new_inputs = [inputs[0]]
# Force the inputs to be on the CPU
new_inputs = [as_tensor_variable(inputs[0])]
# assert dtype is consistent
err_msg1 = ('When compiling the inner function of scan (the '
'function called by scan in each of its iterations) '
......
......@@ -834,9 +834,14 @@ class MakeVectorPrinter:
if r.owner is None:
raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, MakeVector):
return "[%s]" % ", ".join(
pstate.pprinter.process(input, pstate.clone(precedence=1000))
for input in r.owner.inputs)
old_precedence = getattr(pstate, 'precedence', None)
try:
pstate.precedence = 1000
s = [pstate.pprinter.process(input)
for input in r.owner.inputs]
finally:
pstate.precedence = old_precedence
return "[%s]" % ", ".join(s)
else:
raise TypeError("Can only print make_vector.")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论