提交 660916d6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Don't try to create invalid `BatchedDot` in `specialize_matmul_to_batched_dot` rewrite

上级 ad3b4ac7
...@@ -916,6 +916,10 @@ def specialize_matmul_to_batched_dot(fgraph, node): ...@@ -916,6 +916,10 @@ def specialize_matmul_to_batched_dot(fgraph, node):
""" """
x, y = node.inputs x, y = node.inputs
if x.type.ndim < 3:
# This doesn't actually have a batch dimension
return None
# BatchedDot does not allow implicit broadcasting of the batch dimensions # BatchedDot does not allow implicit broadcasting of the batch dimensions
# We do not want to explicitly broadcast as it may result in huge arrays # We do not want to explicitly broadcast as it may result in huge arrays
if x.type.broadcastable[:-2] != y.type.broadcastable[:-2]: if x.type.broadcastable[:-2] != y.type.broadcastable[:-2]:
...@@ -926,6 +930,7 @@ def specialize_matmul_to_batched_dot(fgraph, node): ...@@ -926,6 +930,7 @@ def specialize_matmul_to_batched_dot(fgraph, node):
if len(x_shape) > 3: if len(x_shape) > 3:
# If we have more than one batch dim, ravel it # If we have more than one batch dim, ravel it
x = x.reshape((-1, x_shape[-2], x_shape[-1])) x = x.reshape((-1, x_shape[-2], x_shape[-1]))
if len(y_shape) > 3:
y = y.reshape((-1, y_shape[-2], y_shape[-1])) y = y.reshape((-1, y_shape[-2], y_shape[-1]))
new_out = _batched_dot(x, y) new_out = _batched_dot(x, y)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论