提交 111400e4 authored 作者: Frederic's avatar Frederic

Fix opt crash with other dtype then float32, the code isn't ready.

上级 047c8913
...@@ -249,7 +249,12 @@ def local_gpua_careduce(node): ...@@ -249,7 +249,12 @@ def local_gpua_careduce(node):
isinstance(node.op.scalar_op, scalar.basic.Mul)): isinstance(node.op.scalar_op, scalar.basic.Mul)):
x, = node.inputs x, = node.inputs
greduce = GpuCAReduce(node.op.scalar_op, axis=node.op.axis) greduce = GpuCAReduce(node.op.scalar_op, axis=node.op.axis)
if greduce.supports_c_code([gpu_from_host(x)]): if x.dtype != "float32":
return
gvar = greduce(x)
#We need to have the make node called, otherwise the mask can
#be None
if gvar.owner.op.supports_c_code([gpu_from_host(x)]):
return greduce return greduce
else: else:
# Try to make a simpler pattern based on reshaping # Try to make a simpler pattern based on reshaping
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论