提交 4c4b7cc4 authored 作者: Pierre Luc Carrier's avatar Pierre Luc Carrier 提交者: --global

Improve infer_shape test

上级 45328ed6
import numpy
import theano
import theano.tensor as T
from theano.tests import unittest_tools as utt
from theano.tests.breakpoint import PdbBreakpoint
class TestPdbBreakpoint:
class TestPdbBreakpoint(utt.InferShapeTester):
def setup(self):
def setUp(self):
super(TestPdbBreakpoint, self).setUp()
# Sample computation that involves tensors with different numbers
# of dimensions
......@@ -25,20 +28,15 @@ class TestPdbBreakpoint:
def test_infer_shape(self):
input1_value = numpy.arange(9).reshape(3,3).astype("float32")
input1_value = numpy.arange(6).reshape(2,3).astype("float32")
input2_value = 10.0
fct = theano.function([self.input1, self.input2],
[self.monitored_input1.shape,
self.monitored_input2.shape,
self.monitored_output.shape])
shapes = fct(input1_value, input2_value)
assert tuple(shapes[0]) == input1_value.shape
assert tuple(shapes[1]) == tuple()
assert tuple(shapes[2]) == (input1_value.shape[0],
input1_value.shape[0])
self._compile_and_check([self.input1, self.input2],
[self.monitored_input1,
self.monitored_input2,
self.monitored_output],
[input1_value, input2_value],
PdbBreakpoint)
def test_grad(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论