提交 f629d018 authored 作者: Frederic's avatar Frederic

Fix to the fix :) ScalarVariable and random state type don't have a dtype field.

上级 211a3307
......@@ -246,9 +246,9 @@ def local_gpu_dot_to_dot22(node):
# In case the got do input upcast, we much check that we can
# make it run on the gpu.
if node.outputs[0].dtype != 'float32':
return False
if node.op == gpu_from_host:
if node.outputs[0].type.dtype != 'float32':
return False
host_input = node.inputs[0]
if host_input.owner and host_input.owner.op == tensor.basic.dot:
x, y = host_input.owner.inputs
......@@ -269,6 +269,8 @@ def local_gpu_dot_to_dot22(node):
return [GpuReshape(1)(gpu_dot22(gpu_x, gpu_y), shape_out)]
if node.op == tensor.basic.dot:
if node.outputs[0].type.dtype != 'float32':
return False
if numpy.any([(i.owner and i.owner.op == host_from_gpu) for i in node.inputs]):
x, y = node.inputs
if _is_real_vector(x) and _is_real_matrix(y):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论