提交 aae8382f authored 作者: Frederic's avatar Frederic

Refactored the infer_shape tester to allow using it for tensor tests.

上级 7add8bc5
...@@ -140,29 +140,7 @@ class T_transpose(unittest.TestCase): ...@@ -140,29 +140,7 @@ class T_transpose(unittest.TestCase):
self.assertTrue(vta.shape == (3, 5)) self.assertTrue(vta.shape == (3, 5))
class SparseInferShapeTester(unittest.TestCase): class SparseInferShapeTester(utt.InferShapeTester):
def setUp(self):
utt.seed_rng()
# This mode seems to be the minimal one including the shape_i
# optimizations, if we don't want to enumerate them explicitly.
self.mode = theano.compile.get_default_mode().including("canonicalize")
def _compile_and_check(self, inputs, outputs, numeric_inputs, cls):
outputs_function = theano.function(inputs, outputs, mode=self.mode)
shapes_function = theano.function(inputs, [o.shape for o in outputs],
mode=self.mode)
theano.printing.debugprint(shapes_function)
# Check that the Op is removed from the compiled function.
topo_shape = shapes_function.maker.env.toposort()
assert not any(isinstance(t.op, cls) for t in topo_shape)
topo_out = outputs_function.maker.env.toposort()
assert any(isinstance(t.op, cls) for t in topo_out)
# Check that the shape produced agrees with the actual shape.
numeric_outputs = outputs_function(*numeric_inputs)
numeric_shapes = shapes_function(*numeric_inputs)
for out, shape in zip(numeric_outputs, numeric_shapes):
assert numpy.all(out.shape == shape)
def test_getitem_2d(self): def test_getitem_2d(self):
raise SkipTest('infer_shape not implemented for GetItem2d yet') raise SkipTest('infer_shape not implemented for GetItem2d yet')
......
from copy import copy, deepcopy from copy import copy, deepcopy
import sys import sys
import unittest
import numpy import numpy
import theano
import theano.tensor as T import theano.tensor as T
from theano.configparser import config, AddConfigVar, StrParam from theano.configparser import config, AddConfigVar, StrParam
try: try:
...@@ -148,3 +150,27 @@ class T_OpContractMixin(object): ...@@ -148,3 +150,27 @@ class T_OpContractMixin(object):
for op in self.ops: for op in self.ops:
s = str(op) # show that str works s = str(op) # show that str works
assert s # names should not be empty assert s # names should not be empty
class InferShapeTester(unittest.TestCase):
def setUp(self):
seed_rng()
# This mode seems to be the minimal one including the shape_i
# optimizations, if we don't want to enumerate them explicitly.
self.mode = theano.compile.get_default_mode().including("canonicalize")
def _compile_and_check(self, inputs, outputs, numeric_inputs, cls):
outputs_function = theano.function(inputs, outputs, mode=self.mode)
shapes_function = theano.function(inputs, [o.shape for o in outputs],
mode=self.mode)
theano.printing.debugprint(shapes_function)
# Check that the Op is removed from the compiled function.
topo_shape = shapes_function.maker.env.toposort()
assert not any(isinstance(t.op, cls) for t in topo_shape)
topo_out = outputs_function.maker.env.toposort()
assert any(isinstance(t.op, cls) for t in topo_out)
# Check that the shape produced agrees with the actual shape.
numeric_outputs = outputs_function(*numeric_inputs)
numeric_shapes = shapes_function(*numeric_inputs)
for out, shape in zip(numeric_outputs, numeric_shapes):
assert numpy.all(out.shape == shape)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论