提交 93eb73fd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor exceptions in Scan Op

上级 3cccea1f
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -89,6 +89,119 @@ from aesara.tensor.var import TensorVariable ...@@ -89,6 +89,119 @@ from aesara.tensor.var import TensorVariable
_logger = logging.getLogger("aesara.scan.op") _logger = logging.getLogger("aesara.scan.op")
err_msg1 = (
"When compiling the inner function of scan (the "
"function called by scan in each of its iterations) "
"the following error has been encountered: The "
"%s %s (argument number %d) has dtype "
"%s and %d dimension(s). The corresponding variable "
"in the inner function of scan %s "
"however has dtype %s and %d dimension(s). This "
"variable in the inner function of scan should "
"have the same dtype and one fewer dimension "
"compared to its corresponding variable in the initial "
"state (outputs_info in scan nomenclature). For example, "
"if the inner function of scan returns a vector "
"of size d and scan uses the values of "
"the previous time-step, then the initial state in scan "
"should be a matrix of shape (1, d). "
"The first dimension of this "
"matrix corresponds to the number of previous time-steps "
"that scan uses in each of its iterations. "
"In order to solve this issue if the two variable currently "
"have the same dimensionality, you can increase the "
"dimensionality of the varialbe in the initial state of scan "
"by using dimshuffle or shape_padleft. "
)
err_msg2 = (
"When compiling the inner function of scan the "
"following error has been encountered: The "
"initial state (`outputs_info` in scan nomenclature) "
"of variable %s (argument number %d) "
"has dtype %s, while the result of the inner function "
"(`fn`) has dtype %s. This can happen if the inner "
"function of scan results in an upcast or downcast."
)
err_msg3 = (
"When compiling the inner function of scan (the "
"function called by scan in each of its iterations) "
"the following error has been encountered: The "
"initial state (`outputs_info` in scan nomenclature) "
"of variable %s (argument number %d) has %d dimension(s), "
"while the corresponding variable in the result of the inner "
"function of scan (`fn`) has %d dimension(s) (it should "
"be one less than the initial state). For example, "
"if the inner function of scan returns a vector "
"of size d and scan uses the values of "
"the previous time-step, then the initial state in scan "
"should be a matrix of shape (1, d). "
"The first dimension of this "
"matrix corresponds to the number of previous time-steps "
"that scan uses in each of its iterations. "
"In order to solve this issue if the two varialbe currently "
"have the same dimensionality, you can increase the "
"dimensionality of the variable in the initial state of scan "
"by using dimshuffle or shape_padleft. "
)
def check_broadcast(v1, v2):
"""Checks that the broadcast pattern of v1 and v2.
Controls that the broadcast pattern of the variable provided as
input to `scan` matches the broadcast pattern provided in
`output_info`. It raises an error when they don't match. The
typical case is when the user provides either the input or the
`output_info` (but not both) with a dimension fixed to 1,
which may wrongly be interpreted as broadcastable.
"""
if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"):
return
msg = (
"The broadcast pattern of the output of scan (%s) is "
"inconsistent with the one provided in `output_info` "
"(%s). The output on axis %d is `%r`, but it is `%r` on "
"axis %d in `output_info`. This can happen if one of the "
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using aesara.tensor."
"{patternbroadcast,unbroadcast,addbroadcast}."
)
size = min(len(v1.broadcastable), len(v2.broadcastable))
for n, (b1, b2) in enumerate(
zip(v1.broadcastable[-size:], v2.broadcastable[-size:])
):
if b1 != b2:
a1 = n + size - len(v1.broadcastable) + 1
a2 = n + size - len(v2.broadcastable) + 1
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))
def copy_var_format(var, as_var):
"""
This functions ensures that ``var`` has the same dtype as ``as_var`` as
well as calling `filter_variable` to make sure they are both `TensorType`
or `GpuArrayType`.
It internally deals with the corner case where ``inp.ndim + 1 = out.ndim``.
"""
if not hasattr(var, "dtype"):
return var
rval = var
if rval.type.dtype != as_var.type.dtype:
rval = rval.astype(as_var.type.dtype)
if rval.ndim == as_var.ndim:
rval = as_var.type.filter_variable(rval)
else:
tmp = as_var.type.clone(
broadcastable=(tuple(var.broadcastable[:1]) + tuple(as_var.broadcastable))
)
rval = tmp.filter_variable(rval)
return rval
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ScanInfo: class ScanInfo:
tap_array: tuple tap_array: tuple
...@@ -721,7 +834,9 @@ class Scan(Op, ScanMethodsMixin): ...@@ -721,7 +834,9 @@ class Scan(Op, ScanMethodsMixin):
the inner function) the inner function)
""" """
assert np.all(isinstance(i, Variable) for i in inputs) if not all(isinstance(i, Variable) for i in inputs):
raise TypeError("Inputs must be `Variable` instances")
# Check that the number of inputs to the Scan node corresponds to # Check that the number of inputs to the Scan node corresponds to
# the number of inputs of the inner function of scan # the number of inputs of the inner function of scan
n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1 n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1
...@@ -733,123 +848,15 @@ class Scan(Op, ScanMethodsMixin): ...@@ -733,123 +848,15 @@ class Scan(Op, ScanMethodsMixin):
+ len(self.inner_shared(self.inputs)) + len(self.inner_shared(self.inputs))
+ len(self.inner_non_seqs(self.inputs)) + len(self.inner_non_seqs(self.inputs))
) )
assert n_outer_ins == n_inner_ins, (
if n_outer_ins != n_inner_ins:
raise ValueError(
"The number of inputs given to the inner function of scan" "The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan." " does not match the number of inputs given to scan."
) )
# Force the inputs to be on the CPU # Force the inputs to be on the CPU
new_inputs = [as_tensor_variable(inputs[0])] new_inputs = [as_tensor_variable(inputs[0])]
# assert dtype is consistent
err_msg1 = (
"When compiling the inner function of scan (the "
"function called by scan in each of its iterations) "
"the following error has been encountered: The "
"%s %s (argument number %d) has dtype "
"%s and %d dimension(s). The corresponding variable "
"in the inner function of scan %s "
"however has dtype %s and %d dimension(s). This "
"variable in the inner function of scan should "
"have the same dtype and one fewer dimension "
"compared to its corresponding variable in the initial "
"state (outputs_info in scan nomenclature). For example, "
"if the inner function of scan returns a vector "
"of size d and scan uses the values of "
"the previous time-step, then the initial state in scan "
"should be a matrix of shape (1, d). "
"The first dimension of this "
"matrix corresponds to the number of previous time-steps "
"that scan uses in each of its iterations. "
"In order to solve this issue if the two variable currently "
"have the same dimensionality, you can increase the "
"dimensionality of the varialbe in the initial state of scan "
"by using dimshuffle or shape_padleft. "
)
err_msg2 = (
"When compiling the inner function of scan the "
"following error has been encountered: The "
"initial state (`outputs_info` in scan nomenclature) "
"of variable %s (argument number %d) "
"has dtype %s, while the result of the inner function "
"(`fn`) has dtype %s. This can happen if the inner "
"function of scan results in an upcast or downcast."
)
err_msg3 = (
"When compiling the inner function of scan (the "
"function called by scan in each of its iterations) "
"the following error has been encountered: The "
"initial state (`outputs_info` in scan nomenclature) "
"of variable %s (argument number %d) has %d dimension(s), "
"while the corresponding variable in the result of the inner "
"function of scan (`fn`) has %d dimension(s) (it should "
"be one less than the initial state). For example, "
"if the inner function of scan returns a vector "
"of size d and scan uses the values of "
"the previous time-step, then the initial state in scan "
"should be a matrix of shape (1, d). "
"The first dimension of this "
"matrix corresponds to the number of previous time-steps "
"that scan uses in each of its iterations. "
"In order to solve this issue if the two varialbe currently "
"have the same dimensionality, you can increase the "
"dimensionality of the variable in the initial state of scan "
"by using dimshuffle or shape_padleft. "
)
def check_broadcast(v1, v2):
"""Checks that the broadcast pattern of v1 and v2.
Controls that the broadcast pattern of the variable provided as
input to `scan` matches the broadcast pattern provided in
`output_info`. It raises an error when they don't match. The
typical case is when the user provides either the input or the
`output_info` (but not both) with a dimension fixed to 1,
which may wrongly be interpreted as broadcastable.
"""
if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"):
return
msg = (
"The broadcast pattern of the output of scan (%s) is "
"inconsistent with the one provided in `output_info` "
"(%s). The output on axis %d is `%r`, but it is `%r` on "
"axis %d in `output_info`. This can happen if one of the "
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using aesara.tensor."
"{patternbroadcast,unbroadcast,addbroadcast}."
)
size = min(len(v1.broadcastable), len(v2.broadcastable))
for n, (b1, b2) in enumerate(
zip(v1.broadcastable[-size:], v2.broadcastable[-size:])
):
if b1 != b2:
a1 = n + size - len(v1.broadcastable) + 1
a2 = n + size - len(v2.broadcastable) + 1
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))
def format(var, as_var):
"""
This functions ensures that ``out`` has the same dtype as
``inp`` as well as calling filter_variable to make sure
they are both TensorType or GpuArrayType. It internally
deals with the corner case where inp.ndim + 1 = out.ndim
"""
if not hasattr(var, "dtype"):
return var
rval = var
if rval.type.dtype != as_var.type.dtype:
rval = rval.astype(as_var.type.dtype)
if rval.ndim == as_var.ndim:
rval = as_var.type.filter_variable(rval)
else:
tmp = as_var.type.clone(
broadcastable=(
tuple(var.broadcastable[:1]) + tuple(as_var.broadcastable)
)
)
rval = tmp.filter_variable(rval)
return rval
# Check if input sequences and variables representing a slice of # Check if input sequences and variables representing a slice of
# them have the same dtype # them have the same dtype
...@@ -858,7 +865,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -858,7 +865,7 @@ class Scan(Op, ScanMethodsMixin):
self.inner_seqs(self.inputs), self.outer_seqs(inputs) self.inner_seqs(self.inputs), self.outer_seqs(inputs)
): ):
check_broadcast(outer_seq, inner_seq) check_broadcast(outer_seq, inner_seq)
new_inputs.append(format(outer_seq, as_var=inner_seq)) new_inputs.append(copy_var_format(outer_seq, as_var=inner_seq))
argoffset += len(self.outer_seqs(inputs)) argoffset += len(self.outer_seqs(inputs))
# Check that this 3 things have the same dtype for mit_mot: # Check that this 3 things have the same dtype for mit_mot:
...@@ -872,7 +879,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -872,7 +879,7 @@ class Scan(Op, ScanMethodsMixin):
for idx, (itaps, otaps, _outer_mitmot) in enumerate( for idx, (itaps, otaps, _outer_mitmot) in enumerate(
zip(self.mitmot_taps(), self.mitmot_out_taps(), self.outer_mitmot(inputs)) zip(self.mitmot_taps(), self.mitmot_out_taps(), self.outer_mitmot(inputs))
): ):
outer_mitmot = format(_outer_mitmot, as_var=inner_mitmot[ipos]) outer_mitmot = copy_var_format(_outer_mitmot, as_var=inner_mitmot[ipos])
new_inputs.append(outer_mitmot) new_inputs.append(outer_mitmot)
for k in range(len(itaps)): for k in range(len(itaps)):
if ( if (
...@@ -882,7 +889,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -882,7 +889,7 @@ class Scan(Op, ScanMethodsMixin):
raise ValueError( raise ValueError(
err_msg1 err_msg1
% ( % (
"initial state (outputs_info" " in scan nomenclature) ", "initial state (`outputs_info` in scan nomenclature) ",
str(outer_mitmot), str(outer_mitmot),
argoffset + idx, argoffset + idx,
outer_mitmot.type.dtype, outer_mitmot.type.dtype,
...@@ -926,7 +933,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -926,7 +933,7 @@ class Scan(Op, ScanMethodsMixin):
self.inner_mitsot_outs(self.outputs), self.inner_mitsot_outs(self.outputs),
) )
): ):
outer_mitsot = format(_outer_mitsot, as_var=inner_mitsots[ipos]) outer_mitsot = copy_var_format(_outer_mitsot, as_var=inner_mitsots[ipos])
new_inputs.append(outer_mitsot) new_inputs.append(outer_mitsot)
for k in range(len(itaps)): for k in range(len(itaps)):
...@@ -978,7 +985,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -978,7 +985,7 @@ class Scan(Op, ScanMethodsMixin):
self.inner_sitsot_outs(self.outputs), self.inner_sitsot_outs(self.outputs),
) )
): ):
outer_sitsot = format(_outer_sitsot, as_var=inner_sitsot) outer_sitsot = copy_var_format(_outer_sitsot, as_var=inner_sitsot)
new_inputs.append(outer_sitsot) new_inputs.append(outer_sitsot)
if inner_sitsot.ndim != outer_sitsot.ndim - 1: if inner_sitsot.ndim != outer_sitsot.ndim - 1:
raise ValueError( raise ValueError(
...@@ -1025,7 +1032,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1025,7 +1032,7 @@ class Scan(Op, ScanMethodsMixin):
self.outer_shared(inputs), self.outer_shared(inputs),
) )
): ):
outer_shared = format(_outer_shared, as_var=inner_shared) outer_shared = copy_var_format(_outer_shared, as_var=inner_shared)
new_inputs.append(outer_shared) new_inputs.append(outer_shared)
if ( if (
hasattr(outer_shared, "dtype") hasattr(outer_shared, "dtype")
...@@ -1071,7 +1078,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1071,7 +1078,7 @@ class Scan(Op, ScanMethodsMixin):
inner_shared.ndim, inner_shared.ndim,
) )
) )
# We do not need to call `format` on outer_nisot arguments. # We do not need to call `copy_var_format` on outer_nisot arguments.
# outer_nitsot stands for no input tap single output tap. This means # outer_nitsot stands for no input tap single output tap. This means
# these are states that do not feed anything back in the recurrent # these are states that do not feed anything back in the recurrent
# computation, and hence they do not have an initial state. The scan # computation, and hence they do not have an initial state. The scan
...@@ -1083,15 +1090,14 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1083,15 +1090,14 @@ class Scan(Op, ScanMethodsMixin):
for inner_nonseq, _outer_nonseq in zip( for inner_nonseq, _outer_nonseq in zip(
self.inner_non_seqs(self.inputs), self.outer_non_seqs(inputs) self.inner_non_seqs(self.inputs), self.outer_non_seqs(inputs)
): ):
outer_nonseq = format(_outer_nonseq, as_var=inner_nonseq) outer_nonseq = copy_var_format(_outer_nonseq, as_var=inner_nonseq)
new_inputs.append(outer_nonseq) new_inputs.append(outer_nonseq)
if inner_nonseq.type != outer_nonseq.type: if inner_nonseq.type != outer_nonseq.type:
raise ValueError( raise ValueError(
( (
"Argument %s given to scan node does not" f"Argument {outer_nonseq} given to the scan node does not"
" match its correspondence %s" f" match its corresponding loop function variable {inner_nonseq}"
) )
% (str(outer_nonseq), str(inner_nonseq))
) )
for outer_nitsot in self.outer_nitsot(inputs): for outer_nitsot in self.outer_nitsot(inputs):
...@@ -1102,10 +1108,8 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1102,10 +1108,8 @@ class Scan(Op, ScanMethodsMixin):
str(outer_nitsot.type.dtype) not in integer_dtypes str(outer_nitsot.type.dtype) not in integer_dtypes
or outer_nitsot.ndim != 0 or outer_nitsot.ndim != 0
): ):
raise ValueError( raise ValueError(f"A scalar int is required for output {outer_nitsot}")
"For output %s you need to provide a " "scalar int !",
str(outer_nitsot),
)
assert len(new_inputs) == len(inputs) assert len(new_inputs) == len(inputs)
# The vector_seqs and vector_outs are just a workaround # The vector_seqs and vector_outs are just a workaround
...@@ -1505,24 +1509,16 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1505,24 +1509,16 @@ class Scan(Op, ScanMethodsMixin):
# History, in the past, this was used for backward # History, in the past, this was used for backward
# scan. Now we reverse the inputs outside of scan. # scan. Now we reverse the inputs outside of scan.
raise IndexError( raise IndexError(
f"Scan was asked to run for negative number of step {int(n_steps)}" f"Scan was asked to run for negative number of step {n_steps}"
) )
elif n_steps == 0: elif n_steps == 0:
raise NotImplementedError( raise NotImplementedError("n_steps == 0")
"We didn't implemented yet the case where scan do 0 iteration"
)
else: else:
for idx, seq in enumerate(inputs[1 : self.seqs_arg_offset]): for idx, seq in enumerate(inputs[1 : self.seqs_arg_offset]):
if seq.shape[0] < n_steps: if seq.shape[0] < n_steps:
raise ValueError( raise ValueError(
( f"Sequence {idx} has shape {seq.shape} "
"Sequence is shorter then the required " f"but the Scan's required number of steps is {n_steps}"
"number of steps : (n_steps, seq, "
"seq.shape):"
),
n_steps,
node.inputs[1 + idx],
seq.shape,
) )
seqs.append(seq) seqs.append(seq)
...@@ -2267,7 +2263,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -2267,7 +2263,7 @@ class Scan(Op, ScanMethodsMixin):
if str(g_y.dtype) in integer_dtypes: if str(g_y.dtype) in integer_dtypes:
raise TypeError( raise TypeError(
"Gradients may never be integers but g_y " "Gradients may never be integers but g_y "
"has type " + str(g_y.type) f"has type {g_y.type}"
) )
out_indices = [get_out_idx(self_outputs.index(y)) for y in y_s] out_indices = [get_out_idx(self_outputs.index(y)) for y in y_s]
......
...@@ -58,7 +58,7 @@ from aesara.link.utils import raise_with_op ...@@ -58,7 +58,7 @@ from aesara.link.utils import raise_with_op
def get_version(): def get_version():
return 0.299 return 0.300
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
...@@ -206,16 +206,18 @@ def perform( ...@@ -206,16 +206,18 @@ def perform(
"Scan was asked to run for negative number of step %d" % "Scan was asked to run for negative number of step %d" %
n_steps) n_steps)
elif n_steps == 0: elif n_steps == 0:
raise NotImplementedError( raise NotImplementedError("n_steps == 0")
"We didn't implemented yet the case where scan do 0 iteration")
else: else:
for idx in range(n_seqs): for idx in range(n_seqs):
if args[<unsigned int>(1+idx)].shape[0] < n_steps: if args[<unsigned int>(1+idx)].shape[0] < n_steps:
raise ValueError(('Sequence is shorter than the required ' raise ValueError((
'number of steps : (n_steps, seq, ' "Sequence %s has shape %s "
'seq.shape):'), n_steps, "but the Scan's required number of steps is %s"
args[1+idx], ) % (
args[1+idx].shape) idx,
args[1+idx].shape,
n_steps,
))
# 2. Allocate memory for the outputs. Construct the list: # 2. Allocate memory for the outputs. Construct the list:
# store_steps -- map containing the length of each output # store_steps -- map containing the length of each output
# pos -- map containing the current position of each output # pos -- map containing the current position of each output
......
...@@ -21,7 +21,7 @@ if not config.cxx: ...@@ -21,7 +21,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform") _logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.299 # must match constant returned in function get_version() version = 0.300 # must match constant returned in function get_version()
need_reload = False need_reload = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论