提交 a56ed352 authored 作者: Cesar Laurent's avatar Cesar Laurent

Added tes_old_pool_interface.

上级 30a83561
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from itertools import product from itertools import product
import os
import unittest import unittest
from six.moves import cPickle
import six.moves.builtins as builtins import six.moves.builtins as builtins
import sys
import numpy import numpy
...@@ -821,9 +824,9 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -821,9 +824,9 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
for mode in ['max', 'sum', 'average_inc_pad', 'average_exc_pad']: for mode in ['max', 'sum', 'average_inc_pad', 'average_exc_pad']:
y = pool_2d(x, window_size, ignore_border, stride, padding, y = pool_2d(x, window_size, ignore_border, stride, padding,
mode) mode)
dy = theano.gradient.grad(y.sum(), x) dx = theano.gradient.grad(y.sum(), x)
var_fct = theano.function([x, window_size, stride, padding], var_fct = theano.function([x, window_size, stride, padding],
[y, dy]) [y, dx])
for ws in (4, 2, 5): for ws in (4, 2, 5):
for st in (2, 3): for st in (2, 3):
for pad in (0, 1): for pad in (0, 1):
...@@ -833,13 +836,29 @@ class TestDownsampleFactorMax(utt.InferShapeTester): ...@@ -833,13 +836,29 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
continue continue
y = pool_2d(x, (ws, ws), ignore_border, (st, st), y = pool_2d(x, (ws, ws), ignore_border, (st, st),
(pad, pad), mode) (pad, pad), mode)
dy = theano.gradient.grad(y.sum(), x) dx = theano.gradient.grad(y.sum(), x)
fix_fct = theano.function([x], [y, dy]) fix_fct = theano.function([x], [y, dx])
var_y, var_dy = var_fct(data, (ws, ws), (st, st), var_y, var_dx = var_fct(data, (ws, ws), (st, st),
(pad, pad)) (pad, pad))
fix_y, fix_dy = fix_fct(data) fix_y, fix_dx = fix_fct(data)
utt.assert_allclose(var_y, fix_y) utt.assert_allclose(var_y, fix_y)
utt.assert_allclose(var_dy, fix_dy) utt.assert_allclose(var_dx, fix_dx)
def test_old_pool_interface(self):
testfile_dir = os.path.dirname(os.path.realpath(__file__))
fname = 'old_pool_interface.pkl'
with open(os.path.join(testfile_dir, fname), 'rb') as fp:
try:
cPickle.load(fp)
except ImportError:
# Windows sometimes fail with nonsensical errors like:
# ImportError: No module named type
# ImportError: No module named copy_reg
# when "type" and "copy_reg" are builtin modules.
if sys.platform == 'win32':
exc_type, exc_value, exc_trace = sys.exc_info()
reraise(SkipTest, exc_value, exc_trace)
raise
def test_DownsampleFactorMaxGrad(self): def test_DownsampleFactorMaxGrad(self):
im = theano.tensor.tensor4() im = theano.tensor.tensor4()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论