提交 32aadc8c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Deprecate `extract_constant`

上级 aad6fb75
......@@ -54,6 +54,7 @@ from pytensor.scan.utils import (
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
......@@ -665,8 +666,10 @@ def inner_sitsot_only_last_step_used(
client = fgraph.clients[outer_var][0][0]
if isinstance(client, Apply) and isinstance(client.op, Subtensor):
lst = get_idx_list(client.inputs, client.op.idx_list)
if len(lst) == 1 and pt.extract_constant(lst[0]) == -1:
return True
return (
len(lst) == 1
and get_scalar_constant_value(lst[0], raise_not_constant=False) == -1
)
return False
......@@ -1341,10 +1344,17 @@ def scan_save_mem(fgraph, node):
if isinstance(this_slice[0], slice) and this_slice[0].stop is None:
global_nsteps = None
if isinstance(cf_slice[0], slice):
stop = pt.extract_constant(cf_slice[0].stop)
stop = get_scalar_constant_value(
cf_slice[0].stop, raise_not_constant=False
)
else:
stop = pt.extract_constant(cf_slice[0]) + 1
if stop == maxsize or stop == pt.extract_constant(length):
stop = (
get_scalar_constant_value(cf_slice[0], raise_not_constant=False)
+ 1
)
if stop == maxsize or stop == get_scalar_constant_value(
length, raise_not_constant=False
):
stop = None
else:
# there is a **gotcha** here ! Namely, scan returns an
......@@ -1448,9 +1458,13 @@ def scan_save_mem(fgraph, node):
cf_slice = get_canonical_form_slice(this_slice[0], length)
if isinstance(cf_slice[0], slice):
start = pt.extract_constant(cf_slice[0].start)
start = pt.get_scalar_constant_value(
cf_slice[0].start, raise_not_constant=False
)
else:
start = pt.extract_constant(cf_slice[0])
start = pt.get_scalar_constant_value(
cf_slice[0], raise_not_constant=False
)
if start == 0 or store_steps[i] == 0:
store_steps[i] = 0
......@@ -1625,7 +1639,7 @@ def scan_save_mem(fgraph, node):
# 3.6 Compose the new scan
# TODO: currently we don't support scan with 0 step. So
# don't create one.
if pt.extract_constant(node_ins[0]) == 0:
if get_scalar_constant_value(node_ins[0], raise_not_constant=False) == 0:
return False
# Do not call make_node for test_value
......
......@@ -268,27 +268,7 @@ _scalar_constant_value_elemwise_ops = (
)
def get_scalar_constant_value(
v, elemwise=True, only_process_constants=False, max_recur=10
):
"""
Checks whether 'v' is a scalar (ndim = 0).
If 'v' is a scalar then this function fetches the underlying constant by calling
'get_underlying_scalar_constant_value()'.
If 'v' is not a scalar, it raises a NotScalarConstantError.
"""
if isinstance(v, Variable | np.ndarray):
if v.ndim != 0:
raise NotScalarConstantError()
return get_underlying_scalar_constant_value(
v, elemwise, only_process_constants, max_recur
)
def get_underlying_scalar_constant_value(
def _get_underlying_scalar_constant_value(
orig_v, elemwise=True, only_process_constants=False, max_recur=10
):
"""Return the constant scalar(0-D) value underlying variable `v`.
......@@ -381,7 +361,7 @@ def get_underlying_scalar_constant_value(
elif isinstance(op, CheckAndRaise):
# check if all conditions are constant and true
conds = [
get_underlying_scalar_constant_value(c, max_recur=max_recur)
_get_underlying_scalar_constant_value(c, max_recur=max_recur)
for c in v.owner.inputs[1:]
]
if builtins.all(0 == c.ndim and c != 0 for c in conds):
......@@ -395,7 +375,7 @@ def get_underlying_scalar_constant_value(
continue
if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops):
const = [
get_underlying_scalar_constant_value(i, max_recur=max_recur)
_get_underlying_scalar_constant_value(i, max_recur=max_recur)
for i in v.owner.inputs
]
ret = [[None]]
......@@ -414,7 +394,7 @@ def get_underlying_scalar_constant_value(
v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops
):
const = [
get_underlying_scalar_constant_value(i, max_recur=max_recur)
_get_underlying_scalar_constant_value(i, max_recur=max_recur)
for i in v.owner.inputs
]
ret = [[None]]
......@@ -457,7 +437,7 @@ def get_underlying_scalar_constant_value(
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, Type):
idx = get_underlying_scalar_constant_value(
idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
try:
......@@ -491,14 +471,13 @@ def get_underlying_scalar_constant_value(
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, Type):
idx = get_underlying_scalar_constant_value(
idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
# Python 2.4 does not support indexing with numpy.integer
# So we cast it.
idx = int(idx)
ret = v.owner.inputs[0].owner.inputs[idx]
ret = get_underlying_scalar_constant_value(ret, max_recur=max_recur)
ret = _get_underlying_scalar_constant_value(
ret, max_recur=max_recur
)
# MakeVector can cast implicitly its input in some case.
return np.asarray(ret, dtype=v.type.dtype)
......@@ -513,7 +492,7 @@ def get_underlying_scalar_constant_value(
idx_list = op.idx_list
idx = idx_list[0]
if isinstance(idx, Type):
idx = get_underlying_scalar_constant_value(
idx = _get_underlying_scalar_constant_value(
owner.inputs[1], max_recur=max_recur
)
grandparent = leftmost_parent.owner.inputs[0]
......@@ -523,7 +502,9 @@ def get_underlying_scalar_constant_value(
grandparent.owner.op, Unbroadcast
):
ggp_shape = grandparent.owner.inputs[0].type.shape
l = [get_underlying_scalar_constant_value(s) for s in ggp_shape]
l = [
_get_underlying_scalar_constant_value(s) for s in ggp_shape
]
gp_shape = tuple(l)
if not (idx < ndim):
......@@ -545,7 +526,7 @@ def get_underlying_scalar_constant_value(
if isinstance(grandparent, Constant):
return np.asarray(np.shape(grandparent.data)[idx])
elif isinstance(op, CSM):
data = get_underlying_scalar_constant_value(
data = _get_underlying_scalar_constant_value(
v.owner.inputs, elemwise=elemwise, max_recur=max_recur
)
# Sparse variable can only be constant if zero (or I guess if homogeneously dense)
......@@ -556,6 +537,93 @@ def get_underlying_scalar_constant_value(
raise NotScalarConstantError()
def get_underlying_scalar_constant_value(
v,
*,
elemwise=True,
only_process_constants=False,
max_recur=10,
raise_not_constant=True,
):
"""Return the unique constant scalar(0-D) value underlying variable `v`.
If `v` is the output of dimshuffles, fills, allocs, etc,
cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise
and some pattern with Subtensor, this function digs through them.
If `v` is not some view of constant scalar data, then raise a
NotScalarConstantError.
This function performs symbolic reasoning about the value of `v`, as opposed to numerical reasoning by
constant folding the inputs of `v`.
Parameters
----------
v: Variable
elemwise : bool
If False, we won't try to go into elemwise. So this call is faster.
But we still investigate in Second Elemwise (as this is a substitute
for Alloc)
only_process_constants : bool
If True, we only attempt to obtain the value of `orig_v` if it's
directly constant and don't try to dig through dimshuffles, fills,
allocs, and other to figure out its value.
max_recur : int
The maximum number of recursion.
raise_not_constant: bool, default True
If True, raise a NotScalarConstantError if `v` does not have an
underlying constant scalar value. If False, return `v` as is.
Raises
------
NotScalarConstantError
`v` does not have an underlying constant scalar value.
Only rasise if raise_not_constant is True.
"""
try:
return _get_underlying_scalar_constant_value(
v,
elemwise=elemwise,
only_process_constants=only_process_constants,
max_recur=max_recur,
)
except NotScalarConstantError:
if raise_not_constant:
raise
return v
def get_scalar_constant_value(
v,
elemwise=True,
only_process_constants=False,
max_recur=10,
raise_not_constant: bool = True,
):
"""
Checks whether 'v' is a scalar (ndim = 0).
If 'v' is a scalar then this function fetches the underlying constant by calling
'get_underlying_scalar_constant_value()'.
If 'v' is not a scalar, it raises a NotScalarConstantError.
"""
if isinstance(v, TensorVariable | np.ndarray):
if v.ndim != 0:
print(v, v.ndim)
raise NotScalarConstantError("Input ndim != 0")
return get_underlying_scalar_constant_value(
v,
elemwise=elemwise,
only_process_constants=only_process_constants,
max_recur=max_recur,
raise_not_constant=raise_not_constant,
)
class TensorFromScalar(COp):
__props__ = ()
......@@ -2012,16 +2080,16 @@ def extract_constant(x, elemwise=True, only_process_constants=False):
ScalarVariable, we convert it to a tensor with tensor_from_scalar.
"""
try:
x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants)
except NotScalarConstantError:
pass
if isinstance(x, ps.ScalarVariable | ps.sharedvar.ScalarSharedVariable):
if x.owner and isinstance(x.owner.op, ScalarFromTensor):
x = x.owner.inputs[0]
else:
x = tensor_from_scalar(x)
return x
warnings.warn(
"extract_constant is deprecated. Use `get_underlying_scalar_constant_value(..., raise_not_constant=False)`",
FutureWarning,
)
return get_underlying_scalar_constant_value(
x,
elemwise=elemwise,
only_process_constants=only_process_constants,
raise_not_constant=False,
)
def transpose(x, axes=None):
......@@ -4401,7 +4469,6 @@ __all__ = [
"split",
"transpose",
"matrix_transpose",
"extract_constant",
"default",
"tensor_copy",
"transfer",
......
......@@ -30,7 +30,7 @@ import pytensor.scalar.basic as ps
from pytensor import compile, config
from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant, Variable
from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter,
NodeRewriter,
......@@ -55,8 +55,8 @@ from pytensor.tensor.basic import (
as_tensor_variable,
atleast_Nd,
cast,
extract_constant,
fill,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
join,
ones_like,
......@@ -478,7 +478,12 @@ def local_alloc_sink_dimshuffle(fgraph, node):
output_shape = node.inputs[1:]
num_dims_with_size_1_added_to_left = 0
for i in range(len(output_shape) - inp.ndim):
if extract_constant(output_shape[i], only_process_constants=True) == 1:
if (
get_scalar_constant_value(
output_shape[i], only_process_constants=True, raise_not_constant=False
)
== 1
):
num_dims_with_size_1_added_to_left += 1
else:
break
......@@ -538,93 +543,90 @@ def local_useless_elemwise(fgraph, node):
xor(x, x) -> zeros_like(x)
TODO: This implementation is painfully redundant.
TODO: Allow rewrite when useless input broadcasts output
"""
if isinstance(node.op, Elemwise):
# We call zeros_like and one_like with opt=True to generate a
# cleaner graph.
dtype = node.outputs[0].dtype
if node.op.scalar_op == ps.eq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be true
ret = ones_like(node.inputs[0], dtype=dtype, opt=True)
# Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret)
return [ret]
elif node.op.scalar_op == ps.neq and len(node.inputs) == 2:
if node.inputs[0] == node.inputs[1]:
# it is the same var in the graph. That will always be false
ret = zeros_like(node.inputs[0], dtype=dtype, opt=True)
# Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret)
return [ret]
elif node.op.scalar_op == ps.mul and len(node.inputs) == 1:
# No need to copy over any stack trace
return [node.inputs[0]]
elif node.op.scalar_op == ps.add and len(node.inputs) == 1:
# No need to copy over any stack trace
return [node.inputs[0]]
elif node.op.scalar_op == ps.identity and len(node.inputs) == 1:
return [node.inputs[0]]
elif isinstance(node.op.scalar_op, ps.AND) and len(node.inputs) == 2:
if isinstance(node.inputs[0], TensorConstant):
const_val = extract_constant(
node.inputs[0], only_process_constants=True
)
if not isinstance(const_val, Variable):
if const_val == 0:
return [zeros_like(node.inputs[1], dtype=dtype, opt=True)]
elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise AND,
# and this rewrite would be wrong
return [node.inputs[1].astype(node.outputs[0].dtype)]
if isinstance(node.inputs[1], TensorConstant):
const_val = extract_constant(
node.inputs[1], only_process_constants=True
)
if not isinstance(const_val, Variable):
if const_val == 0:
return [zeros_like(node.inputs[0], dtype=dtype, opt=True)]
elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise AND,
# and this rewrite would be wrong
return [node.inputs[0].astype(node.outputs[0].dtype)]
elif isinstance(node.op.scalar_op, ps.OR) and len(node.inputs) == 2:
if isinstance(node.inputs[0], TensorConstant):
const_val = extract_constant(
node.inputs[0], only_process_constants=True
)
if not isinstance(const_val, Variable):
if const_val == 0:
return [node.inputs[1].astype(node.outputs[0].dtype)]
elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise OR,
# and this rewrite would be wrong
return [ones_like(node.inputs[1], dtype=dtype, opt=True)]
if isinstance(node.inputs[1], TensorConstant):
const_val = extract_constant(
node.inputs[1], only_process_constants=True
)
if not isinstance(const_val, Variable):
if const_val == 0:
return [node.inputs[0].astype(node.outputs[0].dtype)]
elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise OR,
# and this rewrite would be wrong
return [ones_like(node.inputs[0], dtype=dtype, opt=True)]
elif isinstance(node.op.scalar_op, ps.XOR) and len(node.inputs) == 2:
if node.inputs[0] is node.inputs[1]:
return [zeros_like(node.inputs[0], dtype=dtype, opt=True)]
out_bcast = node.outputs[0].type.broadcastable
dtype = node.outputs[0].type.dtype
scalar_op = node.op.scalar_op
if isinstance(scalar_op, ps.EQ) and len(node.inputs) == 2:
if node.inputs[0] is node.inputs[1]:
# it is the same var in the graph. That will always be true
ret = ones_like(node.inputs[0], dtype=dtype, opt=True)
# Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret)
return [ret]
elif isinstance(scalar_op, ps.NEQ | ps.XOR) and len(node.inputs) == 2:
if node.inputs[0] is node.inputs[1]:
# it is the same var in the graph. That will always be false
ret = zeros_like(node.inputs[0], dtype=dtype, opt=True)
# Copy stack trace from input to constant output
copy_stack_trace(node.outputs[0], ret)
return [ret]
elif (
isinstance(node.op.scalar_op, ps.Mul | ps.Add | ps.Identity)
and len(node.inputs) == 1
):
# No need to copy over any stack trace
return [node.inputs[0]]
elif isinstance(node.op.scalar_op, ps.AND) and len(node.inputs) == 2:
if (
isinstance(node.inputs[0], TensorConstant)
and node.inputs[1].type.broadcastable == out_bcast
):
const_val = node.inputs[0].unique_value
if const_val is not None:
if const_val == 0:
return [zeros_like(node.inputs[1], dtype=dtype, opt=True)]
elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise AND,
# and this rewrite would be wrong
return [node.inputs[1].astype(node.outputs[0].dtype)]
if (
isinstance(node.inputs[1], TensorConstant)
and node.inputs[0].type.broadcastable == out_bcast
):
const_val = node.inputs[1].unique_value
if const_val is not None:
if const_val == 0:
return [zeros_like(node.inputs[0], dtype=dtype, opt=True)]
elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise AND,
# and this rewrite would be wrong
return [node.inputs[0].astype(node.outputs[0].dtype)]
elif isinstance(node.op.scalar_op, ps.OR) and len(node.inputs) == 2:
if (
isinstance(node.inputs[0], TensorConstant)
and node.inputs[1].type.broadcastable == out_bcast
):
const_val = node.inputs[0].unique_value
if const_val is not None:
if const_val == 0:
return [node.inputs[1].astype(node.outputs[0].dtype)]
elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise OR,
# and this rewrite would be wrong
return [ones_like(node.inputs[1], dtype=dtype, opt=True)]
if (
isinstance(node.inputs[1], TensorConstant)
and node.inputs[0].type.broadcastable == out_bcast
):
const_val = node.inputs[1].unique_value
if const_val is not None:
if const_val == 0:
return [node.inputs[0].astype(node.outputs[0].dtype)]
elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise OR,
# and this rewrite would be wrong
return [ones_like(node.inputs[0], dtype=dtype, opt=True)]
@register_specialize
......@@ -988,13 +990,10 @@ def local_useless_switch(fgraph, node):
left = node.inputs[1]
right = node.inputs[2]
cond_var = node.inputs[0]
cond = extract_constant(cond_var, only_process_constants=True)
out_bcast = node.outputs[0].type.broadcastable
if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance(
cond, np.number | np.bool_
):
if cond == 0:
if isinstance(cond_var, TensorConstant) and cond_var.unique_value is not None:
if cond_var.unique_value == 0:
correct_out = right
else:
correct_out = left
......@@ -1014,7 +1013,7 @@ def local_useless_switch(fgraph, node):
# if left is right -> left
if equivalent_up_to_constant_casting(left, right):
if left.type.broadcastable != out_bcast:
left, _ = broadcast_arrays(left, cond)
left, _ = broadcast_arrays(left, cond_var)
out_dtype = node.outputs[0].type.dtype
if left.type.dtype != out_dtype:
......@@ -1026,13 +1025,22 @@ def local_useless_switch(fgraph, node):
# This case happens with scan.
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
if (
cond_var.owner
node.outputs[0].type.ndim == 0
and cond_var.owner
and isinstance(cond_var.owner.op, Elemwise)
and isinstance(cond_var.owner.op.scalar_op, ps.LE)
and cond_var.owner.inputs[0].owner
and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i)
and extract_constant(cond_var.owner.inputs[1], only_process_constants=True) == 0
and extract_constant(left, only_process_constants=True) == 0
and get_scalar_constant_value(
cond_var.owner.inputs[1],
only_process_constants=True,
raise_not_constant=False,
)
== 0
and get_scalar_constant_value(
left, only_process_constants=True, raise_not_constant=False
)
== 0
and right == cond_var.owner.inputs[0]
):
assert node.outputs[0].type.is_super(right.type)
......
......@@ -28,7 +28,6 @@ from pytensor.tensor.basic import (
as_tensor_variable,
cast,
constant,
extract_constant,
get_underlying_scalar_constant_value,
moveaxis,
ones_like,
......@@ -566,11 +565,14 @@ def local_expm1(fgraph, node):
in1.owner
and isinstance(in1.owner.op, Elemwise)
and isinstance(in1.owner.op.scalar_op, ps.Exp)
and extract_constant(in2, only_process_constants=False) == 1
and get_underlying_scalar_constant_value(in2, raise_not_constant=False) == 1
):
in11 = in1.owner.inputs[0]
new_out = expm1(in11)
if new_out.type.broadcastable != out.type.broadcastable:
new_out = broadcast_arrays(in11, in2)[0]
if new_out.dtype != out.dtype:
new_out = cast(new_out, dtype=out.dtype)
......@@ -1345,12 +1347,13 @@ def local_useless_elemwise_comparison(fgraph, node):
the graph easier to read.
"""
# TODO: Refactor this function. So much repeated code!
if node.op.scalar_op.nin != 2:
return
# We call zeros_like and one_like with opt=True to generate a
# cleaner graph.
dtype = node.outputs[0].dtype
dtype = node.outputs[0].type.dtype
out_bcast = node.outputs[0].type.broadcastable
# Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X)
if (
......@@ -1361,6 +1364,7 @@ def local_useless_elemwise_comparison(fgraph, node):
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X)
if (
isinstance(node.op.scalar_op, ps.LE | ps.GE)
......@@ -1371,6 +1375,7 @@ def local_useless_elemwise_comparison(fgraph, node):
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[{minimum,maximum}](X, X) -> X
if (
isinstance(node.op.scalar_op, ps.ScalarMinimum | ps.ScalarMaximum)
......@@ -1386,64 +1391,72 @@ def local_useless_elemwise_comparison(fgraph, node):
isinstance(node.op.scalar_op, ps.LT)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Shape_i)
and extract_constant(node.inputs[1], only_process_constants=True) == 0
and get_underlying_scalar_constant_value(
node.inputs[1], only_process_constants=True, raise_not_constant=False
)
== 0
):
res = zeros_like(node.inputs[0], dtype=dtype, opt=True)
if res.type.broadcastable != out_bcast:
res = broadcast_arrays(res, node.inputs[1])[0]
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
if (
isinstance(node.op.scalar_op, ps.GE)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Shape_i)
and extract_constant(node.inputs[1], only_process_constants=True) == 0
and get_underlying_scalar_constant_value(
node.inputs[1], only_process_constants=True, raise_not_constant=False
)
== 0
):
res = ones_like(node.inputs[0], dtype=dtype, opt=True)
if res.type.broadcastable != out_bcast:
res = broadcast_arrays(res, node.inputs[1])[0]
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
if (
isinstance(node.op.scalar_op, ps.ScalarMaximum)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Shape_i)
and extract_constant(node.inputs[1], only_process_constants=True) == 0
):
# No need to copy over stacktrace.
return [node.inputs[0]]
# Elemwise[maximum](0, X.shape[i]) -> X.shape[i]
if (
isinstance(node.op.scalar_op, ps.ScalarMaximum)
and extract_constant(node.inputs[0], only_process_constants=True) == 0
and node.inputs[1].owner
and isinstance(node.inputs[1].owner.op, Shape_i)
):
# No need to copy over stacktrace.
return [node.inputs[1]]
# Elemwise[minimum](X.shape[i], 0) -> 0
if (
isinstance(node.op.scalar_op, ps.ScalarMinimum)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Shape_i)
and extract_constant(node.inputs[1], only_process_constants=True) == 0
):
res = zeros_like(node.inputs[0], dtype=dtype, opt=True)
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
if isinstance(node.op.scalar_op, ps.ScalarMaximum):
for idx in range(2):
if (
node.inputs[idx].owner
and isinstance(node.inputs[idx].owner.op, Shape_i)
and get_underlying_scalar_constant_value(
node.inputs[1 - idx],
only_process_constants=True,
raise_not_constant=False,
)
== 0
):
res = node.inputs[idx]
if res.type.broadcastable != out_bcast:
res = broadcast_arrays(res, node.inputs[1 - idx])[0]
# No need to copy over stacktrace.
return [res]
# Elemwise[minimum](0, X.shape[i]) -> 0
if (
isinstance(node.op.scalar_op, ps.ScalarMinimum)
and extract_constant(node.inputs[0], only_process_constants=True) == 0
and node.inputs[1].owner
and isinstance(node.inputs[1].owner.op, Shape_i)
):
res = zeros_like(node.inputs[1], dtype=dtype, opt=True)
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[minimum](X.shape[i], 0) -> 0
if isinstance(node.op.scalar_op, ps.ScalarMinimum):
for idx in range(2):
if (
node.inputs[idx].owner
and isinstance(node.inputs[idx].owner.op, Shape_i)
and get_underlying_scalar_constant_value(
node.inputs[1 - idx],
only_process_constants=True,
raise_not_constant=False,
)
== 0
):
res = zeros_like(node.inputs[idx], dtype=dtype, opt=True)
if res.type.broadcastable != out_bcast:
res = broadcast_arrays(res, node.inputs[1 - idx])[0]
# No need to copy over stacktrace.
return [res]
# Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X)
if (
......@@ -1455,12 +1468,18 @@ def local_useless_elemwise_comparison(fgraph, node):
isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs
)
and extract_constant(node.inputs[1], only_process_constants=True) == 0
and get_underlying_scalar_constant_value(
node.inputs[1], only_process_constants=True, raise_not_constant=False
)
== 0
):
res = zeros_like(node.inputs[0], dtype=dtype, opt=True)
if res.type.broadcastable != out_bcast:
res = broadcast_arrays(res, node.inputs[1])[0]
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
if (
isinstance(node.op.scalar_op, ps.GE)
......@@ -1471,57 +1490,61 @@ def local_useless_elemwise_comparison(fgraph, node):
isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs
)
and extract_constant(node.inputs[1], only_process_constants=True) == 0
and get_underlying_scalar_constant_value(
node.inputs[1], only_process_constants=True, raise_not_constant=False
)
== 0
):
res = ones_like(node.inputs[0], dtype=dtype, opt=True)
if res.type.broadcastable != out_bcast:
res = broadcast_arrays(res, node.inputs[1])[0]
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[EQ](Subtensor(Shape(x)), -N)
# Elemwise[EQ](somegraph that only depend of shape, -N)
# TODO: handle the case where the -N is on either side
"""
|Elemwise{eq,no_inplace} [id B] ''
| |Subtensor{int64} [id C] ''
| | |Join [id D] ''
| | | |TensorConstant{0} [id E]
| | | |Subtensor{int64:int64:} [id F] ''
| | | | |Shape [id G] ''
"""
# Elemwise[EQ](Subtensor(Shape(x)), -N)
# Elemwise[EQ](somegraph that only depend of shape, -N)
# TODO: handle the case where the -N is on either side
"""
|Elemwise{eq,no_inplace} [id B] ''
| |Subtensor{int64} [id C] ''
| | |Join [id D] ''
| | | |TensorConstant{0} [id E]
| | | |Subtensor{int64:int64:} [id F] ''
| | | | |Shape [id G] ''
"""
def investigate(node):
def investigate_if_shape(node) -> bool:
"Return True if values will be shapes, so >= 0"
if isinstance(node.op, Shape | Shape_i):
return True
elif isinstance(node.op, Subtensor) and node.inputs[0].owner:
return investigate(node.inputs[0].owner)
return investigate_if_shape(node.inputs[0].owner)
elif isinstance(node.op, Join):
return all(v.owner and investigate(v.owner) for v in node.inputs[1:])
return all(
v.owner and investigate_if_shape(v.owner) for v in node.inputs[1:]
)
elif isinstance(node.op, MakeVector):
return all(v.owner and investigate(v.owner) for v in node.inputs)
return all(v.owner and investigate_if_shape(v.owner) for v in node.inputs)
return False
if (
isinstance(node.op.scalar_op, ps.EQ)
and node.inputs[0].owner
and investigate(node.inputs[0].owner)
and investigate_if_shape(node.inputs[0].owner)
and (
isinstance(node.inputs[1], TensorConstant)
and node.inputs[1].unique_value is not None
and node.inputs[1].unique_value < 0
)
):
try:
cst = get_underlying_scalar_constant_value(
node.inputs[1], only_process_constants=True
)
res = zeros_like(node.inputs[0], dtype=dtype, opt=True)
if cst < 0:
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
res = zeros_like(node.inputs[0], dtype=dtype, opt=True)
if res.type.broadcastable != out_bcast:
res = broadcast_arrays(res, node.inputs[1])[0]
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
except NotScalarConstantError:
pass
return
......@@ -2223,12 +2246,21 @@ def local_log1p(fgraph, node):
return [alloc_like(log1p(ninp), node.outputs[0], fgraph)]
elif log_arg.owner and log_arg.owner.op == sub:
one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True)
one, other = log_arg.owner.inputs
try:
one = get_underlying_scalar_constant_value(one, only_process_constants=True)
except NotScalarConstantError:
return
if one != 1:
return
other = log_arg.owner.inputs[1]
if other.dtype != log_arg.dtype:
if other.type.broadcastable != log_arg.type.broadcastable:
other = broadcast_arrays(other, one)[0]
if other.type.dtype != log_arg.type.dtype:
other = other.astype(log_arg.dtype)
return [log1p(neg(other))]
......
......@@ -22,7 +22,7 @@ from pytensor.tensor.basic import (
as_tensor_variable,
cast,
constant,
extract_constant,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
register_infer_shape,
stack,
......@@ -354,7 +354,9 @@ class ShapeFeature(Feature):
not hasattr(r.type, "shape")
or r.type.shape[i] != 1
or self.lscalar_one.equals(shape_vars[i])
or self.lscalar_one.equals(extract_constant(shape_vars[i]))
or self.lscalar_one.equals(
get_scalar_constant_value(shape_vars[i], raise_not_constant=False)
)
for i in range(r.type.ndim)
)
self.shape_of[r] = tuple(shape_vars)
......@@ -450,7 +452,11 @@ class ShapeFeature(Feature):
)
or self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals(
extract_constant(merged_shape[i], only_process_constants=True)
get_underlying_scalar_constant_value(
merged_shape[i],
only_process_constants=True,
raise_not_constant=False,
)
)
for i in range(r.type.ndim)
)
......@@ -474,7 +480,11 @@ class ShapeFeature(Feature):
not hasattr(r.type, "shape")
or r.type.shape[idx] != 1
or self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals(extract_constant(new_shape[idx]))
or self.lscalar_one.equals(
get_underlying_scalar_constant_value(
new_shape[idx], raise_not_constant=False
)
)
for idx in range(r.type.ndim)
)
self.shape_of[r] = tuple(new_shape)
......@@ -847,7 +857,10 @@ def local_useless_reshape(fgraph, node):
outshp_i.owner
and isinstance(outshp_i.owner.op, Subtensor)
and len(outshp_i.owner.inputs) == 2
and extract_constant(outshp_i.owner.inputs[1]) == dim
and get_scalar_constant_value(
outshp_i.owner.inputs[1], raise_not_constant=False
)
== dim
):
subtensor_inp = outshp_i.owner.inputs[0]
if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape):
......@@ -857,7 +870,9 @@ def local_useless_reshape(fgraph, node):
continue
# Match constant if input.type.shape[dim] == constant
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1)
cst_outshp_i = get_scalar_constant_value(
outshp_i, only_process_constants=True, raise_not_constant=False
)
if inp.type.shape[dim] == cst_outshp_i:
shape_match[dim] = True
continue
......@@ -872,8 +887,12 @@ def local_useless_reshape(fgraph, node):
if shape_feature:
inpshp_i = shape_feature.get_shape(inp, dim)
if inpshp_i == outshp_i or (
extract_constant(inpshp_i, only_process_constants=True)
== extract_constant(outshp_i, only_process_constants=True)
get_scalar_constant_value(
inpshp_i, only_process_constants=True, raise_not_constant=False
)
== get_scalar_constant_value(
outshp_i, only_process_constants=True, raise_not_constant=False
)
):
shape_match[dim] = True
continue
......@@ -909,11 +928,14 @@ def local_reshape_to_dimshuffle(fgraph, node):
new_output_shape = []
index = 0 # index over the output of the new reshape
for i in range(output.ndim):
# Since output_shape is a symbolic vector, we trust extract_constant
# Since output_shape is a symbolic vector, we trust get_scalar_constant_value
# to go through however it is formed to see if its i-th element is 1.
# We need only_process_constants=False for that.
dim = extract_constant(
output_shape[i], only_process_constants=False, elemwise=False
dim = get_scalar_constant_value(
output_shape[i],
only_process_constants=False,
elemwise=False,
raise_not_constant=False,
)
if dim == 1:
dimshuffle_new_order.append("x")
......
......@@ -26,7 +26,7 @@ from pytensor.tensor.basic import (
as_tensor,
cast,
concatenate,
extract_constant,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
register_infer_shape,
switch,
......@@ -390,8 +390,8 @@ def local_useless_slice(fgraph, node):
start = s.start
stop = s.stop
if start is not None and extract_constant(
start, only_process_constants=True
if start is not None and get_scalar_constant_value(
start, only_process_constants=True, raise_not_constant=False
) == (0 if positive_step else -1):
change_flag = True
start = None
......@@ -399,7 +399,9 @@ def local_useless_slice(fgraph, node):
if (
stop is not None
and x.type.shape[dim] is not None
and extract_constant(stop, only_process_constants=True)
and get_scalar_constant_value(
stop, only_process_constants=True, raise_not_constant=False
)
== (x.type.shape[dim] if positive_step else -x.type.shape[dim] - 1)
):
change_flag = True
......@@ -889,7 +891,10 @@ def local_useless_inc_subtensor(fgraph, node):
and e.stop is None
and (
e.step is None
or extract_constant(e.step, only_process_constants=True) == -1
or get_scalar_constant_value(
e.step, only_process_constants=True, raise_not_constant=False
)
== -1
)
for e in idx_cst
):
......@@ -1490,7 +1495,10 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node):
and
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
extract_constant(x, elemwise=False) != 0
get_underlying_scalar_constant_value(
x, elemwise=False, raise_not_constant=False
)
!= 0
):
return
......
......@@ -1383,11 +1383,11 @@ class TestLocalUselessElemwiseComparison:
if op == deep_copy_op:
assert len(elem.inputs) == 1, elem.inputs
assert isinstance(elem.inputs[0], TensorConstant), elem
assert pt.extract_constant(elem.inputs[0]) == val, val
assert pt.get_underlying_scalar_constant_value(elem.inputs[0]) == val, val
else:
assert len(elem.inputs) == 2, elem.inputs
assert isinstance(elem.inputs[0], TensorConstant), elem
assert pt.extract_constant(elem.inputs[0]) == val, val
assert pt.get_underlying_scalar_constant_value(elem.inputs[0]) == val, val
def assert_identity(self, f):
topo = f.maker.fgraph.toposort()
......
......@@ -46,7 +46,6 @@ from pytensor.tensor.basic import (
default,
diag,
expand_dims,
extract_constant,
eye,
fill,
flatnonzero,
......@@ -3574,10 +3573,10 @@ class TestGetUnderlyingScalarConstantValue:
# Make sure we do not return a writeable internal storage of a constant,
# so we cannot change the value of a constant by mistake.
c = constant(3)
d = extract_constant(c)
d = get_scalar_constant_value(c)
with pytest.raises(ValueError, match="output array is read-only"):
d += 1
e = extract_constant(c)
e = get_scalar_constant_value(c)
assert e == 3, (c, d, e)
@pytest.mark.parametrize("only_process_constants", (True, False))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论