提交 a56ed352 authored 作者: Cesar Laurent's avatar Cesar Laurent

Added tes_old_pool_interface.

上级 30a83561
from __future__ import absolute_import, print_function, division
from itertools import product
import os
import unittest
from six.moves import cPickle
import six.moves.builtins as builtins
import sys
import numpy
......@@ -821,9 +824,9 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
for mode in ['max', 'sum', 'average_inc_pad', 'average_exc_pad']:
y = pool_2d(x, window_size, ignore_border, stride, padding,
mode)
dy = theano.gradient.grad(y.sum(), x)
dx = theano.gradient.grad(y.sum(), x)
var_fct = theano.function([x, window_size, stride, padding],
[y, dy])
[y, dx])
for ws in (4, 2, 5):
for st in (2, 3):
for pad in (0, 1):
......@@ -833,13 +836,29 @@ class TestDownsampleFactorMax(utt.InferShapeTester):
continue
y = pool_2d(x, (ws, ws), ignore_border, (st, st),
(pad, pad), mode)
dy = theano.gradient.grad(y.sum(), x)
fix_fct = theano.function([x], [y, dy])
var_y, var_dy = var_fct(data, (ws, ws), (st, st),
dx = theano.gradient.grad(y.sum(), x)
fix_fct = theano.function([x], [y, dx])
var_y, var_dx = var_fct(data, (ws, ws), (st, st),
(pad, pad))
fix_y, fix_dy = fix_fct(data)
fix_y, fix_dx = fix_fct(data)
utt.assert_allclose(var_y, fix_y)
utt.assert_allclose(var_dy, fix_dy)
utt.assert_allclose(var_dx, fix_dx)
def test_old_pool_interface(self):
testfile_dir = os.path.dirname(os.path.realpath(__file__))
fname = 'old_pool_interface.pkl'
with open(os.path.join(testfile_dir, fname), 'rb') as fp:
try:
cPickle.load(fp)
except ImportError:
# Windows sometimes fail with nonsensical errors like:
# ImportError: No module named type
# ImportError: No module named copy_reg
# when "type" and "copy_reg" are builtin modules.
if sys.platform == 'win32':
exc_type, exc_value, exc_trace = sys.exc_info()
reraise(SkipTest, exc_value, exc_trace)
raise
def test_DownsampleFactorMaxGrad(self):
im = theano.tensor.tensor4()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论