提交 761fe6e7 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make the Gemm optimizations work on float16.

上级 035e9e13
...@@ -1097,14 +1097,14 @@ def _as_scalar(res, dtype=None): ...@@ -1097,14 +1097,14 @@ def _as_scalar(res, dtype=None):
def _is_real_matrix(res): def _is_real_matrix(res):
return (res.type.dtype in ('float32', 'float64') and return (res.type.dtype in ('float16', 'float32', 'float64') and
res.type.ndim == 2 and res.type.ndim == 2 and
res.type.broadcastable[0] is False and res.type.broadcastable[0] is False and
res.type.broadcastable[1] is False) # cope with tuple vs. list res.type.broadcastable[1] is False) # cope with tuple vs. list
def _is_real_vector(res): def _is_real_vector(res):
return (res.type.dtype in ('float32', 'float64') and return (res.type.dtype in ('float16', 'float32', 'float64') and
res.type.ndim == 1 and res.type.ndim == 1 and
res.type.broadcastable[0] is False) res.type.broadcastable[0] is False)
...@@ -1199,7 +1199,7 @@ def _gemm_canonicalize(r, scale, rval, maxclients): ...@@ -1199,7 +1199,7 @@ def _gemm_canonicalize(r, scale, rval, maxclients):
return None return None
if ((r.type.ndim not in (1, 2)) or if ((r.type.ndim not in (1, 2)) or
r.type.dtype not in ('float32', 'float64', r.type.dtype not in ('float16', 'float32', 'float64',
'complex64', 'complex128')): 'complex64', 'complex128')):
rval.append(scaled(r)) rval.append(scaled(r))
return rval return rval
...@@ -1532,7 +1532,7 @@ class Dot22(GemmRelated): ...@@ -1532,7 +1532,7 @@ class Dot22(GemmRelated):
""" """
def make_node(self, x, y): def make_node(self, x, y):
dtypes = ('float32', 'float64', 'complex64', 'complex128') dtypes = ('float16', 'float32', 'float64', 'complex64', 'complex128')
if x.type.ndim != 2 or x.type.dtype not in dtypes: if x.type.ndim != 2 or x.type.dtype not in dtypes:
raise TypeError(x) raise TypeError(x)
if y.type.ndim != 2 or y.type.dtype not in dtypes: if y.type.ndim != 2 or y.type.dtype not in dtypes:
...@@ -1625,7 +1625,7 @@ def local_dot_to_dot22(node): ...@@ -1625,7 +1625,7 @@ def local_dot_to_dot22(node):
x, y, x.type, y.type) x, y, x.type, y.type)
return return
if y.type.dtype in ['float32', 'float64', 'complex64', 'complex128']: if y.type.dtype in ['float16', 'float32', 'float64', 'complex64', 'complex128']:
if x.ndim == 2 and y.ndim == 2: if x.ndim == 2 and y.ndim == 2:
# print "local_dot_to_dot22: MM" # print "local_dot_to_dot22: MM"
return [_dot22(*node.inputs)] return [_dot22(*node.inputs)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论