提交 32aadc8c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Deprecate `extract_constant`

上级 aad6fb75
...@@ -54,6 +54,7 @@ from pytensor.scan.utils import ( ...@@ -54,6 +54,7 @@ from pytensor.scan.utils import (
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
Alloc, Alloc,
AllocEmpty, AllocEmpty,
get_scalar_constant_value,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
) )
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
...@@ -665,8 +666,10 @@ def inner_sitsot_only_last_step_used( ...@@ -665,8 +666,10 @@ def inner_sitsot_only_last_step_used(
client = fgraph.clients[outer_var][0][0] client = fgraph.clients[outer_var][0][0]
if isinstance(client, Apply) and isinstance(client.op, Subtensor): if isinstance(client, Apply) and isinstance(client.op, Subtensor):
lst = get_idx_list(client.inputs, client.op.idx_list) lst = get_idx_list(client.inputs, client.op.idx_list)
if len(lst) == 1 and pt.extract_constant(lst[0]) == -1: return (
return True len(lst) == 1
and get_scalar_constant_value(lst[0], raise_not_constant=False) == -1
)
return False return False
...@@ -1341,10 +1344,17 @@ def scan_save_mem(fgraph, node): ...@@ -1341,10 +1344,17 @@ def scan_save_mem(fgraph, node):
if isinstance(this_slice[0], slice) and this_slice[0].stop is None: if isinstance(this_slice[0], slice) and this_slice[0].stop is None:
global_nsteps = None global_nsteps = None
if isinstance(cf_slice[0], slice): if isinstance(cf_slice[0], slice):
stop = pt.extract_constant(cf_slice[0].stop) stop = get_scalar_constant_value(
cf_slice[0].stop, raise_not_constant=False
)
else: else:
stop = pt.extract_constant(cf_slice[0]) + 1 stop = (
if stop == maxsize or stop == pt.extract_constant(length): get_scalar_constant_value(cf_slice[0], raise_not_constant=False)
+ 1
)
if stop == maxsize or stop == get_scalar_constant_value(
length, raise_not_constant=False
):
stop = None stop = None
else: else:
# there is a **gotcha** here ! Namely, scan returns an # there is a **gotcha** here ! Namely, scan returns an
...@@ -1448,9 +1458,13 @@ def scan_save_mem(fgraph, node): ...@@ -1448,9 +1458,13 @@ def scan_save_mem(fgraph, node):
cf_slice = get_canonical_form_slice(this_slice[0], length) cf_slice = get_canonical_form_slice(this_slice[0], length)
if isinstance(cf_slice[0], slice): if isinstance(cf_slice[0], slice):
start = pt.extract_constant(cf_slice[0].start) start = pt.get_scalar_constant_value(
cf_slice[0].start, raise_not_constant=False
)
else: else:
start = pt.extract_constant(cf_slice[0]) start = pt.get_scalar_constant_value(
cf_slice[0], raise_not_constant=False
)
if start == 0 or store_steps[i] == 0: if start == 0 or store_steps[i] == 0:
store_steps[i] = 0 store_steps[i] = 0
...@@ -1625,7 +1639,7 @@ def scan_save_mem(fgraph, node): ...@@ -1625,7 +1639,7 @@ def scan_save_mem(fgraph, node):
# 3.6 Compose the new scan # 3.6 Compose the new scan
# TODO: currently we don't support scan with 0 step. So # TODO: currently we don't support scan with 0 step. So
# don't create one. # don't create one.
if pt.extract_constant(node_ins[0]) == 0: if get_scalar_constant_value(node_ins[0], raise_not_constant=False) == 0:
return False return False
# Do not call make_node for test_value # Do not call make_node for test_value
......
...@@ -268,27 +268,7 @@ _scalar_constant_value_elemwise_ops = ( ...@@ -268,27 +268,7 @@ _scalar_constant_value_elemwise_ops = (
) )
def get_scalar_constant_value( def _get_underlying_scalar_constant_value(
v, elemwise=True, only_process_constants=False, max_recur=10
):
"""
Checks whether 'v' is a scalar (ndim = 0).
If 'v' is a scalar then this function fetches the underlying constant by calling
'get_underlying_scalar_constant_value()'.
If 'v' is not a scalar, it raises a NotScalarConstantError.
"""
if isinstance(v, Variable | np.ndarray):
if v.ndim != 0:
raise NotScalarConstantError()
return get_underlying_scalar_constant_value(
v, elemwise, only_process_constants, max_recur
)
def get_underlying_scalar_constant_value(
orig_v, elemwise=True, only_process_constants=False, max_recur=10 orig_v, elemwise=True, only_process_constants=False, max_recur=10
): ):
"""Return the constant scalar(0-D) value underlying variable `v`. """Return the constant scalar(0-D) value underlying variable `v`.
...@@ -381,7 +361,7 @@ def get_underlying_scalar_constant_value( ...@@ -381,7 +361,7 @@ def get_underlying_scalar_constant_value(
elif isinstance(op, CheckAndRaise): elif isinstance(op, CheckAndRaise):
# check if all conditions are constant and true # check if all conditions are constant and true
conds = [ conds = [
get_underlying_scalar_constant_value(c, max_recur=max_recur) _get_underlying_scalar_constant_value(c, max_recur=max_recur)
for c in v.owner.inputs[1:] for c in v.owner.inputs[1:]
] ]
if builtins.all(0 == c.ndim and c != 0 for c in conds): if builtins.all(0 == c.ndim and c != 0 for c in conds):
...@@ -395,7 +375,7 @@ def get_underlying_scalar_constant_value( ...@@ -395,7 +375,7 @@ def get_underlying_scalar_constant_value(
continue continue
if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops): if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops):
const = [ const = [
get_underlying_scalar_constant_value(i, max_recur=max_recur) _get_underlying_scalar_constant_value(i, max_recur=max_recur)
for i in v.owner.inputs for i in v.owner.inputs
] ]
ret = [[None]] ret = [[None]]
...@@ -414,7 +394,7 @@ def get_underlying_scalar_constant_value( ...@@ -414,7 +394,7 @@ def get_underlying_scalar_constant_value(
v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops
): ):
const = [ const = [
get_underlying_scalar_constant_value(i, max_recur=max_recur) _get_underlying_scalar_constant_value(i, max_recur=max_recur)
for i in v.owner.inputs for i in v.owner.inputs
] ]
ret = [[None]] ret = [[None]]
...@@ -457,7 +437,7 @@ def get_underlying_scalar_constant_value( ...@@ -457,7 +437,7 @@ def get_underlying_scalar_constant_value(
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, Type): if isinstance(idx, Type):
idx = get_underlying_scalar_constant_value( idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
try: try:
...@@ -491,14 +471,13 @@ def get_underlying_scalar_constant_value( ...@@ -491,14 +471,13 @@ def get_underlying_scalar_constant_value(
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, Type): if isinstance(idx, Type):
idx = get_underlying_scalar_constant_value( idx = _get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
# Python 2.4 does not support indexing with numpy.integer
# So we cast it.
idx = int(idx)
ret = v.owner.inputs[0].owner.inputs[idx] ret = v.owner.inputs[0].owner.inputs[idx]
ret = get_underlying_scalar_constant_value(ret, max_recur=max_recur) ret = _get_underlying_scalar_constant_value(
ret, max_recur=max_recur
)
# MakeVector can cast implicitly its input in some case. # MakeVector can cast implicitly its input in some case.
return np.asarray(ret, dtype=v.type.dtype) return np.asarray(ret, dtype=v.type.dtype)
...@@ -513,7 +492,7 @@ def get_underlying_scalar_constant_value( ...@@ -513,7 +492,7 @@ def get_underlying_scalar_constant_value(
idx_list = op.idx_list idx_list = op.idx_list
idx = idx_list[0] idx = idx_list[0]
if isinstance(idx, Type): if isinstance(idx, Type):
idx = get_underlying_scalar_constant_value( idx = _get_underlying_scalar_constant_value(
owner.inputs[1], max_recur=max_recur owner.inputs[1], max_recur=max_recur
) )
grandparent = leftmost_parent.owner.inputs[0] grandparent = leftmost_parent.owner.inputs[0]
...@@ -523,7 +502,9 @@ def get_underlying_scalar_constant_value( ...@@ -523,7 +502,9 @@ def get_underlying_scalar_constant_value(
grandparent.owner.op, Unbroadcast grandparent.owner.op, Unbroadcast
): ):
ggp_shape = grandparent.owner.inputs[0].type.shape ggp_shape = grandparent.owner.inputs[0].type.shape
l = [get_underlying_scalar_constant_value(s) for s in ggp_shape] l = [
_get_underlying_scalar_constant_value(s) for s in ggp_shape
]
gp_shape = tuple(l) gp_shape = tuple(l)
if not (idx < ndim): if not (idx < ndim):
...@@ -545,7 +526,7 @@ def get_underlying_scalar_constant_value( ...@@ -545,7 +526,7 @@ def get_underlying_scalar_constant_value(
if isinstance(grandparent, Constant): if isinstance(grandparent, Constant):
return np.asarray(np.shape(grandparent.data)[idx]) return np.asarray(np.shape(grandparent.data)[idx])
elif isinstance(op, CSM): elif isinstance(op, CSM):
data = get_underlying_scalar_constant_value( data = _get_underlying_scalar_constant_value(
v.owner.inputs, elemwise=elemwise, max_recur=max_recur v.owner.inputs, elemwise=elemwise, max_recur=max_recur
) )
# Sparse variable can only be constant if zero (or I guess if homogeneously dense) # Sparse variable can only be constant if zero (or I guess if homogeneously dense)
...@@ -556,6 +537,93 @@ def get_underlying_scalar_constant_value( ...@@ -556,6 +537,93 @@ def get_underlying_scalar_constant_value(
raise NotScalarConstantError() raise NotScalarConstantError()
def get_underlying_scalar_constant_value(
v,
*,
elemwise=True,
only_process_constants=False,
max_recur=10,
raise_not_constant=True,
):
"""Return the unique constant scalar(0-D) value underlying variable `v`.
If `v` is the output of dimshuffles, fills, allocs, etc,
cast, OutputGuard, DeepCopyOp, ScalarFromTensor, ScalarOp, Elemwise
and some pattern with Subtensor, this function digs through them.
If `v` is not some view of constant scalar data, then raise a
NotScalarConstantError.
This function performs symbolic reasoning about the value of `v`, as opposed to numerical reasoning by
constant folding the inputs of `v`.
Parameters
----------
v: Variable
elemwise : bool
If False, we won't try to go into elemwise. So this call is faster.
But we still investigate in Second Elemwise (as this is a substitute
for Alloc)
only_process_constants : bool
If True, we only attempt to obtain the value of `orig_v` if it's
directly constant and don't try to dig through dimshuffles, fills,
allocs, and other to figure out its value.
max_recur : int
The maximum number of recursion.
raise_not_constant: bool, default True
If True, raise a NotScalarConstantError if `v` does not have an
underlying constant scalar value. If False, return `v` as is.
Raises
------
NotScalarConstantError
`v` does not have an underlying constant scalar value.
Only rasise if raise_not_constant is True.
"""
try:
return _get_underlying_scalar_constant_value(
v,
elemwise=elemwise,
only_process_constants=only_process_constants,
max_recur=max_recur,
)
except NotScalarConstantError:
if raise_not_constant:
raise
return v
def get_scalar_constant_value(
v,
elemwise=True,
only_process_constants=False,
max_recur=10,
raise_not_constant: bool = True,
):
"""
Checks whether 'v' is a scalar (ndim = 0).
If 'v' is a scalar then this function fetches the underlying constant by calling
'get_underlying_scalar_constant_value()'.
If 'v' is not a scalar, it raises a NotScalarConstantError.
"""
if isinstance(v, TensorVariable | np.ndarray):
if v.ndim != 0:
print(v, v.ndim)
raise NotScalarConstantError("Input ndim != 0")
return get_underlying_scalar_constant_value(
v,
elemwise=elemwise,
only_process_constants=only_process_constants,
max_recur=max_recur,
raise_not_constant=raise_not_constant,
)
class TensorFromScalar(COp): class TensorFromScalar(COp):
__props__ = () __props__ = ()
...@@ -2012,16 +2080,16 @@ def extract_constant(x, elemwise=True, only_process_constants=False): ...@@ -2012,16 +2080,16 @@ def extract_constant(x, elemwise=True, only_process_constants=False):
ScalarVariable, we convert it to a tensor with tensor_from_scalar. ScalarVariable, we convert it to a tensor with tensor_from_scalar.
""" """
try: warnings.warn(
x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants) "extract_constant is deprecated. Use `get_underlying_scalar_constant_value(..., raise_not_constant=False)`",
except NotScalarConstantError: FutureWarning,
pass )
if isinstance(x, ps.ScalarVariable | ps.sharedvar.ScalarSharedVariable): return get_underlying_scalar_constant_value(
if x.owner and isinstance(x.owner.op, ScalarFromTensor): x,
x = x.owner.inputs[0] elemwise=elemwise,
else: only_process_constants=only_process_constants,
x = tensor_from_scalar(x) raise_not_constant=False,
return x )
def transpose(x, axes=None): def transpose(x, axes=None):
...@@ -4401,7 +4469,6 @@ __all__ = [ ...@@ -4401,7 +4469,6 @@ __all__ = [
"split", "split",
"transpose", "transpose",
"matrix_transpose", "matrix_transpose",
"extract_constant",
"default", "default",
"tensor_copy", "tensor_copy",
"transfer", "transfer",
......
...@@ -22,7 +22,7 @@ from pytensor.tensor.basic import ( ...@@ -22,7 +22,7 @@ from pytensor.tensor.basic import (
as_tensor_variable, as_tensor_variable,
cast, cast,
constant, constant,
extract_constant, get_scalar_constant_value,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
register_infer_shape, register_infer_shape,
stack, stack,
...@@ -354,7 +354,9 @@ class ShapeFeature(Feature): ...@@ -354,7 +354,9 @@ class ShapeFeature(Feature):
not hasattr(r.type, "shape") not hasattr(r.type, "shape")
or r.type.shape[i] != 1 or r.type.shape[i] != 1
or self.lscalar_one.equals(shape_vars[i]) or self.lscalar_one.equals(shape_vars[i])
or self.lscalar_one.equals(extract_constant(shape_vars[i])) or self.lscalar_one.equals(
get_scalar_constant_value(shape_vars[i], raise_not_constant=False)
)
for i in range(r.type.ndim) for i in range(r.type.ndim)
) )
self.shape_of[r] = tuple(shape_vars) self.shape_of[r] = tuple(shape_vars)
...@@ -450,7 +452,11 @@ class ShapeFeature(Feature): ...@@ -450,7 +452,11 @@ class ShapeFeature(Feature):
) )
or self.lscalar_one.equals(merged_shape[i]) or self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals( or self.lscalar_one.equals(
extract_constant(merged_shape[i], only_process_constants=True) get_underlying_scalar_constant_value(
merged_shape[i],
only_process_constants=True,
raise_not_constant=False,
)
) )
for i in range(r.type.ndim) for i in range(r.type.ndim)
) )
...@@ -474,7 +480,11 @@ class ShapeFeature(Feature): ...@@ -474,7 +480,11 @@ class ShapeFeature(Feature):
not hasattr(r.type, "shape") not hasattr(r.type, "shape")
or r.type.shape[idx] != 1 or r.type.shape[idx] != 1
or self.lscalar_one.equals(new_shape[idx]) or self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals(extract_constant(new_shape[idx])) or self.lscalar_one.equals(
get_underlying_scalar_constant_value(
new_shape[idx], raise_not_constant=False
)
)
for idx in range(r.type.ndim) for idx in range(r.type.ndim)
) )
self.shape_of[r] = tuple(new_shape) self.shape_of[r] = tuple(new_shape)
...@@ -847,7 +857,10 @@ def local_useless_reshape(fgraph, node): ...@@ -847,7 +857,10 @@ def local_useless_reshape(fgraph, node):
outshp_i.owner outshp_i.owner
and isinstance(outshp_i.owner.op, Subtensor) and isinstance(outshp_i.owner.op, Subtensor)
and len(outshp_i.owner.inputs) == 2 and len(outshp_i.owner.inputs) == 2
and extract_constant(outshp_i.owner.inputs[1]) == dim and get_scalar_constant_value(
outshp_i.owner.inputs[1], raise_not_constant=False
)
== dim
): ):
subtensor_inp = outshp_i.owner.inputs[0] subtensor_inp = outshp_i.owner.inputs[0]
if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape): if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape):
...@@ -857,7 +870,9 @@ def local_useless_reshape(fgraph, node): ...@@ -857,7 +870,9 @@ def local_useless_reshape(fgraph, node):
continue continue
# Match constant if input.type.shape[dim] == constant # Match constant if input.type.shape[dim] == constant
cst_outshp_i = extract_constant(outshp_i, only_process_constants=1) cst_outshp_i = get_scalar_constant_value(
outshp_i, only_process_constants=True, raise_not_constant=False
)
if inp.type.shape[dim] == cst_outshp_i: if inp.type.shape[dim] == cst_outshp_i:
shape_match[dim] = True shape_match[dim] = True
continue continue
...@@ -872,8 +887,12 @@ def local_useless_reshape(fgraph, node): ...@@ -872,8 +887,12 @@ def local_useless_reshape(fgraph, node):
if shape_feature: if shape_feature:
inpshp_i = shape_feature.get_shape(inp, dim) inpshp_i = shape_feature.get_shape(inp, dim)
if inpshp_i == outshp_i or ( if inpshp_i == outshp_i or (
extract_constant(inpshp_i, only_process_constants=True) get_scalar_constant_value(
== extract_constant(outshp_i, only_process_constants=True) inpshp_i, only_process_constants=True, raise_not_constant=False
)
== get_scalar_constant_value(
outshp_i, only_process_constants=True, raise_not_constant=False
)
): ):
shape_match[dim] = True shape_match[dim] = True
continue continue
...@@ -909,11 +928,14 @@ def local_reshape_to_dimshuffle(fgraph, node): ...@@ -909,11 +928,14 @@ def local_reshape_to_dimshuffle(fgraph, node):
new_output_shape = [] new_output_shape = []
index = 0 # index over the output of the new reshape index = 0 # index over the output of the new reshape
for i in range(output.ndim): for i in range(output.ndim):
# Since output_shape is a symbolic vector, we trust extract_constant # Since output_shape is a symbolic vector, we trust get_scalar_constant_value
# to go through however it is formed to see if its i-th element is 1. # to go through however it is formed to see if its i-th element is 1.
# We need only_process_constants=False for that. # We need only_process_constants=False for that.
dim = extract_constant( dim = get_scalar_constant_value(
output_shape[i], only_process_constants=False, elemwise=False output_shape[i],
only_process_constants=False,
elemwise=False,
raise_not_constant=False,
) )
if dim == 1: if dim == 1:
dimshuffle_new_order.append("x") dimshuffle_new_order.append("x")
......
...@@ -26,7 +26,7 @@ from pytensor.tensor.basic import ( ...@@ -26,7 +26,7 @@ from pytensor.tensor.basic import (
as_tensor, as_tensor,
cast, cast,
concatenate, concatenate,
extract_constant, get_scalar_constant_value,
get_underlying_scalar_constant_value, get_underlying_scalar_constant_value,
register_infer_shape, register_infer_shape,
switch, switch,
...@@ -390,8 +390,8 @@ def local_useless_slice(fgraph, node): ...@@ -390,8 +390,8 @@ def local_useless_slice(fgraph, node):
start = s.start start = s.start
stop = s.stop stop = s.stop
if start is not None and extract_constant( if start is not None and get_scalar_constant_value(
start, only_process_constants=True start, only_process_constants=True, raise_not_constant=False
) == (0 if positive_step else -1): ) == (0 if positive_step else -1):
change_flag = True change_flag = True
start = None start = None
...@@ -399,7 +399,9 @@ def local_useless_slice(fgraph, node): ...@@ -399,7 +399,9 @@ def local_useless_slice(fgraph, node):
if ( if (
stop is not None stop is not None
and x.type.shape[dim] is not None and x.type.shape[dim] is not None
and extract_constant(stop, only_process_constants=True) and get_scalar_constant_value(
stop, only_process_constants=True, raise_not_constant=False
)
== (x.type.shape[dim] if positive_step else -x.type.shape[dim] - 1) == (x.type.shape[dim] if positive_step else -x.type.shape[dim] - 1)
): ):
change_flag = True change_flag = True
...@@ -889,7 +891,10 @@ def local_useless_inc_subtensor(fgraph, node): ...@@ -889,7 +891,10 @@ def local_useless_inc_subtensor(fgraph, node):
and e.stop is None and e.stop is None
and ( and (
e.step is None e.step is None
or extract_constant(e.step, only_process_constants=True) == -1 or get_scalar_constant_value(
e.step, only_process_constants=True, raise_not_constant=False
)
== -1
) )
for e in idx_cst for e in idx_cst
): ):
...@@ -1490,7 +1495,10 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node): ...@@ -1490,7 +1495,10 @@ def local_adv_sub1_adv_inc_sub1(fgraph, node):
and and
# Don't use only_process_constants=True. We need to # Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape. # investigate Alloc of 0s but with non constant shape.
extract_constant(x, elemwise=False) != 0 get_underlying_scalar_constant_value(
x, elemwise=False, raise_not_constant=False
)
!= 0
): ):
return return
......
...@@ -1383,11 +1383,11 @@ class TestLocalUselessElemwiseComparison: ...@@ -1383,11 +1383,11 @@ class TestLocalUselessElemwiseComparison:
if op == deep_copy_op: if op == deep_copy_op:
assert len(elem.inputs) == 1, elem.inputs assert len(elem.inputs) == 1, elem.inputs
assert isinstance(elem.inputs[0], TensorConstant), elem assert isinstance(elem.inputs[0], TensorConstant), elem
assert pt.extract_constant(elem.inputs[0]) == val, val assert pt.get_underlying_scalar_constant_value(elem.inputs[0]) == val, val
else: else:
assert len(elem.inputs) == 2, elem.inputs assert len(elem.inputs) == 2, elem.inputs
assert isinstance(elem.inputs[0], TensorConstant), elem assert isinstance(elem.inputs[0], TensorConstant), elem
assert pt.extract_constant(elem.inputs[0]) == val, val assert pt.get_underlying_scalar_constant_value(elem.inputs[0]) == val, val
def assert_identity(self, f): def assert_identity(self, f):
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
......
...@@ -46,7 +46,6 @@ from pytensor.tensor.basic import ( ...@@ -46,7 +46,6 @@ from pytensor.tensor.basic import (
default, default,
diag, diag,
expand_dims, expand_dims,
extract_constant,
eye, eye,
fill, fill,
flatnonzero, flatnonzero,
...@@ -3574,10 +3573,10 @@ class TestGetUnderlyingScalarConstantValue: ...@@ -3574,10 +3573,10 @@ class TestGetUnderlyingScalarConstantValue:
# Make sure we do not return a writeable internal storage of a constant, # Make sure we do not return a writeable internal storage of a constant,
# so we cannot change the value of a constant by mistake. # so we cannot change the value of a constant by mistake.
c = constant(3) c = constant(3)
d = extract_constant(c) d = get_scalar_constant_value(c)
with pytest.raises(ValueError, match="output array is read-only"): with pytest.raises(ValueError, match="output array is read-only"):
d += 1 d += 1
e = extract_constant(c) e = get_scalar_constant_value(c)
assert e == 3, (c, d, e) assert e == 3, (c, d, e)
@pytest.mark.parametrize("only_process_constants", (True, False)) @pytest.mark.parametrize("only_process_constants", (True, False))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论