提交 9fb384e9 authored 作者: Jesse Livezey's avatar Jesse Livezey

more cleanup

上级 7c8b5808
...@@ -260,7 +260,7 @@ class TestCorr2D(utt.InferShapeTester): ...@@ -260,7 +260,7 @@ class TestCorr2D(utt.InferShapeTester):
self.assertRaises(Exception, self.validate, (3, 2, 8, 8), (4, 2, 5, 5), self.assertRaises(Exception, self.validate, (3, 2, 8, 8), (4, 2, 5, 5),
'valid', input=T.dtensor3()) 'valid', input=T.dtensor3())
def dtype_upcast(self): def test_dtype_upcast(self):
""" """
Checks dtype upcast for CorrMM methods. Checks dtype upcast for CorrMM methods.
""" """
...@@ -275,18 +275,10 @@ class TestCorr2D(utt.InferShapeTester): ...@@ -275,18 +275,10 @@ class TestCorr2D(utt.InferShapeTester):
for op, a_shape, b_shape in zip(ops, a_shapes, b_shapes): for op, a_shape, b_shape in zip(ops, a_shapes, b_shapes):
for a_dtype in dtypes: for a_dtype in dtypes:
c_dtype = 'float32'
for b_dtype in dtypes: for b_dtype in dtypes:
if a_dtype == 'float32': c_dtype = theano.scalar.upcast(a_dtype, b_dtype)
a_tens = T.ftensor4() a_tens = T.tensor4(dtype=a_dtype)
else: b_tens = T.tensor4(dtype=b_dtype)
c_dtype = 'float64'
a_tens = T.dtensor4()
if b_dtype == 'float32':
b_tens = T.ftensor4()
else:
c_dtype = 'float64'
b_tens = T.dtensor4()
a_tens_val = rand(a_shape, dtype=a_dtype) a_tens_val = rand(a_shape, dtype=a_dtype)
b_tens_val = rand(b_shape, dtype=b_dtype) b_tens_val = rand(b_shape, dtype=b_dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论