提交 ac7b5225 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Replace theano.tensor alias T with tt in tests.tensor.signal

上级 e103e7fb
...@@ -2,13 +2,15 @@ import pytest ...@@ -2,13 +2,15 @@ import pytest
import numpy as np import numpy as np
import theano import theano
import theano.tensor as T import theano.tensor as tt
from theano.tensor.signal import conv from theano.tensor.signal import conv
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose
from tests import unittest_tools as utt from tests import unittest_tools as utt
_ = pytest.importorskip("scipy.signal")
class TestSignalConv2D: class TestSignalConv2D:
def setup_method(self): def setup_method(self):
...@@ -18,8 +20,8 @@ class TestSignalConv2D: ...@@ -18,8 +20,8 @@ class TestSignalConv2D:
image_dim = len(image_shape) image_dim = len(image_shape)
filter_dim = len(filter_shape) filter_dim = len(filter_shape)
input = T.TensorType("float64", [False] * image_dim)() input = tt.TensorType("float64", [False] * image_dim)()
filters = T.TensorType("float64", [False] * filter_dim)() filters = tt.TensorType("float64", [False] * filter_dim)()
bsize = image_shape[0] bsize = image_shape[0]
if image_dim != 3: if image_dim != 3:
...@@ -84,8 +86,8 @@ class TestSignalConv2D: ...@@ -84,8 +86,8 @@ class TestSignalConv2D:
utt.verify_grad(sym_conv2d, [image_data, filter_data]) utt.verify_grad(sym_conv2d, [image_data, filter_data])
@pytest.mark.skipif( @pytest.mark.skipif(
not theano.tensor.nnet.conv.imported_scipy_signal and theano.config.cxx == "", theano.config.cxx == "",
reason="conv2d tests need SciPy or a c++ compiler", reason="conv2d tests need a c++ compiler",
) )
def test_basic(self): def test_basic(self):
# Basic functionality of nnet.conv.ConvOp is already tested by # Basic functionality of nnet.conv.ConvOp is already tested by
...@@ -101,15 +103,15 @@ class TestSignalConv2D: ...@@ -101,15 +103,15 @@ class TestSignalConv2D:
# Test that conv2d fails for dimensions other than 2 or 3. # Test that conv2d fails for dimensions other than 2 or 3.
with pytest.raises(Exception): with pytest.raises(Exception):
conv.conv2d(T.dtensor4(), T.dtensor3()) conv.conv2d(tt.dtensor4(), tt.dtensor3())
with pytest.raises(Exception): with pytest.raises(Exception):
conv.conv2d(T.dtensor3(), T.dvector()) conv.conv2d(tt.dtensor3(), tt.dvector())
def test_bug_josh_reported(self): def test_bug_josh_reported(self):
# Test refers to a bug reported by Josh, when due to a bad merge these # Test refers to a bug reported by Josh, when due to a bad merge these
# few lines of code failed. See # few lines of code failed. See
# http://groups.google.com/group/theano-dev/browse_thread/thread/8856e7ca5035eecb # http://groups.google.com/group/theano-dev/browse_thread/thread/8856e7ca5035eecb
m1 = theano.tensor.matrix() m1 = tt.matrix()
m2 = theano.tensor.matrix() m2 = tt.matrix()
conv.conv2d(m1, m2) conv.conv2d(m1, m2)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论