提交 a0290484 authored 作者: Frederic's avatar Frederic

Fix Warning in AdvancedSubtensor.infer_shape.

This just remove warning and to not implement the case where the warning was generated.
上级 4c513ba6
...@@ -1893,7 +1893,9 @@ class AdvancedSubtensor(Op): ...@@ -1893,7 +1893,9 @@ class AdvancedSubtensor(Op):
# Really special case # Really special case
if len(ishapes) == 3: if len(ishapes) == 3:
xshp, ind1shp, ind2shp = ishapes xshp, ind1shp, ind2shp = ishapes
if len(xshp) == 2 and len(ind1shp) == 1 and len(ind2shp) == 1: if (len(xshp) == 2 and
ind1shp is not None and len(ind1shp) == 1 and
ind2shp is not None and len(ind2shp) == 1):
# if the graph is correct, we can assume ind1shp[0] and # if the graph is correct, we can assume ind1shp[0] and
# ind2shp[0] will have the same value. # ind2shp[0] will have the same value.
# Try to return the one closest to the graph input. # Try to return the one closest to the graph input.
......
...@@ -1455,3 +1455,24 @@ class TestInferShape(utt.InferShapeTester): ...@@ -1455,3 +1455,24 @@ class TestInferShape(utt.InferShapeTester):
self._compile_and_check([admat, advec], self._compile_and_check([admat, advec],
[set_subtensor(admat[aivec_val, bivec_val], advec)], [set_subtensor(admat[aivec_val, bivec_val], advec)],
[admat_val, advec_val], AdvancedIncSubtensor) [admat_val, advec_val], AdvancedIncSubtensor)
def test_adv_sub(self):
admat = dmatrix()
aivec = lvector()
bivec = lvector()
admat_val = rand(5, 4)
aivec_val = [1, 3, 2]
bivec_val = [0, 3, 3]
self._compile_and_check([admat, aivec, bivec],
[admat[aivec, bivec]],
[admat_val, aivec_val, bivec_val], AdvancedSubtensor)
# Test case that aren't implemented, but make sure they do not crash.
self._compile_and_check([admat, aivec],
[admat[aivec, 1:3]],
[admat_val, aivec_val], AdvancedSubtensor,
check_topo=False)
self._compile_and_check([admat, aivec],
[admat[1:3, aivec]],
[admat_val, aivec_val], AdvancedSubtensor,
check_topo=False)
...@@ -191,7 +191,7 @@ class InferShapeTester(unittest.TestCase): ...@@ -191,7 +191,7 @@ class InferShapeTester(unittest.TestCase):
self.mode = mode.including("canonicalize") self.mode = mode.including("canonicalize")
def _compile_and_check(self, inputs, outputs, numeric_inputs, cls, def _compile_and_check(self, inputs, outputs, numeric_inputs, cls,
excluding=None, warn=True): excluding=None, warn=True, check_topo=True):
"""This tests the infer_shape method only """This tests the infer_shape method only
When testing with input values with shapes that take the same When testing with input values with shapes that take the same
...@@ -204,6 +204,9 @@ class InferShapeTester(unittest.TestCase): ...@@ -204,6 +204,9 @@ class InferShapeTester(unittest.TestCase):
matrices will not detect the problem. If warn=True, we emit a matrices will not detect the problem. If warn=True, we emit a
warning when testing with such values. warning when testing with such values.
:param check_topo: If True, we check that the Op where removed
from the graph. False is useful to test not implemented case.
""" """
mode = self.mode mode = self.mode
if excluding: if excluding:
...@@ -236,8 +239,9 @@ class InferShapeTester(unittest.TestCase): ...@@ -236,8 +239,9 @@ class InferShapeTester(unittest.TestCase):
mode=mode) mode=mode)
#theano.printing.debugprint(shapes_function) #theano.printing.debugprint(shapes_function)
# Check that the Op is removed from the compiled function. # Check that the Op is removed from the compiled function.
topo_shape = shapes_function.maker.fgraph.toposort() if check_topo:
assert not any(isinstance(t.op, cls) for t in topo_shape) topo_shape = shapes_function.maker.fgraph.toposort()
assert not any(isinstance(t.op, cls) for t in topo_shape)
topo_out = outputs_function.maker.fgraph.toposort() topo_out = outputs_function.maker.fgraph.toposort()
assert any(isinstance(t.op, cls) for t in topo_out) assert any(isinstance(t.op, cls) for t in topo_out)
# Check that the shape produced agrees with the actual shape. # Check that the shape produced agrees with the actual shape.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论