提交 5b58807d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor random tests and include explicit broadcasting tests

上级 87bc36c7
...@@ -29,8 +29,7 @@ except AttributeError: ...@@ -29,8 +29,7 @@ except AttributeError:
from numpy.lib.stride_tricks import _broadcast_shape from numpy.lib.stride_tricks import _broadcast_shape
def broadcast_shapes(*shapes): def broadcast_shapes(*shapes):
arrays = [np.empty(x, dtype=[]) for x in shapes] return _broadcast_shape(*[np.empty(x, dtype=[]) for x in shapes])
return _broadcast_shape(arrays)
class ScipyRandomVariable(RandomVariable): class ScipyRandomVariable(RandomVariable):
......
...@@ -330,10 +330,6 @@ class RandomVariable(Op): ...@@ -330,10 +330,6 @@ class RandomVariable(Op):
def make_node(self, rng, size, dtype, *dist_params): def make_node(self, rng, size, dtype, *dist_params):
"""Create a random variable node. """Create a random variable node.
XXX: Unnamed/non-keyword arguments are considered distribution
parameters! If you want to set `size`, `rng`, and/or `name`, use their
keywords.
Parameters Parameters
---------- ----------
rng: RandomGeneratorType or RandomStateType rng: RandomGeneratorType or RandomStateType
......
...@@ -10,6 +10,7 @@ import aesara ...@@ -10,6 +10,7 @@ import aesara
from aesara.compile.debugmode import str_diagnostic from aesara.compile.debugmode import str_diagnostic
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.gradient import verify_grad as orig_verify_grad from aesara.gradient import verify_grad as orig_verify_grad
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.math import _allclose from aesara.tensor.math import _allclose
from aesara.tensor.math import add as aet_add from aesara.tensor.math import add as aet_add
...@@ -392,3 +393,10 @@ def assertFailure_fast(f): ...@@ -392,3 +393,10 @@ def assertFailure_fast(f):
return test_with_assert return test_with_assert
else: else:
return f return f
def create_aesara_param(param_value):
"""Create a `Variable` from a value and set its test value."""
p_aet = as_tensor_variable(param_value).type()
p_aet.tag.test_value = param_value
return p_aet
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论