提交 4b5680fe authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use the correct destroy_map in RandomVariable

上级 1a969e3d
...@@ -145,7 +145,7 @@ class RandomVariable(Op): ...@@ -145,7 +145,7 @@ class RandomVariable(Op):
self.ndims_params = tuple(self.ndims_params) self.ndims_params = tuple(self.ndims_params)
if self.inplace: if self.inplace:
self.destroy_map = {0: [len(self.ndims_params) + 1]} self.destroy_map = {0: [0]}
def _shape_from_params(self, dist_params, **kwargs): def _shape_from_params(self, dist_params, **kwargs):
"""Determine the shape of a `RandomVariable`'s output given its parameters. """Determine the shape of a `RandomVariable`'s output given its parameters.
......
...@@ -88,7 +88,7 @@ def test_RandomVariable_basics(): ...@@ -88,7 +88,7 @@ def test_RandomVariable_basics():
) )
assert rv.inplace assert rv.inplace
assert rv.destroy_map == {0: [3]} assert rv.destroy_map == {0: [0]}
# A no-params `RandomVariable` # A no-params `RandomVariable`
rv = RandomVariable(name="test_rv", ndim_supp=0, ndims_params=()) rv = RandomVariable(name="test_rv", ndim_supp=0, ndims_params=())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论