提交 b00ab8eb authored 作者: Jesse Livezey's avatar Jesse Livezey

tests fail with original corr code

上级 f27c11e2
......@@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function, division
from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr
from nose.tools import assert_equals
import numpy
from six import integer_types
......@@ -259,6 +260,40 @@ class TestCorr2D(utt.InferShapeTester):
self.assertRaises(Exception, self.validate, (3, 2, 8, 8), (4, 2, 5, 5),
'valid', input=T.dtensor3())
def dtype_upcast(self):
"""
Checks dtype upcast for CorrMM methods.
"""
def rand(shape, dtype='float64'):
r = numpy.asarray(numpy.random.rand(*shape), dtype=dtype)
return r * 2 - 1
ops = [corr.CorrMM, corr.CorrMM_gradWeights, corr.CorrMM_gradInputs]
a_shapes = [[4, 5, 6, 3], [1, 5, 6, 3], [1, 5, 6, 3]]
b_shapes = [[7, 5, 3, 2], [1, 5, 3, 1], [7, 1, 3, 1]]
dtypes = ['float32', 'float64']
for op, a_shape, b_shape in zip(ops, a_shapes, b_shapes):
for a_dtype in dtypes:
c_dtype = 'float32'
for b_dtype in dtypes:
if a_dtype == 'float32':
a_tens = T.ftensor4()
else:
c_dtype = 'float64'
a_tens = T.dtensor4()
if b_dtype == 'float32':
b_tens = T.ftensor4()
else:
c_dtype = 'float64'
b_tens = T.dtensor4()
a_tens_val = rand(a_shape, dtype=a_dtype)
b_tens_val = rand(b_shape, dtype=b_dtype)
c_tens = op()(a_tens, b_tens)
f = theano.function([a_tens, b_tens], c_tens)
assert_equals(f(a_tens_val, b_tens_val).dtype, c_dtype)
@attr('slow')
def test_infer_shape_forward(self):
if theano.config.mode == "FAST_COMPILE":
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论