提交 36586878 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Speedup random perform

上级 815b2585
...@@ -387,24 +387,17 @@ class RandomVariable(Op): ...@@ -387,24 +387,17 @@ class RandomVariable(Op):
return node.inputs[2:] return node.inputs[2:]
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
rng_var_out, smpl_out = outputs
rng, size, *args = inputs rng, size, *args = inputs
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise. # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
if not self.inplace: if not self.inplace:
rng = copy(rng) rng = copy(rng)
rng_var_out[0] = rng outputs[0][0] = rng
outputs[1][0] = np.asarray(
if size is not None: self.rng_fn(rng, *args, None if size is None else tuple(size)),
size = tuple(size) dtype=self.dtype,
smpl_val = self.rng_fn(rng, *([*args, size])) )
if not isinstance(smpl_val, np.ndarray) or str(smpl_val.dtype) != self.dtype:
smpl_val = np.asarray(smpl_val, dtype=self.dtype)
smpl_out[0] = smpl_val
def grad(self, inputs, outputs): def grad(self, inputs, outputs):
return [ return [
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论