提交 e73258b4 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix bug in `local_reduce_join` rewrite.

The helper `apply_local_dimshuffle_lift` requires a FunctionGraph when elemwise inputs are involved.
上级 e934ac7c
...@@ -1620,7 +1620,7 @@ def local_reduce_join(fgraph, node): ...@@ -1620,7 +1620,7 @@ def local_reduce_join(fgraph, node):
if not inp.type.broadcastable[join_axis]: if not inp.type.broadcastable[join_axis]:
return None return None
# Most times inputs to join have an expand_dims, we eagerly clean up those here # Most times inputs to join have an expand_dims, we eagerly clean up those here
new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis)) new_input = apply_local_dimshuffle_lift(fgraph, inp.squeeze(join_axis))
new_inputs.append(new_input) new_inputs.append(new_input)
ret = Elemwise(node.op.scalar_op)(*new_inputs) ret = Elemwise(node.op.scalar_op)(*new_inputs)
......
...@@ -103,6 +103,7 @@ from pytensor.tensor.rewriting.math import ( ...@@ -103,6 +103,7 @@ from pytensor.tensor.rewriting.math import (
local_mul_canonizer, local_mul_canonizer,
local_mul_switch_sink, local_mul_switch_sink,
local_reduce_chain, local_reduce_chain,
local_reduce_join,
local_sum_prod_of_mul_or_div, local_sum_prod_of_mul_or_div,
mul_canonizer, mul_canonizer,
parse_mul_tree, parse_mul_tree,
...@@ -3415,6 +3416,24 @@ class TestReduceJoin: ...@@ -3415,6 +3416,24 @@ class TestReduceJoin:
f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0) f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0)
) )
def test_non_ds_inputs(self):
"""Make sure rewrite works when inputs to join are not the usual DimShuffle.
Sum{axis=1} [id A] <Vector(float64, shape=(3,))>
└─ Join [id B] <Matrix(float64, shape=(3, 3))>
├─ 1 [id C] <Scalar(int8, shape=())>
├─ ExpandDims{axis=1} [id D] <Matrix(float64, shape=(3, 1))>
├─ Sub [id E] <Matrix(float64, shape=(3, 1))>
└─ Sub [id F] <Matrix(float64, shape=(3, 1))>
"""
x = vector("x")
out = join(0, exp(x[None]), log(x[None])).sum(axis=0)
fg = FunctionGraph([x], [out], clone=False)
[rewritten_out] = local_reduce_join.transform(fg, out.owner)
expected_out = add(exp(x), log(x))
assert equal_computations([rewritten_out], [expected_out])
def test_local_useless_adds(): def test_local_useless_adds():
default_mode = get_default_mode() default_mode = get_default_mode()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论