提交 bb27565a authored 作者: Frederic Bastien's avatar Frederic Bastien

2 bug fixes. Make choice on the GPU with replace work well.

上级 3a4d1d15
...@@ -507,7 +507,7 @@ def local_gpua_multinomial_wor(op, context_name, inputs, outputs): ...@@ -507,7 +507,7 @@ def local_gpua_multinomial_wor(op, context_name, inputs, outputs):
p, u, n = inputs p, u, n = inputs
m, = outputs m, = outputs
if ((p.dtype == u.dtype == 'float32') and (m.dtype == 'int64')): if ((p.dtype == u.dtype == 'float32') and (m.dtype == 'int64')):
gpu_op = GPUAChoiceFromUniform(op.odtype) gpu_op = GPUAChoiceFromUniform(**op._props_dict())
return GpuDimShuffle([False, False], [1, 0])( return GpuDimShuffle([False, False], [1, 0])(
gpu_op(p, u, n)) gpu_op(p, u, n))
......
...@@ -327,10 +327,13 @@ def test_gpu_opt_wor(): ...@@ -327,10 +327,13 @@ def test_gpu_opt_wor():
p = tensor.fmatrix() p = tensor.fmatrix()
u = tensor.fvector() u = tensor.fvector()
n = tensor.iscalar() n = tensor.iscalar()
m = multinomial.ChoiceFromUniform(odtype='auto')(p, u, n) for replace in [False, True]:
m = multinomial.ChoiceFromUniform(odtype='auto',
replace=replace)(p, u, n)
assert m.dtype == 'int64', m.dtype assert m.dtype == 'int64', m.dtype
f = function([p, u, n], m, allow_input_downcast=True, mode=mode_with_gpu) f = function([p, u, n], m, allow_input_downcast=True,
mode=mode_with_gpu)
assert any([type(node.op) is GPUAChoiceFromUniform assert any([type(node.op) is GPUAChoiceFromUniform
for node in f.maker.fgraph.toposort()]) for node in f.maker.fgraph.toposort()])
n_samples = 3 n_samples = 3
...@@ -341,10 +344,11 @@ def test_gpu_opt_wor(): ...@@ -341,10 +344,11 @@ def test_gpu_opt_wor():
# Test with a row, it was failing in the past. # Test with a row, it was failing in the past.
r = tensor.frow() r = tensor.frow()
m = multinomial.ChoiceFromUniform('auto')(r, u, n) m = multinomial.ChoiceFromUniform('auto', replace=replace)(r, u, n)
assert m.dtype == 'int64', m.dtype assert m.dtype == 'int64', m.dtype
f = function([r, u, n], m, allow_input_downcast=True, mode=mode_with_gpu) f = function([r, u, n], m, allow_input_downcast=True,
mode=mode_with_gpu)
assert any([type(node.op) is GPUAChoiceFromUniform assert any([type(node.op) is GPUAChoiceFromUniform
for node in f.maker.fgraph.toposort()]) for node in f.maker.fgraph.toposort()])
pval = np.arange(1 * 4, dtype='float32').reshape((1, 4)) + 0.1 pval = np.arange(1 * 4, dtype='float32').reshape((1, 4)) + 0.1
......
...@@ -212,7 +212,7 @@ class ChoiceFromUniform(MultinomialFromUniform): ...@@ -212,7 +212,7 @@ class ChoiceFromUniform(MultinomialFromUniform):
""" """
__props__ = ("replace",) __props__ = ("odtype", "replace",)
def __init__(self, odtype, replace=False, *args, **kwargs): def __init__(self, odtype, replace=False, *args, **kwargs):
self.replace = replace self.replace = replace
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论