提交 610d6199 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Don't do symbolic upcasting in `local_upcast_elemwise_constants`

This reduces the number of rewrite passes, by avoiding constant fold of cast/expand_dims/alloc
上级 b5d6f92a
......@@ -30,13 +30,9 @@ from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import (
MakeVector,
alloc,
cast,
constant,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import add, exp, mul
from pytensor.tensor.rewriting.basic import (
alloc_like,
......@@ -44,7 +40,6 @@ from pytensor.tensor.rewriting.basic import (
register_canonicalize,
register_specialize,
)
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.variable import TensorConstant, TensorVariable
......@@ -434,66 +429,40 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
"""
if len(node.outputs) > 1:
return
try:
shape_i = fgraph.shape_feature.shape_i
except AttributeError:
shape_i = None
if isinstance(node.op, Elemwise):
scalar_op = node.op.scalar_op
# print "aa", scalar_op.output_types_preference
if getattr(scalar_op, "output_types_preference", None) in (
ps.upgrade_to_float,
ps.upcast_out,
):
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
output_dtype = node.outputs[0].type.dtype
new_inputs = []
for i in node.inputs:
if i.type.dtype == output_dtype:
new_inputs.append(i)
else:
try:
cval_i = get_underlying_scalar_constant_value(
i, only_process_constants=True
)
if all(i.broadcastable):
new_inputs.append(
shape_padleft(cast(cval_i, output_dtype), i.ndim)
)
else:
if shape_i is None:
return
new_inputs.append(
alloc(
cast(cval_i, output_dtype),
*[shape_i(d)(i) for d in range(i.ndim)],
)
)
# print >> sys.stderr, "AAA",
# *[Shape_i(d)(i) for d in range(i.ndim)]
except NotScalarConstantError:
# for the case of a non-scalar
if isinstance(i, TensorConstant):
new_inputs.append(cast(i, output_dtype))
else:
new_inputs.append(i)
return None
if getattr(node.op.scalar_op, "output_types_preference", None) not in (
ps.upgrade_to_float,
ps.upcast_out,
):
return None
if new_inputs != node.inputs:
rval = [node.op(*new_inputs)]
if not node.outputs[0].type.is_super(rval[0].type):
# This can happen for example when floatX=float32
# and we do the true division between and int64
# and a constant that will get typed as int8.
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
[old_out] = node.outputs
output_dtype = old_out.type.dtype
new_inputs = list(node.inputs)
changed = False
for i, inp in enumerate(node.inputs):
if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant):
new_inputs[i] = constant(inp.data.astype(output_dtype))
changed = True
if not changed:
return None
# As this is just to allow merging more case, if
# the upcast don't work, we can just skip it.
return
rval = node.op(*new_inputs)
if not old_out.type.is_super(rval.type):
# This can happen for example when floatX=float32
# and we do the true division between and int64
# and a constant that will get typed as int8.
# As this is just to allow merging more case, if
# the upcast don't work, we can just skip it.
return None
# Copy over output stacktrace from before upcasting
copy_stack_trace(node.outputs[0], rval)
return rval
# Copy over output stacktrace from before upcasting
copy_stack_trace(old_out, rval)
return [rval]
@node_rewriter([add, mul])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论