提交 4c4b7cc4 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier 提交者: --global

Improve infer_shape test

上级 45328ed6
import numpy import numpy
import theano import theano
import theano.tensor as T import theano.tensor as T
from theano.tests import unittest_tools as utt
from theano.tests.breakpoint import PdbBreakpoint from theano.tests.breakpoint import PdbBreakpoint
class TestPdbBreakpoint: class TestPdbBreakpoint(utt.InferShapeTester):
def setup(self): def setUp(self):
super(TestPdbBreakpoint, self).setUp()
# Sample computation that involves tensors with different numbers # Sample computation that involves tensors with different numbers
# of dimensions # of dimensions
...@@ -25,20 +28,15 @@ class TestPdbBreakpoint: ...@@ -25,20 +28,15 @@ class TestPdbBreakpoint:
def test_infer_shape(self): def test_infer_shape(self):
input1_value = numpy.arange(9).reshape(3,3).astype("float32") input1_value = numpy.arange(6).reshape(2,3).astype("float32")
input2_value = 10.0 input2_value = 10.0
fct = theano.function([self.input1, self.input2], self._compile_and_check([self.input1, self.input2],
[self.monitored_input1.shape, [self.monitored_input1,
self.monitored_input2.shape, self.monitored_input2,
self.monitored_output.shape]) self.monitored_output],
[input1_value, input2_value],
shapes = fct(input1_value, input2_value) PdbBreakpoint)
assert tuple(shapes[0]) == input1_value.shape
assert tuple(shapes[1]) == tuple()
assert tuple(shapes[2]) == (input1_value.shape[0],
input1_value.shape[0])
def test_grad(self): def test_grad(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论