提交 a5464b63 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

removed the use of TT and replaced with tensor

上级 67739ee2
...@@ -9,7 +9,6 @@ from theano import tensor ...@@ -9,7 +9,6 @@ from theano import tensor
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.compile.pfunc import rebuild_collect_shared from theano.compile.pfunc import rebuild_collect_shared
import theano.tensor as TT
''' '''
...@@ -2201,12 +2200,12 @@ class T_Scan(unittest.TestCase): ...@@ -2201,12 +2200,12 @@ class T_Scan(unittest.TestCase):
def test_pushout(self): def test_pushout(self):
W1 = TT.matrix('W1') W1 = tensor.matrix('W1')
W2 = TT.matrix('W2') W2 = tensor.matrix('W2')
h0 = TT.vector('h0') h0 = tensor.vector('h0')
def lambda_fn(h, W1, W2): def lambda_fn(h, W1, W2):
return TT.dot(h, W1 + W2) return tensor.dot(h, W1 + W2)
o, _ = theano.scan(lambda_fn, outputs_info= h0, o, _ = theano.scan(lambda_fn, outputs_info= h0,
non_sequences =[W1,W2], non_sequences =[W1,W2],
...@@ -2223,14 +2222,14 @@ class T_Scan(unittest.TestCase): ...@@ -2223,14 +2222,14 @@ class T_Scan(unittest.TestCase):
def test_alloc_inputs1(self): def test_alloc_inputs1(self):
W1 = TT.matrix('W1') W1 = tensor.matrix('W1')
W2 = TT.matrix('W2') W2 = tensor.matrix('W2')
h0 = TT.vector('h0') h0 = tensor.vector('h0')
def lambda_fn(h, W1, W2): def lambda_fn(h, W1, W2):
return TT.dot(h, W1 * W2) return tensor.dot(h, W1 * W2)
o, _ = theano.scan(lambda_fn, outputs_info= h0, o, _ = theano.scan(lambda_fn, outputs_info= h0,
non_sequences =[W1,TT.zeros_like(W2)], non_sequences =[W1,tensor.zeros_like(W2)],
n_steps = 5) n_steps = 5)
f = theano.function([h0,W1,W2], o) f = theano.function([h0,W1,W2], o)
...@@ -2242,17 +2241,17 @@ class T_Scan(unittest.TestCase): ...@@ -2242,17 +2241,17 @@ class T_Scan(unittest.TestCase):
def test_alloc_inputs2(self): def test_alloc_inputs2(self):
W1 = TT.matrix() W1 = tensor.matrix()
W2 = TT.matrix() W2 = tensor.matrix()
h0 = TT.vector() h0 = tensor.vector()
def lambda_fn(W1,h, W2): def lambda_fn(W1,h, W2):
return W1 * TT.dot(h, W2) return W1 * tensor.dot(h, W2)
o, _ = theano.scan(lambda_fn, o, _ = theano.scan(lambda_fn,
sequences = TT.zeros_like(W1), sequences = tensor.zeros_like(W1),
outputs_info= h0, outputs_info= h0,
non_sequences =[TT.zeros_like(W2)], non_sequences =[tensor.zeros_like(W2)],
n_steps = 5) n_steps = 5)
f = theano.function([h0,W1,W2], o) f = theano.function([h0,W1,W2], o)
...@@ -2266,21 +2265,21 @@ class T_Scan(unittest.TestCase): ...@@ -2266,21 +2265,21 @@ class T_Scan(unittest.TestCase):
def test_alloc_inputs3(self): def test_alloc_inputs3(self):
_W1 = TT.matrix() _W1 = tensor.matrix()
_W2 = TT.matrix() _W2 = tensor.matrix()
_h0 = TT.vector() _h0 = tensor.vector()
W1 = TT.specify_shape(_W1, (3,3)) W1 = tensor.specify_shape(_W1, (3,3))
W2 = TT.specify_shape(_W2, (3,3)) W2 = tensor.specify_shape(_W2, (3,3))
h0 = TT.specify_shape(_h0, (3,)) h0 = tensor.specify_shape(_h0, (3,))
def lambda_fn(W1,h, W2): def lambda_fn(W1,h, W2):
return W1 * TT.dot(h, W2) return W1 * tensor.dot(h, W2)
o, _ = theano.scan(lambda_fn, o, _ = theano.scan(lambda_fn,
sequences = TT.zeros_like(W1), sequences = tensor.zeros_like(W1),
outputs_info= h0, outputs_info= h0,
non_sequences =[TT.zeros_like(W2)], non_sequences =[tensor.zeros_like(W2)],
n_steps = 5) n_steps = 5)
f = theano.function([_h0,_W1,_W2], o) f = theano.function([_h0,_W1,_W2], o)
...@@ -2292,7 +2291,7 @@ class T_Scan(unittest.TestCase): ...@@ -2292,7 +2291,7 @@ class T_Scan(unittest.TestCase):
def test_while0(self): def test_while0(self):
x = TT.vector('x') x = tensor.vector('x')
def lambda_fn(x_t): def lambda_fn(x_t):
return x_t+1, theano.until( x_t > 3) return x_t+1, theano.until( x_t > 3)
o, _ = theano.scan(lambda_fn, x) o, _ = theano.scan(lambda_fn, x)
...@@ -2303,7 +2302,7 @@ class T_Scan(unittest.TestCase): ...@@ -2303,7 +2302,7 @@ class T_Scan(unittest.TestCase):
assert numpy.sum(out[24:]) == 0 assert numpy.sum(out[24:]) == 0
def test_while1(self): def test_while1(self):
x = TT.vector('x') x = tensor.vector('x')
def lambda_fn(x_t): def lambda_fn(x_t):
return x_t+1, theano.until( x_t > 3) return x_t+1, theano.until( x_t > 3)
o, _ = theano.scan(lambda_fn, x) o, _ = theano.scan(lambda_fn, x)
...@@ -2322,7 +2321,7 @@ class T_Scan(unittest.TestCase): ...@@ -2322,7 +2321,7 @@ class T_Scan(unittest.TestCase):
def test_while2(self): def test_while2(self):
x = TT.vector('x') x = tensor.vector('x')
def lambda_fn(x_t): def lambda_fn(x_t):
return x_t+1, theano.until( x_t > 3) return x_t+1, theano.until( x_t > 3)
o, _ = theano.scan(lambda_fn, x) o, _ = theano.scan(lambda_fn, x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论