提交 e0348d6f authored 作者: Frederic's avatar Frederic

Add the minimal number of condition

上级 9828e7d1
...@@ -1717,14 +1717,13 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1717,14 +1717,13 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# Remove Alloc in DimShuffle # Remove Alloc in DimShuffle
elif i.owner and dimshuffled_alloc(i): elif i.owner and dimshuffled_alloc(i):
assert i.type.ndim == cmp_op.type.ndim assert i.type.ndim == cmp_op.type.ndim
if (theano.config.experimental.local_alloc_elemwise_assert if theano.config.experimental.local_alloc_elemwise_assert:
and not all([same_shape(i, cmp_op, idx, idx) assert_cond = [T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim) for idx in xrange(i.type.ndim)
if not i.type.broadcastable[idx]])): if not i.type.broadcastable[idx] and
assert_op = assert_(assert_op, not same_shape(i, cmp_op, idx, idx)]
*[T.eq(i.shape[idx], cmp_op.shape[idx]) if assert_cond:
for idx in xrange(i.type.ndim) assert_op = assert_(assert_op, *assert_cond)
if not i.type.broadcastable[idx]])
alloc_input = i.owner.inputs[0].owner.inputs[0] alloc_input = i.owner.inputs[0].owner.inputs[0]
if alloc_input.ndim != i.owner.inputs[0].ndim: if alloc_input.ndim != i.owner.inputs[0].ndim:
# The alloc can add dimension to the value # The alloc can add dimension to the value
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论