提交 f0217c61 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

add assertion that dtype is float32 into the function

上级 f835b5ac
...@@ -1525,8 +1525,9 @@ def gpuScanOptimization(node): ...@@ -1525,8 +1525,9 @@ def gpuScanOptimization(node):
local_fgraph = gof.FunctionGraph(tmp_in, tmp_out) local_fgraph = gof.FunctionGraph(tmp_in, tmp_out)
_cmodule_key = gof.CLinker.cmodule_key_(local_fgraph, []) _cmodule_key = gof.CLinker.cmodule_key_(local_fgraph, [])
info['gpu_hash'] = hash(_cmodule_key) info['gpu_hash'] = hash(_cmodule_key)
typeConstructor = lambda broadcastable, dtype: CudaNdarrayType( def typeConstructor(broadcastable, dtype):
broadcastable=broadcastable) assert dtype == 'float32'
return CudaNdarrayType(broadcastable=broadcastable)
_outputs = scan_op.Scan( _outputs = scan_op.Scan(
scan_ins, scan_ins,
scan_outs, scan_outs,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论