提交 32aadc8c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Deprecate `extract_constant`

上级 aad6fb75
...@@ -54,6 +54,7 @@ from pytensor.scan.utils import ( ...@@ -54,6 +54,7 @@ from pytensor.scan.utils import (
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
AllocEmpty, AllocEmpty,
get_scalar_constant_value,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
) )
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
...@@ -665,8 +666,10 @@ def inner_sitsot_only_last_step_used( ...@@ -665,8 +666,10 @@ def inner_sitsot_only_last_step_used(
client = fgraph.clients[outer_var][0][0] client = fgraph.clients[outer_var][0][0]
if isinstance(client, Apply) and isinstance(client.op, Subtensor): if isinstance(client, Apply) and isinstance(client.op, Subtensor):
lst = get_idx_list(client.inputs, client.op.idx_list) lst = get_idx_list(client.inputs, client.op.idx_list)
if len(lst) == 1 and pt.extract_constant(lst[0]) == -1: return (
return True len(lst) == 1
and get_scalar_constant_value(lst[0], raise_not_constant=False) == -1
)
return False return False
...@@ -1341,10 +1344,17 @@ def scan_save_mem(fgraph, node): ...@@ -1341,10 +1344,17 @@ def scan_save_mem(fgraph, node):
if isinstance(this_slice[0], slice) and this_slice[0].stop is None: if isinstance(this_slice[0], slice) and this_slice[0].stop is None:
global_nsteps = None global_nsteps = None
if isinstance(cf_slice[0], slice): if isinstance(cf_slice[0], slice):
stop = pt.extract_constant(cf_slice[0].stop) stop = get_scalar_constant_value(
cf_slice[0].stop, raise_not_constant=False
)
else: else:
stop = pt.extract_constant(cf_slice[0]) + 1 stop = (
if stop == maxsize or stop == pt.extract_constant(length): get_scalar_constant_value(cf_slice[0], raise_not_constant=False)
+ 1
)
if stop == maxsize or stop == get_scalar_constant_value(
length, raise_not_constant=False
):
stop = None stop = None
else: else:
# there is a **gotcha** here ! Namely, scan returns an # there is a **gotcha** here ! Namely, scan returns an
...@@ -1448,9 +1458,13 @@ def scan_save_mem(fgraph, node): ...@@ -1448,9 +1458,13 @@ def scan_save_mem(fgraph, node):
cf_slice = get_canonical_form_slice(this_slice[0], length) cf_slice = get_canonical_form_slice(this_slice[0], length)
if isinstance(cf_slice[0], slice): if isinstance(cf_slice[0], slice):
start = pt.extract_constant(cf_slice[0].start) start = pt.get_scalar_constant_value(
cf_slice[0].start, raise_not_constant=False
)
else: else:
start = pt.extract_constant(cf_slice[0]) start = pt.get_scalar_constant_value(
cf_slice[0], raise_not_constant=False
)
if start == 0 or store_steps[i] == 0: if start == 0 or store_steps[i] == 0:
store_steps[i] = 0 store_steps[i] = 0
...@@ -1625,7 +1639,7 @@ def scan_save_mem(fgraph, node): ...@@ -1625,7 +1639,7 @@ def scan_save_mem(fgraph, node):
# 3.6 Compose the new scan # 3.6 Compose the new scan
# TODO: currently we don't support scan with 0 step. So # TODO: currently we don't support scan with 0 step. So
# don't create one. # don't create one.
if pt.extract_constant(node_ins[0]) == 0: if get_scalar_constant_value(node_ins[0], raise_not_constant=False) == 0:
return False return False
# Do not call make_node for test_value # Do not call make_node for test_value
......
...@@ -268,27 +268,7 @@ _scalar_constant_value_elemwise_ops = ( ...@@ -268,27 +268,7 @@ _scalar_constant_value_elemwise_ops = (
) )
def get_scalar_constant_value( def _get_underlying_scalar_constant_value(
v, elemwise=True, only_process_constants=False, max_recur=10
):
"""
Checks whether 'v' is a scalar (ndim = 0).
If 'v' is a scalar then this function fetches the underlying constant by calling
'get_underlying_scalar_constant_value()'.
If 'v' is not a scalar, it raises a NotScalarConstantError.
"""
if isinstance(v, Variable | np.ndarray):
if v.ndim != 0:
raise NotScalarConstantError()
return get_underlying_scalar_constant_value(
v, elemwise, only_process_constants, max_recur
)
def get_underlying_scalar_constant_value(
orig_v, elemwise=True, only_process_constants=False, max_recur=10 orig_v, elemwise=True, only_process_constants=False, max_recur=10
): ):
"""Return the constant scalar(0-D) value underlying variable `v`. """Return the constant scalar(0-D) value underlying variable `v`.
...@@ -381,7 +361,7 @@ def get_underlying_scalar_constant_value( ...@@ -381,7 +361,7 @@ def get_underlying_scalar_constant_value(
elif isinstance(op, CheckAndRaise): elif isinstance(op, CheckAndRaise):
# check if all conditions are constant and true # check if all conditions are constant and true
conds = [ conds = [
get_underlying_scalar_constant_value(c, max_recur=max_recur) _get_underlying_scalar_constant_value(c, max_recur=max_recur)
for c in v.owner.inputs[1:] for c in v.owner.inputs[1:]
] ]
if builtins.all(0 == c.ndim and c != 0 for c in conds): if builtins.all(0 == c.ndim and c != 0 for c in conds):
...@@ -395,7 +375,7 @@ def get_underlying_scalar_constant_value( ...@@ -395,7 +375,7 @@ def get_underlying_scalar_constant_value(
continue continue
if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops): if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops):
const = [ const = [
get_underlying_scalar_constant_value(i, max_recur=max_recur) _get_underlying_scalar_constant_value(i, max_recur=max_recur)
for i in v.owner.inputs for i in v.owner.inputs
] ]
ret = [[None]] ret = [[None]]
...@@ -414,7 +394,7 @@ def get_underlying_scalar_constant_value( ...@@ -414,7 +394,7 @@ def get_underlying_scalar_constant_value(
v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops
): ):
const = [ const = [
get_underlying_scalar_constant_value(i, max_recur=max_recur) _get_underlying_scalar_constant_value(i, max_recur=max_recur)
for i in v.owner.inputs for i in v.owner.inputs
] ]
ret = [[None]] ret = [[None]]
...@@ -457,7 +437,7 @@ def get_underlying_scalar_constant_value( ...@@ -457,7 +437,7 @@ def get_underlying_scalar_constant_value(
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, Type): if isinstance(idx, Type):
idx = get_underlying_scalar_constant_value( idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
try: try:
...@@ -491,14 +471,13 @@ def get_underlying_scalar_constant_value( ...@@ -491,14 +471,13 @@ def get_underlying_scalar_constant_value(
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, Type): if isinstance(idx, Type):
idx = get_underlying_scalar_constant_value( idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
# Python 2.4 does not support indexing with numpy.integer
# So we cast it.
idx = int(idx)
ret = v.owner.inputs[0].owner.inputs[idx] ret = v.owner.inputs[0].owner.inputs[idx]
ret = get_underlying_scalar_constant_value(ret, max_recur=max_recur) ret = _get_underlying_scalar_constant_value(
ret, max_recur=max_recur
)
# MakeVector can cast implicitly its input in some case. # MakeVector can cast implicitly its input in some case.
return np.asarray(ret, dtype=v.type.dtype) return np.asarray(ret, dtype=v.type.dtype)
...@@ -513,7 +492,7 @@ def get_underlying_scalar_constant_value( ...@@ -513,7 +492,7 @@ def get_underlying_scalar_constant_value(
idx_list = op.idx_list idx_list = op.idx_list
idx = idx_list[0] idx = idx_list[0]
if isinstance(idx, Type): if isinstance(idx, Type):
idx = get_underlying_scalar_constant_value( idx = _get_underlying_scalar_constant_value(
owner.inputs[1], max_recur=max_recur owner.inputs[1], max_recur=max_recur
) )
grandparent = leftmost_parent.owner.inputs[0] grandparent = leftmost_parent.owner.inputs[0]
...@@ -523,7 +502,9 @@ def get_underlying_scalar_constant_value( ...@@ -523,7 +502,9 @@ def get_underlying_scalar_constant_value(
grandparent.owner.op, Unbroadcast grandparent.owner.op, Unbroadcast
): ):
ggp_shape = grandparent.owner.inputs[0].type.shape ggp_shape = grandparent.owner.inputs[0].type.shape
l = [get_underlying_scalar_constant_value(s) for s in ggp_shape] l = [
_get_underlying_scalar_constant_value(s) for s in ggp_shape
]
gp_shape = tuple(l) gp_shape = tuple(l)
if not (idx < ndim): if not (idx < ndim):
...@@ -545,7 +526,7 @@ def get_underlying_scalar_constant_value( ...@@ -545,7 +526,7 @@ def get_underlying_scalar_constant_value(
if isinstance(grandparent, Constant): if isinstance(grandparent, Constant):
return np.asarray(np.shape(grandparent.data)[idx]) return np.asarray(np.shape(grandparent.data)[idx])
elif isinstance(op, CSM): elif isinstance(op, CSM):
data = get_underlying_scalar_constant_value( data = _get_underlying_scalar_constant_value(
v.owner.inputs, elemwise=elemwise, max_recur=max_recur v.owner.inputs, elemwise=elemwise, max_recur=max_recur
) )
# Sparse variable can only be constant if zero (or I guess if homogeneously dense) # Sparse variable can only be constant if zero (or I guess if homogeneously dense)
...@@ -556,6 +537,93 @@ def get_underlying_scalar_constant_value( ...@@ -556,6 +537,93 @@ def get_underlying_scalar_constant_value(
raise NotScalarConstantError() raise NotScalarConstantError()
def get_underlying_scalar_constant_value(
v,
*,
elemwise=True,
only_process_constants=False,
max_recur=10,
raise_not_constant=True,
):
"""Return the unique constant scalar(0-D) value underlying variable `v`.
If `v` is the output of dimshuffles, fills, allocs, etc,
cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise
and some pattern with Subtensor, this function digs through them.
If `v` is not some view of constant scalar data, then raise a
NotScalarConstantError.
This function performs symbolic reasoning about the value of `v`, as opposed to numerical reasoning by
constant folding the inputs of `v`.
Parameters
----------
v: Variable
elemwise : bool
If False, we won't try to go into elemwise. So this call is faster.
But we still investigate in Second Elemwise (as this is a substitute
for Alloc)
only_process_constants : bool
If True, we only attempt to obtain the value of `orig_v` if it's
directly constant and don't try to dig through dimshuffles, fills,
allocs, and other to figure out its value.
max_recur : int
The maximum number of recursion.
raise_not_constant: bool, default True
If True, raise a NotScalarConstantError if `v` does not have an
underlying constant scalar value. If False, return `v` as is.
Raises
------
NotScalarConstantError
`v` does not have an underlying constant scalar value.
Only rasise if raise_not_constant is True.
"""
try:
return _get_underlying_scalar_constant_value(
v,
elemwise=elemwise,
only_process_constants=only_process_constants,
max_recur=max_recur,
)
except NotScalarConstantError:
if raise_not_constant:
raise
return v
def get_scalar_constant_value(
v,
elemwise=True,
only_process_constants=False,
max_recur=10,
raise_not_constant: bool = True,
):
"""
Checks whether 'v' is a scalar (ndim = 0).
If 'v' is a scalar then this function fetches the underlying constant by calling
'get_underlying_scalar_constant_value()'.
If 'v' is not a scalar, it raises a NotScalarConstantError.
"""
if isinstance(v, TensorVariable | np.ndarray):
if v.ndim != 0:
print(v, v.ndim)
raise NotScalarConstantError("Input ndim != 0")
return get_underlying_scalar_constant_value(
v,
elemwise=elemwise,
only_process_constants=only_process_constants,
max_recur=max_recur,
raise_not_constant=raise_not_constant,
)
class TensorFromScalar(COp): class TensorFromScalar(COp):
__props__ = () __props__ = ()
...@@ -2012,16 +2080,16 @@ def extract_constant(x, elemwise=True, only_process_constants=False): ...@@ -2012,16 +2080,16 @@ def extract_constant(x, elemwise=True, only_process_constants=False):
ScalarVariable, we convert it to a tensor with tensor_from_scalar. ScalarVariable, we convert it to a tensor with tensor_from_scalar.
""" """
try: warnings.warn(
x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants) "extract_constant is deprecated. Use `get_underlying_scalar_constant_value(..., raise_not_constant=False)`",
except NotScalarConstantError: FutureWarning,
pass )
if isinstance(x, ps.ScalarVariable | ps.sharedvar.ScalarSharedVariable): return get_underlying_scalar_constant_value(
if x.owner and isinstance(x.owner.op, ScalarFromTensor): x,
x = x.owner.inputs[0] elemwise=elemwise,
else: only_process_constants=only_process_constants,
x = tensor_from_scalar(x) raise_not_constant=False,
return x )
def transpose(x, axes=None): def transpose(x, axes=None):
...@@ -4401,7 +4469,6 @@ __all__ = [ ...@@ -4401,7 +4469,6 @@ __all__ = [
"split", "split",
"transpose", "transpose",
"matrix_transpose", "matrix_transpose",
"extract_constant",
"default", "default",
"tensor_copy", "tensor_copy",
"transfer", "transfer",
......
...@@ -30,7 +30,7 @@ import pytensor.scalar.basic as ps ...@@ -30,7 +30,7 @@ import pytensor.scalar.basic as ps
from pytensor import compile, config from pytensor import compile, config
from pytensor.compile.ops import ViewOp from pytensor.compile.ops import ViewOp
from pytensor.graph import FunctionGraph from pytensor.graph import FunctionGraph
from pytensor.graph.basic import Constant, Variable from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter, NodeProcessingGraphRewriter,
NodeRewriter, NodeRewriter,
...@@ -55,8 +55,8 @@ from pytensor.tensor.basic import ( ...@@ -55,8 +55,8 @@ from pytensor.tensor.basic import (
as_tensor_variable, as_tensor_variable,
atleast_Nd, atleast_Nd,
cast, cast,
extract_constant,
fill, fill,
get_scalar_constant_value,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
join, join,
ones_like, ones_like,
...@@ -478,7 +478,12 @@ def local_alloc_sink_dimshuffle(fgraph, node): ...@@ -478,7 +478,12 @@ def local_alloc_sink_dimshuffle(fgraph, node):
output_shape = node.inputs[1:] output_shape = node.inputs[1:]
num_dims_with_size_1_added_to_left = 0 num_dims_with_size_1_added_to_left = 0
for i in range(len(output_shape) - inp.ndim): for i in range(len(output_shape) - inp.ndim):
if extract_constant(output_shape[i], only_process_constants=True) == 1: if (
get_scalar_constant_value(
output_shape[i], only_process_constants=True, raise_not_constant=False
)
== 1
):
num_dims_with_size_1_added_to_left += 1 num_dims_with_size_1_added_to_left += 1
else: else:
break break
...@@ -538,93 +543,90 @@ def local_useless_elemwise(fgraph, node): ...@@ -538,93 +543,90 @@ def local_useless_elemwise(fgraph, node):
xor(x, x) -> zeros_like(x) xor(x, x) -> zeros_like(x)
TODO: This implementation is painfully redundant. TODO: This implementation is painfully redundant.
TODO: Allow rewrite when useless input broadcasts output
""" """
if isinstance(node.op, Elemwise): out_bcast = node.outputs[0].type.broadcastable
# We call zeros_like and one_like with opt=True to generate a dtype = node.outputs[0].type.dtype
# cleaner graph. scalar_op = node.op.scalar_op
dtype = node.outputs[0].dtype
if isinstance(scalar_op, ps.EQ) and len(node.inputs) == 2:
if node.op.scalar_op == ps.eq and len(node.inputs) == 2: if node.inputs[0] is node.inputs[1]:
if node.inputs[0] == node.inputs[1]: # it is the same var in the graph. That will always be true
# it is the same var in the graph. That will always be true ret = ones_like(node.inputs[0], dtype=dtype, opt=True)
ret = ones_like(node.inputs[0], dtype=dtype, opt=True)
# Copy stack trace from input to constant output
# Copy stack trace from input to constant output copy_stack_trace(node.outputs[0], ret)
copy_stack_trace(node.outputs[0], ret) return [ret]
return [ret] elif isinstance(scalar_op, ps.NEQ | ps.XOR) and len(node.inputs) == 2:
elif node.op.scalar_op == ps.neq and len(node.inputs) == 2: if node.inputs[0] is node.inputs[1]:
if node.inputs[0] == node.inputs[1]: # it is the same var in the graph. That will always be false
# it is the same var in the graph. That will always be false ret = zeros_like(node.inputs[0], dtype=dtype, opt=True)
ret = zeros_like(node.inputs[0], dtype=dtype, opt=True)
# Copy stack trace from input to constant output
# Copy stack trace from input to constant output copy_stack_trace(node.outputs[0], ret)
copy_stack_trace(node.outputs[0], ret) return [ret]
return [ret]
elif (
elif node.op.scalar_op == ps.mul and len(node.inputs) == 1: isinstance(node.op.scalar_op, ps.Mul | ps.Add | ps.Identity)
# No need to copy over any stack trace and len(node.inputs) == 1
return [node.inputs[0]] ):
# No need to copy over any stack trace
elif node.op.scalar_op == ps.add and len(node.inputs) == 1: return [node.inputs[0]]
# No need to copy over any stack trace
return [node.inputs[0]] elif isinstance(node.op.scalar_op, ps.AND) and len(node.inputs) == 2:
elif node.op.scalar_op == ps.identity and len(node.inputs) == 1: if (
return [node.inputs[0]] isinstance(node.inputs[0], TensorConstant)
and node.inputs[1].type.broadcastable == out_bcast
elif isinstance(node.op.scalar_op, ps.AND) and len(node.inputs) == 2: ):
if isinstance(node.inputs[0], TensorConstant): const_val = node.inputs[0].unique_value
const_val = extract_constant( if const_val is not None:
node.inputs[0], only_process_constants=True if const_val == 0:
) return [zeros_like(node.inputs[1], dtype=dtype, opt=True)]
if not isinstance(const_val, Variable): elif node.outputs[0].dtype == "bool":
if const_val == 0: # If the output is not Boolean, it is the bitwise AND,
return [zeros_like(node.inputs[1], dtype=dtype, opt=True)] # and this rewrite would be wrong
elif node.outputs[0].dtype == "bool": return [node.inputs[1].astype(node.outputs[0].dtype)]
# If the output is not Boolean, it is the bitwise AND,
# and this rewrite would be wrong if (
return [node.inputs[1].astype(node.outputs[0].dtype)] isinstance(node.inputs[1], TensorConstant)
and node.inputs[0].type.broadcastable == out_bcast
if isinstance(node.inputs[1], TensorConstant): ):
const_val = extract_constant( const_val = node.inputs[1].unique_value
node.inputs[1], only_process_constants=True if const_val is not None:
) if const_val == 0:
if not isinstance(const_val, Variable): return [zeros_like(node.inputs[0], dtype=dtype, opt=True)]
if const_val == 0: elif node.outputs[0].dtype == "bool":
return [zeros_like(node.inputs[0], dtype=dtype, opt=True)] # If the output is not Boolean, it is the bitwise AND,
elif node.outputs[0].dtype == "bool": # and this rewrite would be wrong
# If the output is not Boolean, it is the bitwise AND, return [node.inputs[0].astype(node.outputs[0].dtype)]
# and this rewrite would be wrong
return [node.inputs[0].astype(node.outputs[0].dtype)] elif isinstance(node.op.scalar_op, ps.OR) and len(node.inputs) == 2:
if (
elif isinstance(node.op.scalar_op, ps.OR) and len(node.inputs) == 2: isinstance(node.inputs[0], TensorConstant)
if isinstance(node.inputs[0], TensorConstant): and node.inputs[1].type.broadcastable == out_bcast
const_val = extract_constant( ):
node.inputs[0], only_process_constants=True const_val = node.inputs[0].unique_value
) if const_val is not None:
if not isinstance(const_val, Variable): if const_val == 0:
if const_val == 0: return [node.inputs[1].astype(node.outputs[0].dtype)]
return [node.inputs[1].astype(node.outputs[0].dtype)] elif node.outputs[0].dtype == "bool":
elif node.outputs[0].dtype == "bool": # If the output is not Boolean, it is the bitwise OR,
# If the output is not Boolean, it is the bitwise OR, # and this rewrite would be wrong
# and this rewrite would be wrong return [ones_like(node.inputs[1], dtype=dtype, opt=True)]
return [ones_like(node.inputs[1], dtype=dtype, opt=True)]
if (
if isinstance(node.inputs[1], TensorConstant): isinstance(node.inputs[1], TensorConstant)
const_val = extract_constant( and node.inputs[0].type.broadcastable == out_bcast
node.inputs[1], only_process_constants=True ):
) const_val = node.inputs[1].unique_value
if not isinstance(const_val, Variable): if const_val is not None:
if const_val == 0: if const_val == 0:
return [node.inputs[0].astype(node.outputs[0].dtype)] return [node.inputs[0].astype(node.outputs[0].dtype)]
elif node.outputs[0].dtype == "bool": elif node.outputs[0].dtype == "bool":
# If the output is not Boolean, it is the bitwise OR, # If the output is not Boolean, it is the bitwise OR,
# and this rewrite would be wrong # and this rewrite would be wrong
return [ones_like(node.inputs[0], dtype=dtype, opt=True)] return [ones_like(node.inputs[0], dtype=dtype, opt=True)]
elif isinstance(node.op.scalar_op, ps.XOR) and len(node.inputs) == 2:
if node.inputs[0] is node.inputs[1]:
return [zeros_like(node.inputs[0], dtype=dtype, opt=True)]
@register_specialize @register_specialize
...@@ -988,13 +990,10 @@ def local_useless_switch(fgraph, node): ...@@ -988,13 +990,10 @@ def local_useless_switch(fgraph, node):
left = node.inputs[1] left = node.inputs[1]
right = node.inputs[2] right = node.inputs[2]
cond_var = node.inputs[0] cond_var = node.inputs[0]
cond = extract_constant(cond_var, only_process_constants=True)
out_bcast = node.outputs[0].type.broadcastable out_bcast = node.outputs[0].type.broadcastable
if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance( if isinstance(cond_var, TensorConstant) and cond_var.unique_value is not None:
cond, np.number | np.bool_ if cond_var.unique_value == 0:
):
if cond == 0:
correct_out = right correct_out = right
else: else:
correct_out = left correct_out = left
...@@ -1014,7 +1013,7 @@ def local_useless_switch(fgraph, node): ...@@ -1014,7 +1013,7 @@ def local_useless_switch(fgraph, node):
# if left is right -> left # if left is right -> left
if equivalent_up_to_constant_casting(left, right): if equivalent_up_to_constant_casting(left, right):
if left.type.broadcastable != out_bcast: if left.type.broadcastable != out_bcast:
left, _ = broadcast_arrays(left, cond) left, _ = broadcast_arrays(left, cond_var)
out_dtype = node.outputs[0].type.dtype out_dtype = node.outputs[0].type.dtype
if left.type.dtype != out_dtype: if left.type.dtype != out_dtype:
...@@ -1026,13 +1025,22 @@ def local_useless_switch(fgraph, node): ...@@ -1026,13 +1025,22 @@ def local_useless_switch(fgraph, node):
# This case happens with scan. # This case happens with scan.
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) # Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
if ( if (
cond_var.owner node.outputs[0].type.ndim == 0
and cond_var.owner
and isinstance(cond_var.owner.op, Elemwise) and isinstance(cond_var.owner.op, Elemwise)
and isinstance(cond_var.owner.op.scalar_op, ps.LE) and isinstance(cond_var.owner.op.scalar_op, ps.LE)
and cond_var.owner.inputs[0].owner and cond_var.owner.inputs[0].owner
and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i) and isinstance(cond_var.owner.inputs[0].owner.op, Shape_i)
and extract_constant(cond_var.owner.inputs[1], only_process_constants=True) == 0 and get_scalar_constant_value(
and extract_constant(left, only_process_constants=True) == 0 cond_var.owner.inputs[1],
only_process_constants=True,
raise_not_constant=False,
)
== 0
and get_scalar_constant_value(
left, only_process_constants=True, raise_not_constant=False
)
== 0
and right == cond_var.owner.inputs[0] and right == cond_var.owner.inputs[0]
): ):
assert node.outputs[0].type.is_super(right.type) assert node.outputs[0].type.is_super(right.type)
......
...@@ -28,7 +28,6 @@ from pytensor.tensor.basic import ( ...@@ -28,7 +28,6 @@ from pytensor.tensor.basic import (
as_tensor_variable, as_tensor_variable,
cast, cast,
constant, constant,
extract_constant,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
moveaxis, moveaxis,
ones_like, ones_like,
...@@ -566,11 +565,14 @@ def local_expm1(fgraph, node): ...@@ -566,11 +565,14 @@ def local_expm1(fgraph, node):
in1.owner in1.owner
and isinstance(in1.owner.op, Elemwise) and isinstance(in1.owner.op, Elemwise)
and isinstance(in1.owner.op.scalar_op, ps.Exp) and isinstance(in1.owner.op.scalar_op, ps.Exp)
and extract_constant(in2, only_process_constants=False) == 1 and get_underlying_scalar_constant_value(in2, raise_not_constant=False) == 1
): ):
in11 = in1.owner.inputs[0] in11 = in1.owner.inputs[0]
new_out = expm1(in11) new_out = expm1(in11)
if new_out.type.broadcastable != out.type.broadcastable:
new_out = broadcast_arrays(in11, in2)[0]
if new_out.dtype != out.dtype: if new_out.dtype != out.dtype:
new_out = cast(new_out, dtype=out.dtype) new_out = cast(new_out, dtype=out.dtype)
...@@ -1345,12 +1347,13 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -1345,12 +1347,13 @@ def local_useless_elemwise_comparison(fgraph, node):
the graph easier to read. the graph easier to read.
""" """
# TODO: Refactor this function. So much repeated code!
if node.op.scalar_op.nin != 2: if node.op.scalar_op.nin != 2:
return return
# We call zeros_like and one_like with opt=True to generate a dtype = node.outputs[0].type.dtype
# cleaner graph. out_bcast = node.outputs[0].type.broadcastable
dtype = node.outputs[0].dtype
# Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X) # Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X)
if ( if (
...@@ -1361,6 +1364,7 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -1361,6 +1364,7 @@ def local_useless_elemwise_comparison(fgraph, node):
# Copy over stacktrace from previous output. # Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res) copy_stack_trace(node.outputs, res)
return [res] return [res]
# Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X) # Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X)
if ( if (
isinstance(node.op.scalar_op, ps.LE | ps.GE) isinstance(node.op.scalar_op, ps.LE | ps.GE)
...@@ -1371,6 +1375,7 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -1371,6 +1375,7 @@ def local_useless_elemwise_comparison(fgraph, node):
# Copy over stacktrace from previous output. # Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res) copy_stack_trace(node.outputs, res)
return [res] return [res]
# Elemwise[{minimum,maximum}](X, X) -> X # Elemwise[{minimum,maximum}](X, X) -> X
if ( if (
isinstance(node.op.scalar_op, ps.ScalarMinimum | ps.ScalarMaximum) isinstance(node.op.scalar_op, ps.ScalarMinimum | ps.ScalarMaximum)
...@@ -1386,64 +1391,72 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -1386,64 +1391,72 @@ def local_useless_elemwise_comparison(fgraph, node):
isinstance(node.op.scalar_op, ps.LT) isinstance(node.op.scalar_op, ps.LT)
and node.inputs[0].owner and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Shape_i) and isinstance(node.inputs[0].owner.op, Shape_i)
and extract_constant(node.inputs[1], only_process_constants=True) == 0 and get_underlying_scalar_constant_value(
node.inputs[1], only_process_constants=True, raise_not_constant=False
)
== 0
): ):
res = zeros_like(node.inputs[0], dtype=dtype, opt=True) res = zeros_like(node.inputs[0], dtype=dtype, opt=True)
if res.type.broadcastable != out_bcast:
res = broadcast_arrays(res, node.inputs[1])[0]
# Copy over stacktrace from previous output. # Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res) copy_stack_trace(node.outputs, res)
return [res] return [res]
# Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X) # Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
if ( if (
isinstance(node.op.scalar_op, ps.GE) isinstance(node.op.scalar_op, ps.GE)
and node.inputs[0].owner and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Shape_i) and isinstance(node.inputs[0].owner.op, Shape_i)
and extract_constant(node.inputs[1], only_process_constants=True) == 0 and get_underlying_scalar_constant_value(
node.inputs[1], only_process_constants=True, raise_not_constant=False
)
== 0
): ):
res = ones_like(node.inputs[0], dtype=dtype, opt=True) res = ones_like(node.inputs[0], dtype=dtype, opt=True)
if res.type.broadcastable != out_bcast:
res = broadcast_arrays(res, node.inputs[1])[0]
# Copy over stacktrace from previous output. # Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res) copy_stack_trace(node.outputs, res)
return [res] return [res]
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i] # Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
if ( if isinstance(node.op.scalar_op, ps.ScalarMaximum):
isinstance(node.op.scalar_op, ps.ScalarMaximum) for idx in range(2):
and node.inputs[0].owner if (
and isinstance(node.inputs[0].owner.op, Shape_i) node.inputs[idx].owner
and extract_constant(node.inputs[1], only_process_constants=True) == 0 and isinstance(node.inputs[idx].owner.op, Shape_i)
): and get_underlying_scalar_constant_value(
# No need to copy over stacktrace. node.inputs[1 - idx],
return [node.inputs[0]] only_process_constants=True,
# Elemwise[maximum](0, X.shape[i]) -> X.shape[i] raise_not_constant=False,
if ( )
isinstance(node.op.scalar_op, ps.ScalarMaximum) == 0
and extract_constant(node.inputs[0], only_process_constants=True) == 0 ):
and node.inputs[1].owner res = node.inputs[idx]
and isinstance(node.inputs[1].owner.op, Shape_i) if res.type.broadcastable != out_bcast:
): res = broadcast_arrays(res, node.inputs[1 - idx])[0]
# No need to copy over stacktrace. # No need to copy over stacktrace.
return [node.inputs[1]] return [res]
# Elemwise[minimum](X.shape[i], 0) -> 0
if (
isinstance(node.op.scalar_op, ps.ScalarMinimum)
and node.inputs[0].owner
and isinstance(node.inputs[0].owner.op, Shape_i)
and extract_constant(node.inputs[1], only_process_constants=True) == 0
):
res = zeros_like(node.inputs[0], dtype=dtype, opt=True)
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
# Elemwise[minimum](0, X.shape[i]) -> 0 # Elemwise[minimum](X.shape[i], 0) -> 0
if ( if isinstance(node.op.scalar_op, ps.ScalarMinimum):
isinstance(node.op.scalar_op, ps.ScalarMinimum) for idx in range(2):
and extract_constant(node.inputs[0], only_process_constants=True) == 0 if (
and node.inputs[1].owner node.inputs[idx].owner
and isinstance(node.inputs[1].owner.op, Shape_i) and isinstance(node.inputs[idx].owner.op, Shape_i)
): and get_underlying_scalar_constant_value(
res = zeros_like(node.inputs[1], dtype=dtype, opt=True) node.inputs[1 - idx],
# Copy over stacktrace from previous output. only_process_constants=True,
copy_stack_trace(node.outputs, res) raise_not_constant=False,
return [res] )
== 0
):
res = zeros_like(node.inputs[idx], dtype=dtype, opt=True)
if res.type.broadcastable != out_bcast:
res = broadcast_arrays(res, node.inputs[1 - idx])[0]
# No need to copy over stacktrace.
return [res]
# Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X) # Elemwise[LT](add([anything that is shapes]), 0) -> Elemwise[zeros](X)
if ( if (
...@@ -1455,12 +1468,18 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -1455,12 +1468,18 @@ def local_useless_elemwise_comparison(fgraph, node):
isinstance(var.owner and var.owner.op, Shape_i) isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs for var in node.inputs[0].owner.inputs
) )
and extract_constant(node.inputs[1], only_process_constants=True) == 0 and get_underlying_scalar_constant_value(
node.inputs[1], only_process_constants=True, raise_not_constant=False
)
== 0
): ):
res = zeros_like(node.inputs[0], dtype=dtype, opt=True) res = zeros_like(node.inputs[0], dtype=dtype, opt=True)
if res.type.broadcastable != out_bcast:
res = broadcast_arrays(res, node.inputs[1])[0]
# Copy over stacktrace from previous output. # Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res) copy_stack_trace(node.outputs, res)
return [res] return [res]
# Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X) # Elemwise[GE](add([anything that is shapes]), 0) -> Elemwise[ones](X)
if ( if (
isinstance(node.op.scalar_op, ps.GE) isinstance(node.op.scalar_op, ps.GE)
...@@ -1471,57 +1490,61 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -1471,57 +1490,61 @@ def local_useless_elemwise_comparison(fgraph, node):
isinstance(var.owner and var.owner.op, Shape_i) isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs for var in node.inputs[0].owner.inputs
) )
and extract_constant(node.inputs[1], only_process_constants=True) == 0 and get_underlying_scalar_constant_value(
node.inputs[1], only_process_constants=True, raise_not_constant=False
)
== 0
): ):
res = ones_like(node.inputs[0], dtype=dtype, opt=True) res = ones_like(node.inputs[0], dtype=dtype, opt=True)
if res.type.broadcastable != out_bcast:
res = broadcast_arrays(res, node.inputs[1])[0]
# Copy over stacktrace from previous output. # Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res) copy_stack_trace(node.outputs, res)
return [res] return [res]
# Elemwise[EQ](Subtensor(Shape(x)), -N) # Elemwise[EQ](Subtensor(Shape(x)), -N)
# Elemwise[EQ](somegraph that only depend of shape, -N) # Elemwise[EQ](somegraph that only depend of shape, -N)
# TODO: handle the case where the -N is on either side # TODO: handle the case where the -N is on either side
""" """
|Elemwise{eq,no_inplace} [id B] '' |Elemwise{eq,no_inplace} [id B] ''
| |Subtensor{int64} [id C] '' | |Subtensor{int64} [id C] ''
| | |Join [id D] '' | | |Join [id D] ''
| | | |TensorConstant{0} [id E] | | | |TensorConstant{0} [id E]
| | | |Subtensor{int64:int64:} [id F] '' | | | |Subtensor{int64:int64:} [id F] ''
| | | | |Shape [id G] '' | | | | |Shape [id G] ''
""" """
def investigate(node): def investigate_if_shape(node) -> bool:
"Return True if values will be shapes, so >= 0" "Return True if values will be shapes, so >= 0"
if isinstance(node.op, Shape | Shape_i): if isinstance(node.op, Shape | Shape_i):
return True return True
elif isinstance(node.op, Subtensor) and node.inputs[0].owner: elif isinstance(node.op, Subtensor) and node.inputs[0].owner:
return investigate(node.inputs[0].owner) return investigate_if_shape(node.inputs[0].owner)
elif isinstance(node.op, Join): elif isinstance(node.op, Join):
return all(v.owner and investigate(v.owner) for v in node.inputs[1:]) return all(
v.owner and investigate_if_shape(v.owner) for v in node.inputs[1:]
)
elif isinstance(node.op, MakeVector): elif isinstance(node.op, MakeVector):
return all(v.owner and investigate(v.owner) for v in node.inputs) return all(v.owner and investigate_if_shape(v.owner) for v in node.inputs)
return False
if ( if (
isinstance(node.op.scalar_op, ps.EQ) isinstance(node.op.scalar_op, ps.EQ)
and node.inputs[0].owner and node.inputs[0].owner
and investigate(node.inputs[0].owner) and investigate_if_shape(node.inputs[0].owner)
and (
isinstance(node.inputs[1], TensorConstant)
and node.inputs[1].unique_value is not None
and node.inputs[1].unique_value < 0
)
): ):
try: res = zeros_like(node.inputs[0], dtype=dtype, opt=True)
cst = get_underlying_scalar_constant_value( if res.type.broadcastable != out_bcast:
node.inputs[1], only_process_constants=True res = broadcast_arrays(res, node.inputs[1])[0]
) # Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
res = zeros_like(node.inputs[0], dtype=dtype, opt=True) return [res]
if cst < 0:
# Copy over stacktrace from previous output.
copy_stack_trace(node.outputs, res)
return [res]
except NotScalarConstantError:
pass
return return
...@@ -2223,12 +2246,21 @@ def local_log1p(fgraph, node): ...@@ -2223,12 +2246,21 @@ def local_log1p(fgraph, node):
return [alloc_like(log1p(ninp), node.outputs[0], fgraph)] return [alloc_like(log1p(ninp), node.outputs[0], fgraph)]
elif log_arg.owner and log_arg.owner.op == sub: elif log_arg.owner and log_arg.owner.op == sub:
one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True) one, other = log_arg.owner.inputs
try:
one = get_underlying_scalar_constant_value(one, only_process_constants=True)
except NotScalarConstantError:
return
if one != 1: if one != 1:
return return
other = log_arg.owner.inputs[1]
if other.dtype != log_arg.dtype: if other.type.broadcastable != log_arg.type.broadcastable:
other = broadcast_arrays(other, one)[0]
if other.type.dtype != log_arg.type.dtype:
other = other.astype(log_arg.dtype) other = other.astype(log_arg.dtype)
return [log1p(neg(other))] return [log1p(neg(other))]
......
...@@ -22,7 +22,7 @@ from pytensor.tensor.basic import ( ...@@ -22,7 +22,7 @@ from pytensor.tensor.basic import (
as_tensor_variable, as_tensor_variable,
cast, cast,
constant, constant,
extract_constant, get_scalar_constant_value,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
register_infer_shape, register_infer_shape,
stack, stack,
...@@ -354,7 +354,9 @@ class ShapeFeature(Feature): ...@@ -354,7 +354,9 @@ class ShapeFeature(Feature):
not hasattr(r.type, "shape") not hasattr(r.type, "shape")
or r.type.shape[i] != 1 or r.type.shape[i] != 1
or self.lscalar_one.equals(shape_vars[i]) or self.lscalar_one.equals(shape_vars[i])
or self.lscalar_one.equals(extract_constant(shape_vars[i])) or self.lscalar_one.equals(
get_scalar_constant_value(shape_vars[i], raise_not_constant=False)
)
for i in range(r.type.ndim) for i in range(r.type.ndim)
) )
self.shape_of[r] = tuple(shape_vars) self.shape_of[r] = tuple(shape_vars)
...@@ -450,7 +452,11 @@ class ShapeFeature(Feature): ...@@ -450,7 +452,11 @@ class ShapeFeature(Feature):
) )
or self.lscalar_one.equals(merged_shape[i]) or self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals( or self.lscalar_one.equals(
extract_constant(merged_shape[i], only_process_constants=True) get_underlying_scalar_constant_value(
merged_shape[i],
only_process_constants=True,
raise_not_constant=False,
)
) )
for i in range(r.type.ndim) for i in range(r.type.ndim)
) )
...@@ -474,7 +480,11 @@ class ShapeFeature(Feature): ...@@ -474,7 +480,11 @@ class ShapeFeature(Feature):
not hasattr(r.type, "shape") not hasattr(r.type, "shape")
or r.type.shape[idx] != 1 or r.type.shape[idx] != 1
or self.lscalar_one.equals(new_shape[idx]) or self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals(extract_constant(new_shape[idx])) or self.lscalar_one.equals(
get_underlying_scalar_constant_value(
new_shape[idx], raise_not_constant=False
)
)
for idx in range(r.type.ndim) for idx in range(r.type.ndim)
) )
self.shape_of[r] = tuple(new_shape) self.shape_of[r] = tuple(new_shape)
...@@ -847,7 +857,10 @@ def local_useless_reshape(fgraph, node): ...@@ -847,7 +857,10 @@ def local_useless_reshape(fgraph, node):
outshp_i.owner outshp_i.owner
and isinstance(outshp_i.owner.op, Subtensor) and isinstance(outshp_i.owner.op, Subtensor)
and len(outshp_i.owner.inputs) == 2 and len(outshp_i.owner.inputs) == 2
and extract_constant(outshp_i.owner.inputs[1]) == dim and get_scalar_constant_value(
outshp_i.owner.inputs[1], raise_not_constant=False
)
== dim
): ):
subtensor_inp = outshp_i.owner.inputs[0] subtensor_inp = outshp_i.owner.inputs[0]
if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape): if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape):
...@@ -857,7 +870,9 @@ def local_useless_reshape(fgraph, node): ...@@ -857,7 +870,9 @@ def local_useless_reshape(fgraph, node):
continue continue
# Match constant if input.type.shape[dim] == constant # Match constant if input.type.shape[dim] == constant
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) cst_outshp_i = get_scalar_constant_value(
outshp_i, only_process_constants=True, raise_not_constant=False
)
if inp.type.shape[dim] == cst_outshp_i: if inp.type.shape[dim] == cst_outshp_i:
shape_match[dim] = True shape_match[dim] = True
continue continue
...@@ -872,8 +887,12 @@ def local_useless_reshape(fgraph, node): ...@@ -872,8 +887,12 @@ def local_useless_reshape(fgraph, node):
if shape_feature: if shape_feature:
inpshp_i = shape_feature.get_shape(inp, dim) inpshp_i = shape_feature.get_shape(inp, dim)
if inpshp_i == outshp_i or ( if inpshp_i == outshp_i or (
extract_constant(inpshp_i, only_process_constants=True) get_scalar_constant_value(
== extract_constant(outshp_i, only_process_constants=True) inpshp_i, only_process_constants=True, raise_not_constant=False
)
== get_scalar_constant_value(
outshp_i, only_process_constants=True, raise_not_constant=False
)
): ):
shape_match[dim] = True shape_match[dim] = True
continue continue
...@@ -909,11 +928,14 @@ def local_reshape_to_dimshuffle(fgraph, node): ...@@ -909,11 +928,14 @@ def local_reshape_to_dimshuffle(fgraph, node):
new_output_shape = [] new_output_shape = []
index = 0 # index over the output of the new reshape index = 0 # index over the output of the new reshape
for i in range(output.ndim): for i in range(output.ndim):
# Since output_shape is a symbolic vector, we trust extract_constant # Since output_shape is a symbolic vector, we trust get_scalar_constant_value
# to go through however it is formed to see if its i-th element is 1. # to go through however it is formed to see if its i-th element is 1.
# We need only_process_constants=False for that. # We need only_process_constants=False for that.
dim = extract_constant( dim = get_scalar_constant_value(
output_shape[i], only_process_constants=False, elemwise=False output_shape[i],
only_process_constants=False,
elemwise=False,
raise_not_constant=False,
) )
if dim == 1: if dim == 1:
dimshuffle_new_order.append("x") dimshuffle_new_order.append("x")
......
...@@ -26,7 +26,7 @@ from pytensor.tensor.basic import ( ...@@ -26,7 +26,7 @@ from pytensor.tensor.basic import (
as_tensor, as_tensor,
cast, cast,
concatenate, concatenate,
extract_constant, get_scalar_constant_value,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
register_infer_shape, register_infer_shape,
switch, switch,
...@@ -390,8 +390,8 @@ def local_useless_slice(fgraph, node): ...@@ -390,8 +390,8 @@ def local_useless_slice(fgraph, node):
start = s.start start = s.start
stop = s.stop stop = s.stop
if start is not None and extract_constant( if start is not None and get_scalar_constant_value(
start, only_process_constants=True start, only_process_constants=True, raise_not_constant=False
) == (0 if positive_step else -1): ) == (0 if positive_step else -1):
change_flag = True change_flag = True
start = None start = None
...@@ -399,7 +399,9 @@ def local_useless_slice(fgraph, node): ...@@ -399,7 +399,9 @@ def local_useless_slice(fgraph, node):
if ( if (
stop is not None stop is not None
and x.type.shape[dim] is not None and x.type.shape[dim] is not None
and extract_constant(stop, only_process_constants=True) and get_scalar_constant_value(
stop, only_process_constants=True, raise_not_constant=False
)
== (x.type.shape[dim] if positive_step else -x.type.shape[dim] - 1) == (x.type.shape[dim] if positive_step else -x.type.shape[dim] - 1)
): ):
change_flag = True change_flag = True
...@@ -889,7 +891,10 @@ def local_useless_inc_subtensor(fgraph, node): ...@@ -889,7 +891,10 @@ def local_useless_inc_subtensor(fgraph, node):
and e.stop is None and e.stop is None
and ( and (
e.step is None e.step is None
or extract_constant(e.step, only_process_constants=True) == -1 or get_scalar_constant_value(
e.step, only_process_constants=True, raise_not_constant=False
)
== -1
) )
for e in idx_cst for e in idx_cst
): ):
...@@ -1490,7 +1495,10 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node): ...@@ -1490,7 +1495,10 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node):
and and
# Don't use only_process_constants=True. We need to # Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape. # investigate Alloc of 0s but with non constant shape.
extract_constant(x, elemwise=False) != 0 get_underlying_scalar_constant_value(
x, elemwise=False, raise_not_constant=False
)
!= 0
): ):
return return
......
...@@ -1383,11 +1383,11 @@ class TestLocalUselessElemwiseComparison: ...@@ -1383,11 +1383,11 @@ class TestLocalUselessElemwiseComparison:
if op == deep_copy_op: if op == deep_copy_op:
assert len(elem.inputs) == 1, elem.inputs assert len(elem.inputs) == 1, elem.inputs
assert isinstance(elem.inputs[0], TensorConstant), elem assert isinstance(elem.inputs[0], TensorConstant), elem
assert pt.extract_constant(elem.inputs[0]) == val, val assert pt.get_underlying_scalar_constant_value(elem.inputs[0]) == val, val
else: else:
assert len(elem.inputs) == 2, elem.inputs assert len(elem.inputs) == 2, elem.inputs
assert isinstance(elem.inputs[0], TensorConstant), elem assert isinstance(elem.inputs[0], TensorConstant), elem
assert pt.extract_constant(elem.inputs[0]) == val, val assert pt.get_underlying_scalar_constant_value(elem.inputs[0]) == val, val
def assert_identity(self, f): def assert_identity(self, f):
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
......
...@@ -46,7 +46,6 @@ from pytensor.tensor.basic import ( ...@@ -46,7 +46,6 @@ from pytensor.tensor.basic import (
default, default,
diag, diag,
expand_dims, expand_dims,
extract_constant,
eye, eye,
fill, fill,
flatnonzero, flatnonzero,
...@@ -3574,10 +3573,10 @@ class TestGetUnderlyingScalarConstantValue: ...@@ -3574,10 +3573,10 @@ class TestGetUnderlyingScalarConstantValue:
# Make sure we do not return a writeable internal storage of a constant, # Make sure we do not return a writeable internal storage of a constant,
# so we cannot change the value of a constant by mistake. # so we cannot change the value of a constant by mistake.
c = constant(3) c = constant(3)
d = extract_constant(c) d = get_scalar_constant_value(c)
with pytest.raises(ValueError, match="output array is read-only"): with pytest.raises(ValueError, match="output array is read-only"):
d += 1 d += 1
e = extract_constant(c) e = get_scalar_constant_value(c)
assert e == 3, (c, d, e) assert e == 3, (c, d, e)
@pytest.mark.parametrize("only_process_constants", (True, False)) @pytest.mark.parametrize("only_process_constants", (True, False))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论