提交 d3bba908 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make local_dot_to_dot22 stop failing for float16.

上级 2c1d697c
...@@ -1728,7 +1728,7 @@ def local_dot_to_dot22(node): ...@@ -1728,7 +1728,7 @@ def local_dot_to_dot22(node):
x, y, x.type, y.type) x, y, x.type, y.type)
return return
if y.type.dtype.startswith('float') or y.type.dtype.startswith('complex'): if y.type.dtype in ['float32', 'float64', 'complex64', 'complex128']:
if x.ndim == 2 and y.ndim == 2: if x.ndim == 2 and y.ndim == 2:
# print "local_dot_to_dot22: MM" # print "local_dot_to_dot22: MM"
return [_dot22(*node.inputs)] return [_dot22(*node.inputs)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论