提交 40ba569b authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed crash in a specific concatenate situation

上级 87c38921
...@@ -79,6 +79,8 @@ Crash Fix ...@@ -79,6 +79,8 @@ Crash Fix
(Pascal L., reported by Simon McGregor) (Pascal L., reported by Simon McGregor)
* Fixed issue with the MaxAndArgmax Op not properly preserving broadcastable * Fixed issue with the MaxAndArgmax Op not properly preserving broadcastable
dimensions, which could typically result in optimization crashes (Olivier D.) dimensions, which could typically result in optimization crashes (Olivier D.)
* Fixed crash when concatenating some arrays with specific broadcasting
patterns (Olivier D.)
============= =============
Release Notes Release Notes
......
...@@ -2078,7 +2078,11 @@ def local_rebroadcast_lift(node): ...@@ -2078,7 +2078,11 @@ def local_rebroadcast_lift(node):
input = node.inputs[0] input = node.inputs[0]
inode = input.owner inode = input.owner
if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1: if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1:
if len(input.clients) == 1: # It may happen that `input` has no client because this optimization
# is called from `apply_rebroadcast_opt`, which in particular is used
# by the `unbroadcast` function before we are in the actual function
# compilation phase.
if hasattr(input, 'clients') and len(input.clients) == 1:
rval = inode.op.make_node(T.Rebroadcast(*op.axis.items())( rval = inode.op.make_node(T.Rebroadcast(*op.axis.items())(
inode.inputs[0])).outputs inode.inputs[0])).outputs
return rval return rval
...@@ -2098,7 +2102,7 @@ def apply_rebroadcast_opt(rval): ...@@ -2098,7 +2102,7 @@ def apply_rebroadcast_opt(rval):
and local_rebroadcast_lift. and local_rebroadcast_lift.
:param rval: a Variable :param rval: a Variable
:retrun: a Variable. The same if not optimisation can be applied. :return: a Variable (the same if no optimization can be applied)
""" """
changed = True changed = True
......
...@@ -3387,6 +3387,15 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3387,6 +3387,15 @@ class T_Join_and_Split(unittest.TestCase):
else: else:
f(get_mat(3, 4), get_mat(3, 4), get_mat(2, 5)) f(get_mat(3, 4), get_mat(3, 4), get_mat(2, 5))
def test_rebroadcast(self):
"""
Regression test for a crash that used to happen when rebroadcasting.
"""
x = tensor.TensorType(floatX, [False, False, True])()
u = tensor.TensorType(floatX, [False, False, True])()
# This line used to crash.
z = tensor.concatenate([x, -u], axis=2)
class test_comparison(unittest.TestCase): class test_comparison(unittest.TestCase):
def test_gt(self): def test_gt(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论