提交 610d6199 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Don't do symbolic upcasting in `local_upcast_elemwise_constants`

This reduces the number of rewrite passes, by avoiding constant fold of cast/expand_dims/alloc
上级 b5d6f92a
...@@ -30,13 +30,9 @@ from pytensor.graph.utils import InconsistencyError, MethodNotDefined ...@@ -30,13 +30,9 @@ from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
MakeVector, MakeVector,
alloc,
cast,
constant, constant,
get_underlying_scalar_constant_value,
) )
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import add, exp, mul from pytensor.tensor.math import add, exp, mul
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
alloc_like, alloc_like,
...@@ -44,7 +40,6 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -44,7 +40,6 @@ from pytensor.tensor.rewriting.basic import (
register_canonicalize, register_canonicalize,
register_specialize, register_specialize,
) )
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
...@@ -434,66 +429,40 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): ...@@ -434,66 +429,40 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
""" """
if len(node.outputs) > 1: if len(node.outputs) > 1:
return return None
try:
shape_i = fgraph.shape_feature.shape_i if getattr(node.op.scalar_op, "output_types_preference", None) not in (
except AttributeError: ps.upgrade_to_float,
shape_i = None ps.upcast_out,
if isinstance(node.op, Elemwise): ):
scalar_op = node.op.scalar_op return None
# print "aa", scalar_op.output_types_preference
if getattr(scalar_op, "output_types_preference", None) in (
ps.upgrade_to_float,
ps.upcast_out,
):
# this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly
output_dtype = node.outputs[0].type.dtype
new_inputs = []
for i in node.inputs:
if i.type.dtype == output_dtype:
new_inputs.append(i)
else:
try:
cval_i = get_underlying_scalar_constant_value(
i, only_process_constants=True
)
if all(i.broadcastable):
new_inputs.append(
shape_padleft(cast(cval_i, output_dtype), i.ndim)
)
else:
if shape_i is None:
return
new_inputs.append(
alloc(
cast(cval_i, output_dtype),
*[shape_i(d)(i) for d in range(i.ndim)],
)
)
# print >> sys.stderr, "AAA",
# *[Shape_i(d)(i) for d in range(i.ndim)]
except NotScalarConstantError:
# for the case of a non-scalar
if isinstance(i, TensorConstant):
new_inputs.append(cast(i, output_dtype))
else:
new_inputs.append(i)
if new_inputs != node.inputs: # this is the kind of op that we can screw with the input
rval = [node.op(*new_inputs)] # dtypes by upcasting explicitly
if not node.outputs[0].type.is_super(rval[0].type): [old_out] = node.outputs
# This can happen for example when floatX=float32 output_dtype = old_out.type.dtype
# and we do the true division between and int64 new_inputs = list(node.inputs)
# and a constant that will get typed as int8. changed = False
for i, inp in enumerate(node.inputs):
if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant):
new_inputs[i] = constant(inp.data.astype(output_dtype))
changed = True
if not changed:
return None
# As this is just to allow merging more case, if rval = node.op(*new_inputs)
# the upcast don't work, we can just skip it. if not old_out.type.is_super(rval.type):
return # This can happen for example when floatX=float32
# and we do the true division between and int64
# and a constant that will get typed as int8.
# As this is just to allow merging more case, if
# the upcast don't work, we can just skip it.
return None
# Copy over output stacktrace from before upcasting # Copy over output stacktrace from before upcasting
copy_stack_trace(node.outputs[0], rval) copy_stack_trace(old_out, rval)
return rval return [rval]
@node_rewriter([add, mul]) @node_rewriter([add, mul])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论