提交 b59a0c80 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Fix some floatX issues

上级 0c2a2b08
......@@ -708,17 +708,19 @@ def numba_funcify_Sum(op, node, **kwargs):
np_acc_dtype = np.dtype(acc_dtype)
out_dtype = np.dtype(node.outputs[0].dtype)
if ndim_input == len(axes):
@numba_njit(fastmath=True)
def impl_sum(array):
return np.asarray(array.sum(), dtype=np_acc_dtype)
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
elif len(axes) == 0:
@numba_njit(fastmath=True)
def impl_sum(array):
return array
return np.asarray(array, dtype=out_dtype)
else:
impl_sum = numba_funcify_CAReduce(op, node, **kwargs)
......
......@@ -97,7 +97,7 @@ def test_Clip(v, min, max):
],
)
def test_Composite(inputs, input_values, scalar_fn):
composite_inputs = [aes.float64(i.name) for i in inputs]
composite_inputs = [aes.ScalarType(config.floatX)(name=i.name) for i in inputs]
comp_op = Elemwise(Composite(composite_inputs, [scalar_fn(*composite_inputs)]))
out_fg = FunctionGraph(inputs, [comp_op(*inputs)])
compare_numba_and_py(out_fg, input_values)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论