提交 b128827c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Don't introduce Second in AlegrabicCanonizer because of shape specialization…

Don't introduce Second in AlegrabicCanonizer because of shape specialization (only for broadcasting)
上级 5b50d27d
......@@ -1135,10 +1135,12 @@ class AlgebraicCanonizer(NodeRewriter):
if new.type.dtype != out.type.dtype:
new = cast(new, out.type.dtype)
if new.type != out.type:
if new.type.broadcastable != out.type.broadcastable:
new = fill_chain(new, node.inputs)[0]
if new.type == out.type:
if (new.type.dtype == out.type.dtype) and (
new.type.broadcastable == out.type.broadcastable
):
new.tag.values_eq_approx = values_eq_approx_remove_inf_nan
copy_stack_trace(out, new)
return [new]
......
......@@ -30,7 +30,7 @@ from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import debugprint
from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, join, switch
from pytensor.tensor.basic import Alloc, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
......@@ -96,7 +96,7 @@ from pytensor.tensor.rewriting.math import (
perform_sigm_times_exp,
simplify_mul,
)
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape
from pytensor.tensor.type import (
TensorType,
cmatrix,
......@@ -979,6 +979,28 @@ class TestAlgebraicCanonizer:
# No rewrite was applied
assert z_rewritten is z
def test_shape_specified_by_constant(self):
x = vector("x")
const = np.full(shape=(5,), fill_value=2.0).astype(config.floatX)
out = x * const
new_out = rewrite_graph(
out, custom_rewrite=in2out(local_mul_canonizer, name="test")
)
expected_out = np.array([2.0]).astype(config.floatX) * specify_shape(x, (5,))
assert equal_computations([new_out], [expected_out])
def test_broadcasted_by_constant(self):
x = vector("x")
const = np.full(shape=(3, 5), fill_value=2.0).astype(config.floatX)
out = x * const
new_out = rewrite_graph(
out, custom_rewrite=in2out(local_mul_canonizer, name="test")
)
expected_out = second(const, np.array([[2.0]], dtype=config.floatX) * x)
assert equal_computations([new_out], [expected_out])
def test_local_merge_abs():
x, y, z = matrices("xyz")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论