提交 660916d6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Don't try to create invalid `BatchedDot` in `specialize_matmul_to_batched_dot` rewrite

上级 ad3b4ac7
......@@ -916,6 +916,10 @@ def specialize_matmul_to_batched_dot(fgraph, node):
"""
x, y = node.inputs
if x.type.ndim < 3:
# This doesn't actually have a batch dimension
return None
# BatchedDot does not allow implicit broadcasting of the batch dimensions
# We do not want to explicitly broadcast as it may result in huge arrays
if x.type.broadcastable[:-2] != y.type.broadcastable[:-2]:
......@@ -926,6 +930,7 @@ def specialize_matmul_to_batched_dot(fgraph, node):
if len(x_shape) > 3:
# If we have more than one batch dim, ravel it
x = x.reshape((-1, x_shape[-2], x_shape[-1]))
if len(y_shape) > 3:
y = y.reshape((-1, y_shape[-2], y_shape[-1]))
new_out = _batched_dot(x, y)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论