提交 a0290484 authored 作者: Frederic's avatar Frederic

Fix Warning in AdvancedSubtensor.infer_shape.

This just remove warning and to not implement the case where the warning was generated.
上级 4c513ba6
......@@ -1893,7 +1893,9 @@ class AdvancedSubtensor(Op):
# Really special case
if len(ishapes) == 3:
xshp, ind1shp, ind2shp = ishapes
if len(xshp) == 2 and len(ind1shp) == 1 and len(ind2shp) == 1:
if (len(xshp) == 2 and
ind1shp is not None and len(ind1shp) == 1 and
ind2shp is not None and len(ind2shp) == 1):
# if the graph is correct, we can assume ind1shp[0] and
# ind2shp[0] will have the same value.
# Try to return the one closest to the graph input.
......
......@@ -1455,3 +1455,24 @@ class TestInferShape(utt.InferShapeTester):
self._compile_and_check([admat, advec],
[set_subtensor(admat[aivec_val, bivec_val], advec)],
[admat_val, advec_val], AdvancedIncSubtensor)
def test_adv_sub(self):
admat = dmatrix()
aivec = lvector()
bivec = lvector()
admat_val = rand(5, 4)
aivec_val = [1, 3, 2]
bivec_val = [0, 3, 3]
self._compile_and_check([admat, aivec, bivec],
[admat[aivec, bivec]],
[admat_val, aivec_val, bivec_val], AdvancedSubtensor)
# Test case that aren't implemented, but make sure they do not crash.
self._compile_and_check([admat, aivec],
[admat[aivec, 1:3]],
[admat_val, aivec_val], AdvancedSubtensor,
check_topo=False)
self._compile_and_check([admat, aivec],
[admat[1:3, aivec]],
[admat_val, aivec_val], AdvancedSubtensor,
check_topo=False)
......@@ -191,7 +191,7 @@ class InferShapeTester(unittest.TestCase):
self.mode = mode.including("canonicalize")
def _compile_and_check(self, inputs, outputs, numeric_inputs, cls,
excluding=None, warn=True):
excluding=None, warn=True, check_topo=True):
"""This tests the infer_shape method only
When testing with input values with shapes that take the same
......@@ -204,6 +204,9 @@ class InferShapeTester(unittest.TestCase):
matrices will not detect the problem. If warn=True, we emit a
warning when testing with such values.
:param check_topo: If True, we check that the Op where removed
from the graph. False is useful to test not implemented case.
"""
mode = self.mode
if excluding:
......@@ -236,8 +239,9 @@ class InferShapeTester(unittest.TestCase):
mode=mode)
#theano.printing.debugprint(shapes_function)
# Check that the Op is removed from the compiled function.
topo_shape = shapes_function.maker.fgraph.toposort()
assert not any(isinstance(t.op, cls) for t in topo_shape)
if check_topo:
topo_shape = shapes_function.maker.fgraph.toposort()
assert not any(isinstance(t.op, cls) for t in topo_shape)
topo_out = outputs_function.maker.fgraph.toposort()
assert any(isinstance(t.op, cls) for t in topo_out)
# Check that the shape produced agrees with the actual shape.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论