提交 9b4531d4 authored 作者: lucianopaz's avatar lucianopaz 提交者: Ricardo Vieira

Use _props_dict instead of _props in random_make_inplace

上级 71962518
...@@ -44,8 +44,9 @@ def random_make_inplace(fgraph, node): ...@@ -44,8 +44,9 @@ def random_make_inplace(fgraph, node):
op = node.op op = node.op
if isinstance(op, RandomVariable) and not op.inplace: if isinstance(op, RandomVariable) and not op.inplace:
name, ndim_supp, ndims_params, dtype, _ = op._props() props = op._props_dict()
new_op = type(op)(name, ndim_supp, ndims_params, dtype, True) props["inplace"] = True
new_op = type(op)(**props)
return new_op.make_node(*node.inputs).outputs return new_op.make_node(*node.inputs).outputs
return False return False
......
...@@ -87,6 +87,48 @@ def test_inplace_optimization(): ...@@ -87,6 +87,48 @@ def test_inplace_optimization():
assert np.array_equal(new_out.owner.inputs[1].data, []) assert np.array_equal(new_out.owner.inputs[1].data, [])
def test_inplace_optimization_extra_props():
class Test(RandomVariable):
name = "test"
ndim_supp = 0
ndims_params = [0]
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace", "extra")
dtype = "floatX"
_print_name = ("Test", "\\operatorname{Test}")
def __init__(self, extra, *args, **kwargs):
self.extra = extra
super().__init__(*args, **kwargs)
def make_node(self, rng, size, dtype, sigma):
return super().make_node(rng, size, dtype, sigma)
def rng_fn(self, rng, sigma, size):
return rng.normal(scale=sigma, size=size)
out = Test(extra="some value")(1)
out.owner.inputs[0].default_update = out.owner.outputs[0]
assert out.owner.op.inplace is False
f = function(
[],
out,
mode="FAST_RUN",
)
(new_out, new_rng) = f.maker.fgraph.outputs
assert new_out.type == out.type
assert isinstance(new_out.owner.op, type(out.owner.op))
assert new_out.owner.op.inplace is True
assert new_out.owner.op.extra == out.owner.op.extra
assert all(
np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
)
assert np.array_equal(new_out.owner.inputs[1].data, [])
@config.change_flags(compute_test_value="raise") @config.change_flags(compute_test_value="raise")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dist_op, dist_params, size", "dist_op, dist_params, size",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论