提交 b128827c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Don't introduce Second in AlegrabicCanonizer because of shape specialization…

Don't introduce Second in AlegrabicCanonizer because of shape specialization (only for broadcasting)
上级 5b50d27d
...@@ -1135,10 +1135,12 @@ class AlgebraicCanonizer(NodeRewriter): ...@@ -1135,10 +1135,12 @@ class AlgebraicCanonizer(NodeRewriter):
if new.type.dtype != out.type.dtype: if new.type.dtype != out.type.dtype:
new = cast(new, out.type.dtype) new = cast(new, out.type.dtype)
if new.type != out.type: if new.type.broadcastable != out.type.broadcastable:
new = fill_chain(new, node.inputs)[0] new = fill_chain(new, node.inputs)[0]
if new.type == out.type: if (new.type.dtype == out.type.dtype) and (
new.type.broadcastable == out.type.broadcastable
):
new.tag.values_eq_approx = values_eq_approx_remove_inf_nan new.tag.values_eq_approx = values_eq_approx_remove_inf_nan
copy_stack_trace(out, new) copy_stack_trace(out, new)
return [new] return [new]
......
...@@ -30,7 +30,7 @@ from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph ...@@ -30,7 +30,7 @@ from pytensor.graph.rewriting.utils import is_same_graph, rewrite_graph
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import debugprint from pytensor.printing import debugprint
from pytensor.tensor import inplace from pytensor.tensor import inplace
from pytensor.tensor.basic import Alloc, join, switch from pytensor.tensor.basic import Alloc, join, second, switch
from pytensor.tensor.blas import Dot22, Gemv from pytensor.tensor.blas import Dot22, Gemv
from pytensor.tensor.blas_c import CGemv from pytensor.tensor.blas_c import CGemv
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
...@@ -96,7 +96,7 @@ from pytensor.tensor.rewriting.math import ( ...@@ -96,7 +96,7 @@ from pytensor.tensor.rewriting.math import (
perform_sigm_times_exp, perform_sigm_times_exp,
simplify_mul, simplify_mul,
) )
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape, specify_shape
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
cmatrix, cmatrix,
...@@ -979,6 +979,28 @@ class TestAlgebraicCanonizer: ...@@ -979,6 +979,28 @@ class TestAlgebraicCanonizer:
# No rewrite was applied # No rewrite was applied
assert z_rewritten is z assert z_rewritten is z
def test_shape_specified_by_constant(self):
x = vector("x")
const = np.full(shape=(5,), fill_value=2.0).astype(config.floatX)
out = x * const
new_out = rewrite_graph(
out, custom_rewrite=in2out(local_mul_canonizer, name="test")
)
expected_out = np.array([2.0]).astype(config.floatX) * specify_shape(x, (5,))
assert equal_computations([new_out], [expected_out])
def test_broadcasted_by_constant(self):
x = vector("x")
const = np.full(shape=(3, 5), fill_value=2.0).astype(config.floatX)
out = x * const
new_out = rewrite_graph(
out, custom_rewrite=in2out(local_mul_canonizer, name="test")
)
expected_out = second(const, np.array([[2.0]], dtype=config.floatX) * x)
assert equal_computations([new_out], [expected_out])
def test_local_merge_abs(): def test_local_merge_abs():
x, y, z = matrices("xyz") x, y, z = matrices("xyz")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论