提交 88626f5d authored 作者: khaotik's avatar khaotik

add test for NullType and DisconnectedType

上级 c0f7334c
......@@ -4,6 +4,8 @@ import numpy as np
from theano import config, shared
from theano.gradient import DisconnectedType
from theano.gof.null_type import NullType
from theano.compile import function
from theano import tensor as T
......@@ -173,6 +175,19 @@ class T_OpFromGraph(unittest_tools.InferShapeTester):
assert np.allclose(xv * 1.5, dwv)
assert np.allclose(np.ones(16, dtype=config.floatX), dbv)
# NullType and DisconnectedType
op_linear2 = cls_ofg([x, w, b], [x * w + b], grad_overrides=[go1, None, 0])
zz2 = T.sum(op_linear2(xx, ww, bb))
dx2, dw2, db2 = T.grad(
zz2, [xx, ww, bb],
return_disconnected='Disconnected',
null_gradients='return')
fn2 = function([xx, ww, bb], [dx2, dw2, db2])
dxv2, dwv2, dbv2 = fn2(xv, wv, bv)
assert numpy.allclose(wv * 2, dxv)
assert isinstance(dwv2.type, NullType)
assert isinstance(dbv2.type, DisconnectedType)
@test_params
def test_rop(self, cls_ofg):
a = T.vector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论