提交 9c53bf31 authored 作者: Frederic Bastien's avatar Frederic Bastien

[CRASH] Fix gh-6441, fix opt crash due to wrong dtype

上级 07c889c5
......@@ -12,8 +12,6 @@ from theano.tensor import (DimShuffle, get_scalar_constant_value,
from .basic_ops import GpuFromHost, HostFromGpu, GpuAllocEmpty, GpuReshape
from .elemwise import GpuDimShuffle, GpuElemwise
_one = scal.constant(np.asarray(1.0, dtype='float32'))
def grab_cpu_scalar(v, nd):
"""
......@@ -273,7 +271,9 @@ def output_merge(cls, alpha_in, beta_in, out_in):
return None
inputs = list(targ.inputs)
inputs[out_in] = W
inputs[beta_in] = _one.clone()
dtype = inputs[beta_in].dtype
one = scal.constant(np.asarray(1.0, dtype=dtype))
inputs[beta_in] = one
with inherit_stack_trace(node.outputs):
return maker(targ, *inputs)
return opt
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论