提交 9fb384e9 authored 作者: Jesse Livezey's avatar Jesse Livezey

more cleanup

上级 7c8b5808
......@@ -260,7 +260,7 @@ class TestCorr2D(utt.InferShapeTester):
self.assertRaises(Exception, self.validate, (3, 2, 8, 8), (4, 2, 5, 5),
'valid', input=T.dtensor3())
def dtype_upcast(self):
def test_dtype_upcast(self):
"""
Checks dtype upcast for CorrMM methods.
"""
......@@ -275,18 +275,10 @@ class TestCorr2D(utt.InferShapeTester):
for op, a_shape, b_shape in zip(ops, a_shapes, b_shapes):
for a_dtype in dtypes:
c_dtype = 'float32'
for b_dtype in dtypes:
if a_dtype == 'float32':
a_tens = T.ftensor4()
else:
c_dtype = 'float64'
a_tens = T.dtensor4()
if b_dtype == 'float32':
b_tens = T.ftensor4()
else:
c_dtype = 'float64'
b_tens = T.dtensor4()
c_dtype = theano.scalar.upcast(a_dtype, b_dtype)
a_tens = T.tensor4(dtype=a_dtype)
b_tens = T.tensor4(dtype=b_dtype)
a_tens_val = rand(a_shape, dtype=a_dtype)
b_tens_val = rand(b_shape, dtype=b_dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论