提交 99eb36f9 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

new test for connectivity matrix

上级 67074eba
......@@ -3266,6 +3266,47 @@ class T_Scan(unittest.TestCase):
f = theano.function([seq], results[1], updates=updates)
assert numpy.all(exp_out == f(inp))
def test_grad_connectivity_matrix(self):
def inner_fn(x_tm1, y_tm1, z_tm1):
x_tm1.name = 'x'
y_tm1.name = 'y'
z_tm1.name = 'z'
return x_tm1**2, x_tm1 + y_tm1, x_tm1+1
x0 = tensor.vector('X')
y0 = tensor.vector('y0')
z0 = tensor.vector('Z')
[x,y,z], _ = theano.scan(inner_fn,
outputs_info=[x0,y0,z0],
n_steps=10)
cost = (x+y+z).sum()
#gx0 = tensor.grad(cost, x0) #defined
import ipdb; ipdb.set_trace()
gy0 = tensor.grad(cost, y0) #defined
failed = True
try:
gz0 = tensor.grad(cost, z0) #disconnected
except ValueError:
failed = False
if failed:
raise ValueError('grad should have complained about '
'disconnected input')
cost = x.sum()
failed = True
try:
gy0 = tensor.grad(cost, y0) #disconnected
except ValueError:
failed = False
if failed:
raise ValueError('grad should have complained about '
'disconnected input')
def test_speed():
#
# This function prints out the speed of very simple recurrent
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论