提交 1dad197f authored 作者: lamblin's avatar lamblin

Merge pull request #477 from nouiz/infer_shape

Infer shape
...@@ -140,29 +140,7 @@ class T_transpose(unittest.TestCase): ...@@ -140,29 +140,7 @@ class T_transpose(unittest.TestCase):
self.assertTrue(vta.shape == (3, 5)) self.assertTrue(vta.shape == (3, 5))
class SparseInferShapeTester(unittest.TestCase): class SparseInferShapeTester(utt.InferShapeTester):
def setUp(self):
utt.seed_rng()
# This mode seems to be the minimal one including the shape_i
# optimizations, if we don't want to enumerate them explicitly.
self.mode = theano.compile.get_default_mode().including("canonicalize")
def _compile_and_check(self, inputs, outputs, numeric_inputs, cls):
outputs_function = theano.function(inputs, outputs, mode=self.mode)
shapes_function = theano.function(inputs, [o.shape for o in outputs],
mode=self.mode)
theano.printing.debugprint(shapes_function)
# Check that the Op is removed from the compiled function.
topo_shape = shapes_function.maker.env.toposort()
assert not any(isinstance(t.op, cls) for t in topo_shape)
topo_out = outputs_function.maker.env.toposort()
assert any(isinstance(t.op, cls) for t in topo_out)
# Check that the shape produced agrees with the actual shape.
numeric_outputs = outputs_function(*numeric_inputs)
numeric_shapes = shapes_function(*numeric_inputs)
for out, shape in zip(numeric_outputs, numeric_shapes):
assert numpy.all(out.shape == shape)
def test_getitem_2d(self): def test_getitem_2d(self):
raise SkipTest('infer_shape not implemented for GetItem2d yet') raise SkipTest('infer_shape not implemented for GetItem2d yet')
......
...@@ -5807,10 +5807,16 @@ class SortOp(theano.Op): ...@@ -5807,10 +5807,16 @@ class SortOp(theano.Op):
z[0] = numpy.sort(a, axis, self.kind, self.order) z[0] = numpy.sort(a, axis, self.kind, self.order)
def infer_shape(self, node, inputs_shapes): def infer_shape(self, node, inputs_shapes):
if inputs_shapes[1] is None: if (isinstance(node.inputs[1], Constant) and
# That probably means axis = None, node.inputs[1].data is None):
# so the array is flattened before being sorted # That means axis = None,
# So the array is flattened before being sorted
return [(mul(*inputs_shapes[0]),)] return [(mul(*inputs_shapes[0]),)]
# axis should not be None
# So there should be the same number of dimensions
# in the input and output
assert node.inputs[0].ndim == node.outputs[0].ndim
assert inputs_shapes[1] is ()
return [inputs_shapes[0]] return [inputs_shapes[0]]
#**** It need the argsort, so we can't do it now. #**** It need the argsort, so we can't do it now.
......
...@@ -5641,6 +5641,22 @@ class test_sort(unittest.TestCase): ...@@ -5641,6 +5641,22 @@ class test_sort(unittest.TestCase):
assert numpy.allclose(f(self.m_val), assert numpy.allclose(f(self.m_val),
numpy.sort(self.m_val, None)) numpy.sort(self.m_val, None))
class TensorInferShapeTester(utt.InferShapeTester):
def test_sort(self):
x = tensor.matrix()
self._compile_and_check(
[x],
[sort(x)],
[numpy.random.randn(10, 40).astype(config.floatX)],
SortOp)
self._compile_and_check(
[x],
[sort(x, axis=None)],
[numpy.random.randn(10, 40).astype(config.floatX)],
SortOp)
if __name__ == '__main__': if __name__ == '__main__':
if 0: if 0:
unittest.main() unittest.main()
......
from copy import copy, deepcopy from copy import copy, deepcopy
import sys import sys
import unittest
import numpy import numpy
import theano
import theano.tensor as T import theano.tensor as T
from theano.configparser import config, AddConfigVar, StrParam from theano.configparser import config, AddConfigVar, StrParam
try: try:
...@@ -148,3 +150,27 @@ class T_OpContractMixin(object): ...@@ -148,3 +150,27 @@ class T_OpContractMixin(object):
for op in self.ops: for op in self.ops:
s = str(op) # show that str works s = str(op) # show that str works
assert s # names should not be empty assert s # names should not be empty
class InferShapeTester(unittest.TestCase):
def setUp(self):
seed_rng()
# This mode seems to be the minimal one including the shape_i
# optimizations, if we don't want to enumerate them explicitly.
self.mode = theano.compile.get_default_mode().including("canonicalize")
def _compile_and_check(self, inputs, outputs, numeric_inputs, cls):
outputs_function = theano.function(inputs, outputs, mode=self.mode)
shapes_function = theano.function(inputs, [o.shape for o in outputs],
mode=self.mode)
theano.printing.debugprint(shapes_function)
# Check that the Op is removed from the compiled function.
topo_shape = shapes_function.maker.env.toposort()
assert not any(isinstance(t.op, cls) for t in topo_shape)
topo_out = outputs_function.maker.env.toposort()
assert any(isinstance(t.op, cls) for t in topo_out)
# Check that the shape produced agrees with the actual shape.
numeric_outputs = outputs_function(*numeric_inputs)
numeric_shapes = shapes_function(*numeric_inputs)
for out, shape in zip(numeric_outputs, numeric_shapes):
assert numpy.all(out.shape == shape)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论