提交 be866f9f authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Numba RVs: Include output_bc_pattern in cache key

Size is ignored when the output_bc_pattern implies it must be length 1, but shouldn't be ignored otherwise
上级 f091b86b
...@@ -485,16 +485,14 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs ...@@ -485,16 +485,14 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
# If the core RV can't be cached, then the whole RV can't be cached # If the core RV can't be cached, then the whole RV can't be cached
random_rv_key = None # type: ignore[unreachable] random_rv_key = None # type: ignore[unreachable]
else: else:
rv_op_props_dict = rv_op.props_dict() if hasattr(rv_op, "props_dict") else {}
random_rv_key_contents = ( random_rv_key_contents = (
type(op), type(op),
type(rv_op), type(rv_op),
rv_op, tuple(rv_op._props_dict().items()), # type: ignore[attr-defined]
tuple(rv_op_props_dict.items()),
size_len, size_len,
core_shape_len, core_shape_len,
inplace,
input_bc_patterns, input_bc_patterns,
output_bc_patterns,
core_cache_key, core_cache_key,
) )
random_rv_key = sha256(str(random_rv_key_contents).encode()).hexdigest() random_rv_key = sha256(str(random_rv_key_contents).encode()).hexdigest()
......
...@@ -689,6 +689,15 @@ def test_dirichlet_discrete_alpha(): ...@@ -689,6 +689,15 @@ def test_dirichlet_discrete_alpha():
assert np.unique(res).size > 2 # Make sure we have more than just 0s and 1s assert np.unique(res).size > 2 # Make sure we have more than just 0s and 1s
def test_cache_size_bcast_change():
# Regression bug for caching with the same key in case where size is meaningful vs not
alpha = pt.dvector()
for s in (1, 2, 3):
x = ptr.dirichlet(alpha, size=(s,))
fn = function([alpha], x, mode=numba_mode)
assert fn([0.2, 0.3, 0.5]).shape == (s, 3)
def test_rv_inside_ofg(): def test_rv_inside_ofg():
rng_np = np.random.default_rng(562) rng_np = np.random.default_rng(562)
rng = shared(rng_np) rng = shared(rng_np)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论