提交 c4fb0cfa authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Raise explicitly on Python methods that are incompatible with lazy variables

Notably changes the behavior of `__bool__` to always raise. Before there was a hack based on whether a variable had been compared to something before.
上级 e00abf32
...@@ -569,7 +569,7 @@ def construct_pfunc_ins_and_outs( ...@@ -569,7 +569,7 @@ def construct_pfunc_ins_and_outs(
if not fgraph: if not fgraph:
# Extend the outputs with the updates on input variables so they are # Extend the outputs with the updates on input variables so they are
# also cloned # also cloned
additional_outputs = [i.update for i in inputs if i.update] additional_outputs = [i.update for i in inputs if i.update is not None]
if outputs is None: if outputs is None:
out_list = [] out_list = []
else: else:
...@@ -608,7 +608,7 @@ def construct_pfunc_ins_and_outs( ...@@ -608,7 +608,7 @@ def construct_pfunc_ins_and_outs(
new_i.variable = iv new_i.variable = iv
# If needed, replace the input's update by its cloned equivalent # If needed, replace the input's update by its cloned equivalent
if i.update: if i.update is not None:
new_i.update = clone_d[i.update] new_i.update = clone_d[i.update]
new_inputs.append(new_i) new_inputs.append(new_i)
......
...@@ -198,7 +198,7 @@ def std_fgraph( ...@@ -198,7 +198,7 @@ def std_fgraph(
update_mapping = {} update_mapping = {}
out_idx = len(output_specs) out_idx = len(output_specs)
for idx, input_spec in enumerate(input_specs): for idx, input_spec in enumerate(input_specs):
if input_spec.update: if input_spec.update is not None:
updates.append(input_spec.update) updates.append(input_spec.update)
update_mapping[out_idx] = idx update_mapping[out_idx] = idx
out_idx += 1 out_idx += 1
...@@ -1195,7 +1195,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1195,7 +1195,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
updated_fgraph_inputs = { updated_fgraph_inputs = {
fgraph_i fgraph_i
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs, strict=True) for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs, strict=True)
if getattr(i, "update", False) if getattr(i, "update", None) is not None
} }
# We can't use fgraph.inputs as this don't include Constant Value. # We can't use fgraph.inputs as this don't include Constant Value.
...@@ -1351,7 +1351,11 @@ class FunctionMaker: ...@@ -1351,7 +1351,11 @@ class FunctionMaker:
ancestors( ancestors(
( (
[o.variable for o in outputs] [o.variable for o in outputs]
+ [i.update for i in inputs if getattr(i, "update", False)] + [
i.update
for i in inputs
if getattr(i, "update", None) is not None
]
), ),
blockers=[i.variable for i in inputs], blockers=[i.variable for i in inputs],
) )
......
...@@ -36,7 +36,7 @@ def _is_numeric_value(arr, var): ...@@ -36,7 +36,7 @@ def _is_numeric_value(arr, var):
return False return False
elif isinstance(arr, np.random.mtrand.RandomState | np.random.Generator): elif isinstance(arr, np.random.mtrand.RandomState | np.random.Generator):
return False return False
elif var and isinstance(var.type, RandomType): elif var is not None and isinstance(var.type, RandomType):
return False return False
elif isinstance(arr, slice): elif isinstance(arr, slice):
return False return False
......
...@@ -823,6 +823,37 @@ discrete_dtypes = tuple(t.dtype for t in discrete_types) ...@@ -823,6 +823,37 @@ discrete_dtypes = tuple(t.dtype for t in discrete_types)
class _scalar_py_operators: class _scalar_py_operators:
# These can't work because Python requires native output types
def __bool__(self):
raise TypeError(
"ScalarVariable cannot be converted to Python boolean. "
"Call `.astype(bool)` for the symbolic equivalent."
)
def __index__(self):
raise TypeError(
"ScalarVariable cannot be converted to Python integer. "
"Call `.astype(int)` for the symbolic equivalent."
)
def __int__(self):
raise TypeError(
"ScalarVariable cannot be converted to Python integer. "
"Call `.astype(int)` for the symbolic equivalent."
)
def __float__(self):
raise TypeError(
"ScalarVariable cannot be converted to Python float. "
"Call `.astype(float)` for the symbolic equivalent."
)
def __complex__(self):
raise TypeError(
"ScalarVariable cannot be converted to Python complex number. "
"Call `.astype(complex)` for the symbolic equivalent."
)
# So that we can simplify checking code when we have a mixture of ScalarType # So that we can simplify checking code when we have a mixture of ScalarType
# variables and Tensor variables # variables and Tensor variables
ndim = 0 ndim = 0
...@@ -843,11 +874,6 @@ class _scalar_py_operators: ...@@ -843,11 +874,6 @@ class _scalar_py_operators:
def __neg__(self): def __neg__(self):
return neg(self) return neg(self)
# CASTS
# def __int__(self): return AsInt(self).out
# def __float__(self): return AsDouble(self).out
# def __complex__(self): return AsComplex(self).out
# BITWISE # BITWISE
def __invert__(self): def __invert__(self):
return invert(self) return invert(self)
......
...@@ -60,12 +60,12 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -60,12 +60,12 @@ class ScalarLoop(ScalarInnerGraphOp):
constant = [] constant = []
if not len(init) == len(update): if not len(init) == len(update):
raise ValueError("An update must be given for each init variable") raise ValueError("An update must be given for each init variable")
if until: if until is not None:
inputs, outputs = clone([*init, *constant], [*update, until]) inputs, outputs = clone([*init, *constant], [*update, until])
else: else:
inputs, outputs = clone([*init, *constant], update) inputs, outputs = clone([*init, *constant], update)
self.is_while = bool(until) self.is_while = until is not None
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs) self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
self._validate_updates(self.inputs, self.outputs) self._validate_updates(self.inputs, self.outputs)
......
...@@ -856,7 +856,7 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")): ...@@ -856,7 +856,7 @@ def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")):
dfac = k_minus_one_minus_n * dfac + fac dfac = k_minus_one_minus_n * dfac + fac
fac *= k_minus_one_minus_n fac *= k_minus_one_minus_n
delta = dfac / xpow delta = dfac / xpow
return (sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac), () return (sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac), None
init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac] init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac]
constant = [x] constant = [x]
......
...@@ -979,7 +979,7 @@ def scan( ...@@ -979,7 +979,7 @@ def scan(
# user-specified within the inner-function (e.g. by returning an update # user-specified within the inner-function (e.g. by returning an update
# `dict`) or the `SharedVariable.default_update`s of a shared variable # `dict`) or the `SharedVariable.default_update`s of a shared variable
# created in the inner-function. # created in the inner-function.
if input.update and (is_local or input.variable in updates): if input.update is not None and (is_local or input.variable in updates):
# We need to remove the `default_update`s on the shared # We need to remove the `default_update`s on the shared
# variables created within the context of the loop function # variables created within the context of the loop function
# (e.g. via use of `RandomStream`); otherwise, they'll get # (e.g. via use of `RandomStream`); otherwise, they'll get
......
...@@ -3430,7 +3430,14 @@ class _nd_grid: ...@@ -3430,7 +3430,14 @@ class _nd_grid:
raise NotImplementedError( raise NotImplementedError(
"Not implemented for slices whose step is complex" "Not implemented for slices whose step is complex"
) )
ranges = [arange(sl.start or 0, sl.stop, sl.step or 1) for sl in args[0]] ranges = [
arange(
sl.start if sl.start is not None else 0,
sl.stop,
sl.step if sl.step is not None else 1,
)
for sl in args[0]
]
shapes = [ shapes = [
tuple([1] * j + [r.shape[0]] + [1] * (ndim - 1 - j)) tuple([1] * j + [r.shape[0]] + [1] * (ndim - 1 - j))
for j, r in enumerate(ranges) for j, r in enumerate(ranges)
......
...@@ -2199,7 +2199,7 @@ class BaseAbstractConv(Op): ...@@ -2199,7 +2199,7 @@ class BaseAbstractConv(Op):
): ):
border_mode = "valid" border_mode = "valid"
self.imshp = tuple(imshp) if imshp else (None,) * (2 + convdim) self.imshp = tuple(imshp) if imshp is not None else (None,) * (2 + convdim)
for imshp_i in self.imshp: for imshp_i in self.imshp:
if imshp_i is not None: if imshp_i is not None:
# Components of imshp should be constant or ints # Components of imshp should be constant or ints
...@@ -2209,7 +2209,7 @@ class BaseAbstractConv(Op): ...@@ -2209,7 +2209,7 @@ class BaseAbstractConv(Op):
raise ValueError( raise ValueError(
"imshp should be None or a tuple of constant int values" "imshp should be None or a tuple of constant int values"
).with_traceback(sys.exc_info()[2]) ).with_traceback(sys.exc_info()[2])
if kshp: if kshp is not None:
self.kshp = tuple(kshp) self.kshp = tuple(kshp)
else: else:
self.kshp = (None,) * ((2 + 2 * convdim) if unshared else (2 + convdim)) self.kshp = (None,) * ((2 + 2 * convdim) if unshared else (2 + convdim))
......
...@@ -1811,14 +1811,14 @@ class Dot(Op): ...@@ -1811,14 +1811,14 @@ class Dot(Op):
if eval_points[0] is None and eval_points[1] is None: if eval_points[0] is None and eval_points[1] is None:
return [None] return [None]
if eval_points[0]: if eval_points[0] is not None:
t1 = self(eval_points[0], inputs[1]) t1 = self(eval_points[0], inputs[1])
if eval_points[1]: if eval_points[1] is not None:
t2 = self(inputs[0], eval_points[1]) t2 = self(inputs[0], eval_points[1])
if eval_points[0] and eval_points[1]: if eval_points[0] is not None and eval_points[1] is not None:
return [t1 + t2] return [t1 + t2]
elif eval_points[0]: elif eval_points[0] is not None:
return [t1] return [t1]
else: else:
return [t2] return [t2]
......
...@@ -803,7 +803,7 @@ def local_dot22_to_dot22scalar(fgraph, node): ...@@ -803,7 +803,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
""" """
if node.op != mul: if node.op != mul:
return False return False
i_dot22 = [x.owner and x.owner.op == _dot22 for x in node.inputs] i_dot22 = [x.owner is not None and x.owner.op == _dot22 for x in node.inputs]
if not any(i_dot22): if not any(i_dot22):
return False # no dot22 return False # no dot22
if i_dot22.count(True) > 1: if i_dot22.count(True) > 1:
...@@ -813,14 +813,16 @@ def local_dot22_to_dot22scalar(fgraph, node): ...@@ -813,14 +813,16 @@ def local_dot22_to_dot22scalar(fgraph, node):
dot22_idx = i_dot22.index(True) dot22_idx = i_dot22.index(True)
d = node.inputs[dot22_idx] d = node.inputs[dot22_idx]
i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs] i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs]
if not any(i_scalar): if all(i is None for i in i_scalar):
# Check if we can reorder the graph as this mul have a mul in inputs. # Check if we can reorder the graph as this mul have a mul in inputs.
# We support only 1 additional level of mul. # We support only 1 additional level of mul.
# The canonizer should have merged those mul together. # The canonizer should have merged those mul together.
i_mul = [ i_mul = [
x.owner x.owner
and x.owner.op == mul and x.owner.op == mul
and any(_as_scalar(x_i, dtype=d.dtype) for x_i in x.owner.inputs) and any(
_as_scalar(x_i, dtype=d.dtype) is not None for x_i in x.owner.inputs
)
for x in node.inputs for x in node.inputs
] ]
if not any(i_mul): if not any(i_mul):
...@@ -834,7 +836,7 @@ def local_dot22_to_dot22scalar(fgraph, node): ...@@ -834,7 +836,7 @@ def local_dot22_to_dot22scalar(fgraph, node):
scalar_idx = -1 scalar_idx = -1
for i, x in enumerate(m.owner.inputs): for i, x in enumerate(m.owner.inputs):
if _as_scalar(x, dtype=d.dtype) and ( if _as_scalar(x, dtype=d.dtype) is not None and (
pytensor.scalar.upcast(x.type.dtype, d.type.dtype) == d.type.dtype pytensor.scalar.upcast(x.type.dtype, d.type.dtype) == d.type.dtype
): ):
scalar_idx = i scalar_idx = i
......
...@@ -1331,14 +1331,14 @@ def local_sum_prod_of_mul_or_div(fgraph, node): ...@@ -1331,14 +1331,14 @@ def local_sum_prod_of_mul_or_div(fgraph, node):
# If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements # If we have a `Prod`, then the outside terms need to be raised to the power of the number of elements
# that were contracted in the input # that were contracted in the input
if isinstance(node.op, Prod) and inner_term: if isinstance(node.op, Prod) and inner_term is not None:
dtype = inner_term.dtype dtype = inner_term.dtype
n_reduced_elements = prod( n_reduced_elements = prod(
[inner_term.shape[i].astype(dtype) for i in reduced_axes] [inner_term.shape[i].astype(dtype) for i in reduced_axes]
) )
outer_term = outer_term**n_reduced_elements outer_term = outer_term**n_reduced_elements
if not inner_term: if inner_term is None:
# Sum/Prod is useless, just return the outer_term # Sum/Prod is useless, just return the outer_term
# (This can only happen for mul, not division) # (This can only happen for mul, not division)
new_out = outer_term new_out = outer_term
...@@ -1992,7 +1992,7 @@ def local_pow_canonicalize(fgraph, node): ...@@ -1992,7 +1992,7 @@ def local_pow_canonicalize(fgraph, node):
# x ** 1 = x # x ** 1 = x
new_out = broadcast_arrays(*node.inputs)[0] new_out = broadcast_arrays(*node.inputs)[0]
if not new_out: if new_out is None:
return return
if new_out.dtype != node.out.dtype: if new_out.dtype != node.out.dtype:
...@@ -2119,7 +2119,7 @@ def local_pow_to_nested_squaring(fgraph, node): ...@@ -2119,7 +2119,7 @@ def local_pow_to_nested_squaring(fgraph, node):
rval1_scal = None rval1_scal = None
while y_to_do > 0: while y_to_do > 0:
log_to_do = int(np.log2(y_to_do)) log_to_do = int(np.log2(y_to_do))
if rval1: if rval1 is not None:
rval1 *= pow2[log_to_do] rval1 *= pow2[log_to_do]
rval1_scal *= pow2_scal[log_to_do] rval1_scal *= pow2_scal[log_to_do]
else: else:
...@@ -2137,7 +2137,7 @@ def local_pow_to_nested_squaring(fgraph, node): ...@@ -2137,7 +2137,7 @@ def local_pow_to_nested_squaring(fgraph, node):
rval = [reciprocal(rval1)] rval = [reciprocal(rval1)]
else: else:
rval = [rval1] rval = [rval1]
if rval: if rval is not None:
rval[0] = cast(rval[0], odtype) rval[0] = cast(rval[0], odtype)
return rval return rval
......
...@@ -162,7 +162,7 @@ def softmax_simplifier(numerators, denominators): ...@@ -162,7 +162,7 @@ def softmax_simplifier(numerators, denominators):
matching_denom = denominator matching_denom = denominator
break break
if matching_denom: if matching_denom is not None:
softmax = Softmax(axis=sum_axis)(numerator.owner.inputs[0]) softmax = Softmax(axis=sum_axis)(numerator.owner.inputs[0])
copy_stack_trace(numerator, softmax) copy_stack_trace(numerator, softmax)
numerators.remove(numerator) numerators.remove(numerator)
......
...@@ -26,53 +26,54 @@ _TensorTypeType = TypeVar("_TensorTypeType", bound=TensorType) ...@@ -26,53 +26,54 @@ _TensorTypeType = TypeVar("_TensorTypeType", bound=TensorType)
class _tensor_py_operators: class _tensor_py_operators:
# These can't work because Python requires native output types
def __bool__(self):
raise TypeError(
"TensorVariable cannot be converted to Python boolean. "
"Call `.astype(bool)` for the symbolic equivalent."
)
def __index__(self):
raise TypeError(
"TensorVariable cannot be converted to Python integer. "
"Call `.astype(int)` for the symbolic equivalent."
)
def __int__(self):
raise TypeError(
"TensorVariable cannot be converted to Python integer. "
"Call `.astype(int)` for the symbolic equivalent."
)
def __float__(self):
raise TypeError(
"TensorVariables cannot be converted to Python float. "
"Call `.astype(float)` for the symbolic equivalent."
)
def __complex__(self):
raise TypeError(
"TensorVariables cannot be converted to Python complex number. "
"Call `.astype(complex)` for the symbolic equivalent."
)
def __abs__(self): def __abs__(self):
return pt.math.abs(self) return pt.math.abs(self)
def __neg__(self): def __neg__(self):
return pt.math.neg(self) return pt.math.neg(self)
# These won't work because Python requires an int return value
# def __int__(self): return convert_to_int32(self)
# def __float__(self): return convert_to_float64(self)
# def __complex__(self): return convert_to_complex128(self)
_is_nonzero = True
def __lt__(self, other): def __lt__(self, other):
rval = pt.math.lt(self, other) return pt.math.lt(self, other)
rval._is_nonzero = False
return rval
def __le__(self, other): def __le__(self, other):
rval = pt.math.le(self, other) return pt.math.le(self, other)
rval._is_nonzero = False
return rval
def __gt__(self, other): def __gt__(self, other):
rval = pt.math.gt(self, other) return pt.math.gt(self, other)
rval._is_nonzero = False
return rval
def __ge__(self, other): def __ge__(self, other):
rval = pt.math.ge(self, other) return pt.math.ge(self, other)
rval._is_nonzero = False
return rval
def __bool__(self):
# This is meant to prohibit stuff like a < b < c, which is internally
# implemented as (a < b) and (b < c). The trouble with this is the
# side-effect that checking for a non-NULL a by typing "if a: ..."
# uses the same __nonzero__ method. We want these both to work, but
# it seems impossible. Currently, all vars evaluate to nonzero except
# the return values of comparison operators, which raise this
# exception. If you can think of a better solution, go for it!
#
# __bool__ is Python 3.x data model. __nonzero__ is Python 2.x.
if self._is_nonzero:
return True
else:
raise TypeError("Variables do not support boolean operations.")
def __invert__(self): def __invert__(self):
return pt.math.invert(self) return pt.math.invert(self)
......
...@@ -399,7 +399,7 @@ def test_tensor_creator_dtype_catch(dtype): ...@@ -399,7 +399,7 @@ def test_tensor_creator_dtype_catch(dtype):
tensor(dtype, shape=(None,)) tensor(dtype, shape=(None,))
# This should work # This should work
assert tensor(dtype=dtype, shape=(None,)) assert tensor(dtype=dtype, shape=(None,)) is not None
def test_tensor_creator_ignores_rare_dtype_name(): def test_tensor_creator_ignores_rare_dtype_name():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论