提交 5db0d833 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Be consistent about second vs alloc in rewrites

上级 67519be2
""" Tensor optimizations addressing the ops in basic.py.""" """ Tensor optimizations addressing the ops in basic.py.
Notes
-----
There are two ways of broadcasting arrays:
second(x, y) == alloc(y, broadcast_shapes(x.shape, y.shape))
The second can be more efficient because x doesn't usually need to be computed when we only want its shape.
It may also allow other rewrites that don't try to modify x when it has multiple clients (for fear of duplicating computation).
However, the first one is easier to reason about.
Knowing we have such a graph allows to do certain rewrites such as "sinking" broadcasting operations below Elemwise.
The same rewrites with alloc would be more complicated as we would need to symbolically combine the shapes of each one.
As an example contrast rewriting the following two equivalent graphs
alloc(x, broadcast_shapes(x.shape, y.shape)) + alloc(y, broadcast_shapes(x.shape, y.shape)) -> x + y
second(y, x) + second(x, y) -> x + y
Theano developers (mostly) preferred to use the first form during canonicalization and introduce the second form later,
via rewrites like `local_fill_to_alloc`, and using the `alloc_like` helper inside rewrites.
Many stabilize and stabilization rewrites refuse to be applied when a variable has multiple clients, so this is important.
"""
import logging import logging
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
......
...@@ -30,7 +30,6 @@ from pytensor.tensor.basic import ( ...@@ -30,7 +30,6 @@ from pytensor.tensor.basic import (
cast, cast,
constant, constant,
extract_constant, extract_constant,
fill,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
ones_like, ones_like,
switch, switch,
...@@ -2041,8 +2040,6 @@ def local_zero_div(fgraph, node): ...@@ -2041,8 +2040,6 @@ def local_zero_div(fgraph, node):
@register_specialize @register_specialize
@node_rewriter([at_pow]) @node_rewriter([at_pow])
def local_pow_specialize(fgraph, node): def local_pow_specialize(fgraph, node):
# here, we are past the point of canonicalization, so we don't want
# to put in un-necessary fills.
if node.op == at_pow: if node.op == at_pow:
# the idea here is that we have pow(x, y) # the idea here is that we have pow(x, y)
odtype = node.outputs[0].dtype odtype = node.outputs[0].dtype
...@@ -2057,7 +2054,7 @@ def local_pow_specialize(fgraph, node): ...@@ -2057,7 +2054,7 @@ def local_pow_specialize(fgraph, node):
if np.all(y == 1): if np.all(y == 1):
rval = [xsym] rval = [xsym]
if np.all(y == 0): if np.all(y == 0):
rval = [fill(xsym, np.asarray(1, dtype=odtype))] rval = [alloc_like(1, xsym, fgraph)]
if np.all(y == 0.5): if np.all(y == 0.5):
rval = [sqrt(xsym)] rval = [sqrt(xsym)]
if np.all(y == -0.5): if np.all(y == -0.5):
...@@ -2158,9 +2155,7 @@ def local_mul_specialize(fgraph, node): ...@@ -2158,9 +2155,7 @@ def local_mul_specialize(fgraph, node):
mul(-1, x, y) -/-> neg(mul(x, y)) mul(-1, x, y) -/-> neg(mul(x, y))
""" """
# here, we are past the point of canonicalization, so we don't
# want to put in un-necessary fills.
#
# at this point [post canonicalize], mul() may have many inputs. # at this point [post canonicalize], mul() may have many inputs.
if node.op == mul: if node.op == mul:
# the idea here is that we have pow(x, y) # the idea here is that we have pow(x, y)
...@@ -2221,16 +2216,7 @@ def local_mul_specialize(fgraph, node): ...@@ -2221,16 +2216,7 @@ def local_mul_specialize(fgraph, node):
@register_specialize @register_specialize
@node_rewriter([add]) @node_rewriter([add])
def local_add_specialize(fgraph, node): def local_add_remove_zeros(fgraph, node):
"""Remove zeros from ``add``s.
TODO: This should be a canonicalization, no?
"""
# here, we are past the point of canonicalization, so we don't want
# to put in un-necessary fills.
if node.op != add:
return False
new_inputs = [] new_inputs = []
for inp in node.inputs: for inp in node.inputs:
try: try:
...@@ -2253,12 +2239,12 @@ def local_add_specialize(fgraph, node): ...@@ -2253,12 +2239,12 @@ def local_add_specialize(fgraph, node):
# Reuse call to constant for cache() # Reuse call to constant for cache()
cst = constant(np.zeros((1,) * ndim, dtype=dtype)) cst = constant(np.zeros((1,) * ndim, dtype=dtype))
assert cst.type.broadcastable == (True,) * ndim assert cst.type.broadcastable == (True,) * ndim
return [broadcast_arrays(cst, *node.inputs)[0]] return [alloc_like(cst, node_output, fgraph)]
if len(new_inputs) == 1: if len(new_inputs) == 1:
ret = [broadcast_arrays(new_inputs[0], *node.inputs)[0]] ret = [alloc_like(new_inputs[0], node_output, fgraph)]
else: else:
ret = [broadcast_arrays(add(*new_inputs), *node.inputs)[0]] ret = [alloc_like(add(*new_inputs), node_output, fgraph)]
# The dtype should not be changed. It can happen if the input # The dtype should not be changed. It can happen if the input
# that was forcing upcasting was equal to 0. # that was forcing upcasting was equal to 0.
...@@ -2376,7 +2362,7 @@ def local_log1p(fgraph, node): ...@@ -2376,7 +2362,7 @@ def local_log1p(fgraph, node):
ninp = nonconsts[0] ninp = nonconsts[0]
if ninp.dtype != log_arg.type.dtype: if ninp.dtype != log_arg.type.dtype:
ninp = ninp.astype(node.outputs[0].dtype) ninp = ninp.astype(node.outputs[0].dtype)
return [broadcast_arrays(log1p(ninp), *scalar_inputs)[0]] return [alloc_like(log1p(ninp), node.outputs[0], fgraph)]
elif log_arg.owner and log_arg.owner.op == sub: elif log_arg.owner and log_arg.owner.op == sub:
one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True) one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True)
...@@ -3572,10 +3558,11 @@ def local_reciprocal_1_plus_exp(fgraph, node): ...@@ -3572,10 +3558,11 @@ def local_reciprocal_1_plus_exp(fgraph, node):
if nonconsts[0].owner and nonconsts[0].owner.op == exp: if nonconsts[0].owner and nonconsts[0].owner.op == exp:
if scalars_ and np.allclose(np.sum(scalars_), 1): if scalars_ and np.allclose(np.sum(scalars_), 1):
out = [ out = [
broadcast_arrays( alloc_like(
sigmoid(neg(nonconsts[0].owner.inputs[0])), sigmoid(neg(nonconsts[0].owner.inputs[0])),
*scalar_inputs, node.outputs[0],
)[0] fgraph,
)
] ]
# keep combined stack traces of # keep combined stack traces of
# exp(x): nonconsts[0], # exp(x): nonconsts[0],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论