提交 08b447f7 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5508 from nouiz/scan_gpuarray

Force scan n_steps on the CPU.
...@@ -66,7 +66,7 @@ from theano.gof import PureOp, Apply ...@@ -66,7 +66,7 @@ from theano.gof import PureOp, Apply
from theano.gof.graph import io_connection_pattern from theano.gof.graph import io_connection_pattern
from theano.gof.toolbox import NoOutputFromInplace from theano.gof.toolbox import NoOutputFromInplace
from theano.compat import izip from theano.compat import izip
from theano.tensor import TensorType from theano.tensor import as_tensor_variable, TensorType
from theano.tensor.opt import Shape_i from theano.tensor.opt import Shape_i
from theano.gradient import grad_undefined, DisconnectedType, NullType from theano.gradient import grad_undefined, DisconnectedType, NullType
from six import string_types from six import string_types
...@@ -349,7 +349,8 @@ class Scan(PureOp): ...@@ -349,7 +349,8 @@ class Scan(PureOp):
assert n_outer_ins == n_inner_ins, \ assert n_outer_ins == n_inner_ins, \
("The number of inputs given to the inner function of scan" ("The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan.") " does not match the number of inputs given to scan.")
new_inputs = [inputs[0]] # Force the inputs to be on the CPU
new_inputs = [as_tensor_variable(inputs[0])]
# assert dtype is consistent # assert dtype is consistent
err_msg1 = ('When compiling the inner function of scan (the ' err_msg1 = ('When compiling the inner function of scan (the '
'function called by scan in each of its iterations) ' 'function called by scan in each of its iterations) '
......
...@@ -834,9 +834,14 @@ class MakeVectorPrinter: ...@@ -834,9 +834,14 @@ class MakeVectorPrinter:
if r.owner is None: if r.owner is None:
raise TypeError("Can only print make_vector.") raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, MakeVector): elif isinstance(r.owner.op, MakeVector):
return "[%s]" % ", ".join( old_precedence = getattr(pstate, 'precedence', None)
pstate.pprinter.process(input, pstate.clone(precedence=1000)) try:
for input in r.owner.inputs) pstate.precedence = 1000
s = [pstate.pprinter.process(input)
for input in r.owner.inputs]
finally:
pstate.precedence = old_precedence
return "[%s]" % ", ".join(s)
else: else:
raise TypeError("Can only print make_vector.") raise TypeError("Can only print make_vector.")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论