提交 15f89be2 authored 作者: Melanie Ducoffe's avatar Melanie Ducoffe

adds dtype in props of AllocEmpty and remove debug print

上级 5a09224f
...@@ -2042,8 +2042,6 @@ def local_dot22_to_dot22scalar(node): ...@@ -2042,8 +2042,6 @@ def local_dot22_to_dot22scalar(node):
dtype=d.dtype), d.type.dtype) dtype=d.dtype), d.type.dtype)
assert not a.type.ndim assert not a.type.ndim
# Deprecated :
#dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
z = T.AllocEmpty(d.owner.inputs[0].dtype)(d.owner.inputs[0].shape[0], z = T.AllocEmpty(d.owner.inputs[0].dtype)(d.owner.inputs[0].shape[0],
d.owner.inputs[1].shape[1]) d.owner.inputs[1].shape[1])
zero = T.as_tensor_variable(numpy.asarray(0, dtype=a.dtype)) zero = T.as_tensor_variable(numpy.asarray(0, dtype=a.dtype))
...@@ -2083,23 +2081,16 @@ def local_dot22_to_dot22scalar(node): ...@@ -2083,23 +2081,16 @@ def local_dot22_to_dot22scalar(node):
a = T.cast(i_scalar[scalar_idx], d.type.dtype) a = T.cast(i_scalar[scalar_idx], d.type.dtype)
assert not a.type.ndim assert not a.type.ndim
if len(o) == 0: if len(o) == 0:
# Deprecated
#return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)]
z = T.AllocEmpty(d.owner.inputs[0].dtype)(d.owner.inputs[0].shape[0], z = T.AllocEmpty(d.owner.inputs[0].dtype)(d.owner.inputs[0].shape[0],
d.owner.inputs[1].shape[1]) d.owner.inputs[1].shape[1])
zero = T.as_tensor_variable(numpy.asarray(0, dtype=a.dtype)) zero = T.as_tensor_variable(numpy.asarray(0, dtype=a.dtype))
return [gemm(z, a, d.owner.inputs[0], d.owner.inputs[1], zero)] return [gemm(z, a, d.owner.inputs[0], d.owner.inputs[1], zero)]
else: else:
# Deprecated
#return [T.mul(_dot22scalar(d.owner.inputs[0],
# d.owner.inputs[1], a), *o)]
z = T.AllocEmpty(d.owner.inputs[0].dtype)(d.owner.inputs[0].shape[0], z = T.AllocEmpty(d.owner.inputs[0].dtype)(d.owner.inputs[0].shape[0],
d.owner.inputs[1].shape[1]) d.owner.inputs[1].shape[1])
zero = T.as_tensor_variable(numpy.asarray(0, dtype=a.dtype)) zero = T.as_tensor_variable(numpy.asarray(0, dtype=a.dtype))
return [T.mul(gemm(z, a, d.owner.inputs[0], d.owner.inputs[1], return [T.mul(gemm(z, a, d.owner.inputs[0], d.owner.inputs[1],
zero), *o)] zero), *o)]
# must happen after gemm as the gemm optimizer don't understant # must happen after gemm as the gemm optimizer don't understant
# dot22scalar and gemm give more speed up then dot22scalar # dot22scalar and gemm give more speed up then dot22scalar
blas_optdb.register('local_dot22_to_dot22scalar', blas_optdb.register('local_dot22_to_dot22scalar',
......
...@@ -930,6 +930,7 @@ def test_dot22scalar(): ...@@ -930,6 +930,7 @@ def test_dot22scalar():
mode=mode_blas_opt) mode=mode_blas_opt)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
check_dot22scalar(f, 5) check_dot22scalar(f, 5)
#print (av.dtype, bv.dtype, cv.dtype)
f(av, bv, cv) f(av, bv, cv)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论