提交 761fe6e7 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make the Gemm optimizations work on float16.

上级 035e9e13
......@@ -1097,14 +1097,14 @@ def _as_scalar(res, dtype=None):
def _is_real_matrix(res):
return (res.type.dtype in ('float32', 'float64') and
return (res.type.dtype in ('float16', 'float32', 'float64') and
res.type.ndim == 2 and
res.type.broadcastable[0] is False and
res.type.broadcastable[1] is False) # cope with tuple vs. list
def _is_real_vector(res):
return (res.type.dtype in ('float32', 'float64') and
return (res.type.dtype in ('float16', 'float32', 'float64') and
res.type.ndim == 1 and
res.type.broadcastable[0] is False)
......@@ -1199,7 +1199,7 @@ def _gemm_canonicalize(r, scale, rval, maxclients):
return None
if ((r.type.ndim not in (1, 2)) or
r.type.dtype not in ('float32', 'float64',
r.type.dtype not in ('float16', 'float32', 'float64',
'complex64', 'complex128')):
rval.append(scaled(r))
return rval
......@@ -1532,7 +1532,7 @@ class Dot22(GemmRelated):
"""
def make_node(self, x, y):
dtypes = ('float32', 'float64', 'complex64', 'complex128')
dtypes = ('float16', 'float32', 'float64', 'complex64', 'complex128')
if x.type.ndim != 2 or x.type.dtype not in dtypes:
raise TypeError(x)
if y.type.ndim != 2 or y.type.dtype not in dtypes:
......@@ -1625,7 +1625,7 @@ def local_dot_to_dot22(node):
x, y, x.type, y.type)
return
if y.type.dtype in ['float32', 'float64', 'complex64', 'complex128']:
if y.type.dtype in ['float16', 'float32', 'float64', 'complex64', 'complex128']:
if x.ndim == 2 and y.ndim == 2:
# print "local_dot_to_dot22: MM"
return [_dot22(*node.inputs)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论