提交 610d6199 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Don't do symbolic upcasting in `local_upcast_elemwise_constants`

This reduces the number of rewrite passes, by avoiding constant fold of cast/expand_dims/alloc
上级 b5d6f92a
...@@ -30,13 +30,9 @@ from pytensor.graph.utils import InconsistencyError, MethodNotDefined ...@@ -30,13 +30,9 @@ from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
MakeVector, MakeVector,
alloc,
cast,
constant, constant,
get_underlying_scalar_constant_value,
) )
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import add, exp, mul from pytensor.tensor.math import add, exp, mul
from pytensor.tensor.rewriting.basic import ( from pytensor.tensor.rewriting.basic import (
alloc_like, alloc_like,
...@@ -44,7 +40,6 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -44,7 +40,6 @@ from pytensor.tensor.rewriting.basic import (
register_canonicalize, register_canonicalize,
register_specialize, register_specialize,
) )
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
...@@ -434,66 +429,40 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): ...@@ -434,66 +429,40 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
""" """
if len(node.outputs) > 1: if len(node.outputs) > 1:
return return None
try:
shape_i = fgraph.shape_feature.shape_i if getattr(node.op.scalar_op, "output_types_preference", None) not in (
except AttributeError:
shape_i = None
if isinstance(node.op, Elemwise):
scalar_op = node.op.scalar_op
# print "aa", scalar_op.output_types_preference
if getattr(scalar_op, "output_types_preference", None) in (
ps.upgrade_to_float, ps.upgrade_to_float,
ps.upcast_out, ps.upcast_out,
): ):
return None
# this is the kind of op that we can screw with the input # this is the kind of op that we can screw with the input
# dtypes by upcasting explicitly # dtypes by upcasting explicitly
output_dtype = node.outputs[0].type.dtype [old_out] = node.outputs
new_inputs = [] output_dtype = old_out.type.dtype
for i in node.inputs: new_inputs = list(node.inputs)
if i.type.dtype == output_dtype: changed = False
new_inputs.append(i) for i, inp in enumerate(node.inputs):
else: if inp.type.dtype != output_dtype and isinstance(inp, TensorConstant):
try: new_inputs[i] = constant(inp.data.astype(output_dtype))
cval_i = get_underlying_scalar_constant_value( changed = True
i, only_process_constants=True
) if not changed:
if all(i.broadcastable): return None
new_inputs.append(
shape_padleft(cast(cval_i, output_dtype), i.ndim)
)
else:
if shape_i is None:
return
new_inputs.append(
alloc(
cast(cval_i, output_dtype),
*[shape_i(d)(i) for d in range(i.ndim)],
)
)
# print >> sys.stderr, "AAA",
# *[Shape_i(d)(i) for d in range(i.ndim)]
except NotScalarConstantError:
# for the case of a non-scalar
if isinstance(i, TensorConstant):
new_inputs.append(cast(i, output_dtype))
else:
new_inputs.append(i)
if new_inputs != node.inputs: rval = node.op(*new_inputs)
rval = [node.op(*new_inputs)] if not old_out.type.is_super(rval.type):
if not node.outputs[0].type.is_super(rval[0].type):
# This can happen for example when floatX=float32 # This can happen for example when floatX=float32
# and we do the true division between and int64 # and we do the true division between and int64
# and a constant that will get typed as int8. # and a constant that will get typed as int8.
# As this is just to allow merging more case, if # As this is just to allow merging more case, if
# the upcast don't work, we can just skip it. # the upcast don't work, we can just skip it.
return return None
# Copy over output stacktrace from before upcasting # Copy over output stacktrace from before upcasting
copy_stack_trace(node.outputs[0], rval) copy_stack_trace(old_out, rval)
return rval return [rval]
@node_rewriter([add, mul]) @node_rewriter([add, mul])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论