提交 c84ff2b1 authored 作者: Yikang Shen's avatar Yikang Shen

debug

上级 ce687283
......@@ -866,8 +866,10 @@ class MRG_RandomStreams(object):
rstates = self.get_substream_rstates(nstreams, dtype)
d = {}
if kwargs.has_key('target'):
if 'target' in kwargs:
d = dict(target=kwargs.pop('target'))
if len(kwargs) > 0:
raise TypeError("uniform() got unexpected keyword arguements %s" % (str(kwargs.keys())))
node_rstate = shared(rstates, **d)
u = self.pretty_return(node_rstate,
*mrg_uniform.new(node_rstate,
......@@ -883,8 +885,6 @@ class MRG_RandomStreams(object):
'`low` and `high` arguments')
assert r.dtype == dtype
if len(kwargs) > 0:
raise TypeError("uniform() got unexpected keyword arguements %s" % (str(kwargs.keys())))
return r
def binomial(self, size=None, n=1, p=0.5, ndim=None, dtype='int64',
......
......@@ -753,6 +753,21 @@ def test_f16_nonzero(mode=None, op_to_check=rng_mrg.mrg_uniform):
assert np.all((0 < m_val) & (m_val < 1))
def test_target_parameter():
srng = MRG_RandomStreams()
pvals = np.array([[.98, .01, .01], [.01, .49, .50]])
def basic_target_parameter_test(x):
f = theano.function([], x)
assert isinstance(f(), np.ndarray)
basic_target_parameter_test(srng.uniform((3, 2), target='cpu'))
basic_target_parameter_test(srng.binomial((3, 2), target='cpu'))
basic_target_parameter_test(srng.multinomial(pvals=pvals.astype('float32'), target='cpu'))
basic_target_parameter_test(srng.choice(p=pvals.astype('float32'), replace=False, target='cpu'))
basic_target_parameter_test(srng.multinomial_wo_replacement(pvals=pvals.astype('float32'), target='cpu'))
if __name__ == "__main__":
rng = MRG_RandomStreams(np.random.randint(2147462579))
print(theano.__file__)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论