提交 aae8382f authored 作者: Frederic's avatar Frederic

Refactored the infer_shape tester to allow using it for tensor tests.

上级 7add8bc5
......@@ -140,29 +140,7 @@ class T_transpose(unittest.TestCase):
self.assertTrue(vta.shape == (3, 5))
class SparseInferShapeTester(unittest.TestCase):
def setUp(self):
utt.seed_rng()
# This mode seems to be the minimal one including the shape_i
# optimizations, if we don't want to enumerate them explicitly.
self.mode = theano.compile.get_default_mode().including("canonicalize")
def _compile_and_check(self, inputs, outputs, numeric_inputs, cls):
outputs_function = theano.function(inputs, outputs, mode=self.mode)
shapes_function = theano.function(inputs, [o.shape for o in outputs],
mode=self.mode)
theano.printing.debugprint(shapes_function)
# Check that the Op is removed from the compiled function.
topo_shape = shapes_function.maker.env.toposort()
assert not any(isinstance(t.op, cls) for t in topo_shape)
topo_out = outputs_function.maker.env.toposort()
assert any(isinstance(t.op, cls) for t in topo_out)
# Check that the shape produced agrees with the actual shape.
numeric_outputs = outputs_function(*numeric_inputs)
numeric_shapes = shapes_function(*numeric_inputs)
for out, shape in zip(numeric_outputs, numeric_shapes):
assert numpy.all(out.shape == shape)
class SparseInferShapeTester(utt.InferShapeTester):
def test_getitem_2d(self):
raise SkipTest('infer_shape not implemented for GetItem2d yet')
......
from copy import copy, deepcopy
import sys
import unittest
import numpy
import theano
import theano.tensor as T
from theano.configparser import config, AddConfigVar, StrParam
try:
......@@ -148,3 +150,27 @@ class T_OpContractMixin(object):
for op in self.ops:
s = str(op) # show that str works
assert s # names should not be empty
class InferShapeTester(unittest.TestCase):
def setUp(self):
seed_rng()
# This mode seems to be the minimal one including the shape_i
# optimizations, if we don't want to enumerate them explicitly.
self.mode = theano.compile.get_default_mode().including("canonicalize")
def _compile_and_check(self, inputs, outputs, numeric_inputs, cls):
outputs_function = theano.function(inputs, outputs, mode=self.mode)
shapes_function = theano.function(inputs, [o.shape for o in outputs],
mode=self.mode)
theano.printing.debugprint(shapes_function)
# Check that the Op is removed from the compiled function.
topo_shape = shapes_function.maker.env.toposort()
assert not any(isinstance(t.op, cls) for t in topo_shape)
topo_out = outputs_function.maker.env.toposort()
assert any(isinstance(t.op, cls) for t in topo_out)
# Check that the shape produced agrees with the actual shape.
numeric_outputs = outputs_function(*numeric_inputs)
numeric_shapes = shapes_function(*numeric_inputs)
for out, shape in zip(numeric_outputs, numeric_shapes):
assert numpy.all(out.shape == shape)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论