提交 01078889 authored 作者: --global's avatar --global

Revert previous change to replace dot22 with gemm instead of dot22scalar

上级 c7a81516
......@@ -2043,12 +2043,8 @@ def local_dot22_to_dot22scalar(node):
a = T.cast(_as_scalar(m.owner.inputs[scalar_idx],
dtype=d.dtype), d.type.dtype)
assert not a.type.ndim
dot = _dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
z = T.AllocEmpty(d.owner.inputs[0].dtype)(d.owner.inputs[0].shape[0],
d.owner.inputs[1].shape[1])
zero = T.as_tensor_variable(numpy.asarray(0, dtype=a.dtype))
dot = gemm(z, a, d.owner.inputs[0], d.owner.inputs[1], zero)
# The other inputs to the original node that were
# neither part of the dot22 or this mul should be
# factors in the returned "mul" node.
......@@ -2083,16 +2079,12 @@ def local_dot22_to_dot22scalar(node):
a = T.cast(i_scalar[scalar_idx], d.type.dtype)
assert not a.type.ndim
if len(o) == 0:
z = T.AllocEmpty(d.owner.inputs[0].dtype)(d.owner.inputs[0].shape[0],
d.owner.inputs[1].shape[1])
zero = T.as_tensor_variable(numpy.asarray(0, dtype=a.dtype))
return [gemm(z, a, d.owner.inputs[0], d.owner.inputs[1], zero)]
return [_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)]
else:
z = T.AllocEmpty(d.owner.inputs[0].dtype)(d.owner.inputs[0].shape[0],
d.owner.inputs[1].shape[1])
zero = T.as_tensor_variable(numpy.asarray(0, dtype=a.dtype))
return [T.mul(gemm(z, a, d.owner.inputs[0], d.owner.inputs[1],
zero), *o)]
return [T.mul(_dot22scalar(d.owner.inputs[0],
d.owner.inputs[1], a), *o)]
# must happen after gemm as the gemm optimizer don't understant
# dot22scalar and gemm give more speed up then dot22scalar
blas_optdb.register('local_dot22_to_dot22scalar',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论