提交 befdc27d authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Removed the use of safe_to_cpu

safe_to_cpu was one of the reason that I was importing cuda. It was a misunderstanding on my part, and the way is done now is somewhat more sane.
上级 db9dc771
......@@ -51,7 +51,7 @@ from theano.sandbox import cuda
import scan_op
import scan_utils
from scan_utils import safe_new, safe_to_cpu, traverse
from scan_utils import safe_new, traverse
# Logging function for sending warning or info
_logger = logging.getLogger('theano.scan_module.scan')
......@@ -892,7 +892,7 @@ def scan( fn
##
### Step 8. Compute the outputs using the scan op
##
scan_inputs = ( scan_seqs +
_scan_inputs = ( scan_seqs +
mit_mot_scan_inputs +
mit_sot_scan_inputs +
sit_sot_scan_inputs +
......@@ -901,7 +901,15 @@ def scan( fn
other_shared_scan_args +
other_scan_args )
scan_inputs = [safe_to_cpu(x) for x in ([actual_n_steps] + scan_inputs)]
scan_inputs = []
for arg in [actual_n_steps]+ _scan_inputs:
try:
arg = tensor.as_tensor_variable(arg)
except TypeError:
# This happens for Random States for e.g. but it is a good way
# to make sure no input is a cuda ndarrays
pass
scan_inputs += [arg]
scan_outs = local_op(* scan_inputs )
if type(scan_outs) not in (list,tuple):
scan_outs = [scan_outs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论