提交 1e507512 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #4 from nouiz/lamblin-fix_inc_set_subtensor1

Remove useless warning and make the graph more optimized in old supported cases.
...@@ -1119,8 +1119,13 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, ...@@ -1119,8 +1119,13 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
# We also explicitly duplicate y to its broadcasted shape # We also explicitly duplicate y to its broadcasted shape
# before we partially flatten it to inner_x dimension. This is # before we partially flatten it to inner_x dimension. This is
# not strictly needed in all cases, but it is easier this way. # not strictly needed in all cases, but it is easier this way.
if y.ndim > 0:
# This if is needed to prevent some useless warning about
# old code bug.
expanded_y = alloc(y, *[x.shape[i] for i in range(x.ndim)]) expanded_y = alloc(y, *[x.shape[i] for i in range(x.ndim)])
flattened_y = expanded_y.flatten(inner_x.ndim) flattened_y = expanded_y.flatten(inner_x.ndim)
else:
flattened_y = y
# Warn if this code path would have produced wrong results in the past # Warn if this code path would have produced wrong results in the past
if config.warn.inc_set_subtensor1: if config.warn.inc_set_subtensor1:
......
...@@ -1035,15 +1035,8 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -1035,15 +1035,8 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
m = matrix('m') m = matrix('m')
i = lmatrix('i') i = lmatrix('i')
# That test actually gave correct results, the warning is
# a bit too broad
orig_warn = config.warn.inc_set_subtensor1
try:
config.warn.inc_set_subtensor1 = False
m1 = set_subtensor(m[:, i], 0) m1 = set_subtensor(m[:, i], 0)
m2 = inc_subtensor(m[:, i], 1) m2 = inc_subtensor(m[:, i], 1)
finally:
config.warn.inc_set_subtensor1 = orig_warn
f = theano.function([m, i], [m1, m2]) f = theano.function([m, i], [m1, m2])
...@@ -1060,10 +1053,10 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -1060,10 +1053,10 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
assert numpy.allclose(m1_val, m1_ref), (m1_val, m1_ref) assert numpy.allclose(m1_val, m1_ref), (m1_val, m1_ref)
assert numpy.allclose(m2_val, m2_ref), (m2_val, m2_ref) assert numpy.allclose(m2_val, m2_ref), (m2_val, m2_ref)
def test_adv1_inc_sub_notlastdim_1dval(self): def test_adv1_inc_sub_notlastdim_1_2dval_broadcast(self):
# Test that taking 1-dimensional advanced indexing # Test that taking 1-dimensional advanced indexing
# over a dimension that's not the first (outer-most), # over a dimension that's not the first (outer-most),
# and incrementing/setting a 1D value works. # and incrementing/setting with broadcast
m = matrix('m') m = matrix('m')
# Test for both vector and matrix as index # Test for both vector and matrix as index
...@@ -1097,16 +1090,16 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -1097,16 +1090,16 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
finally: finally:
config.warn.inc_set_subtensor1 = orig_warn config.warn.inc_set_subtensor1 = orig_warn
def test_adv1_inc_sub_notlastdim_2dval(self): def test_adv1_inc_sub_notlastdim_1_2dval_no_broadcast(self):
# Test that taking 1-dimensional advanced indexing # Test that taking 1-dimensional advanced indexing
# over a dimension that's not the first (outer-most), # over a dimension that's not the first (outer-most),
# and incrementing/setting a 2D value works. # and incrementing/setting without broadcast
m = matrix('m') m = matrix('m')
# Test for both vector and matrix as index # Test for both vector and matrix as index
sym_i = (lvector('i'), lmatrix('i')) sym_i = (lvector('i'), lmatrix('i'))
shape_i = ((4,), (4, 2)) shape_i = ((4,), (4, 2))
shape_val = ((3, 1), (3, 1, 1)) shape_val = ((3, 4), (3, 4, 2))
# Disable the warning emitted for that case # Disable the warning emitted for that case
orig_warn = config.warn.inc_set_subtensor1 orig_warn = config.warn.inc_set_subtensor1
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论