提交 93eb73fd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor exceptions in Scan Op

上级 3cccea1f
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -89,6 +89,119 @@ from aesara.tensor.var import TensorVariable
_logger = logging.getLogger("aesara.scan.op")
err_msg1 = (
"When compiling the inner function of scan (the "
"function called by scan in each of its iterations) "
"the following error has been encountered: The "
"%s %s (argument number %d) has dtype "
"%s and %d dimension(s). The corresponding variable "
"in the inner function of scan %s "
"however has dtype %s and %d dimension(s). This "
"variable in the inner function of scan should "
"have the same dtype and one fewer dimension "
"compared to its corresponding variable in the initial "
"state (outputs_info in scan nomenclature). For example, "
"if the inner function of scan returns a vector "
"of size d and scan uses the values of "
"the previous time-step, then the initial state in scan "
"should be a matrix of shape (1, d). "
"The first dimension of this "
"matrix corresponds to the number of previous time-steps "
"that scan uses in each of its iterations. "
"In order to solve this issue if the two variable currently "
"have the same dimensionality, you can increase the "
"dimensionality of the varialbe in the initial state of scan "
"by using dimshuffle or shape_padleft. "
)
err_msg2 = (
"When compiling the inner function of scan the "
"following error has been encountered: The "
"initial state (`outputs_info` in scan nomenclature) "
"of variable %s (argument number %d) "
"has dtype %s, while the result of the inner function "
"(`fn`) has dtype %s. This can happen if the inner "
"function of scan results in an upcast or downcast."
)
err_msg3 = (
"When compiling the inner function of scan (the "
"function called by scan in each of its iterations) "
"the following error has been encountered: The "
"initial state (`outputs_info` in scan nomenclature) "
"of variable %s (argument number %d) has %d dimension(s), "
"while the corresponding variable in the result of the inner "
"function of scan (`fn`) has %d dimension(s) (it should "
"be one less than the initial state). For example, "
"if the inner function of scan returns a vector "
"of size d and scan uses the values of "
"the previous time-step, then the initial state in scan "
"should be a matrix of shape (1, d). "
"The first dimension of this "
"matrix corresponds to the number of previous time-steps "
"that scan uses in each of its iterations. "
"In order to solve this issue if the two varialbe currently "
"have the same dimensionality, you can increase the "
"dimensionality of the variable in the initial state of scan "
"by using dimshuffle or shape_padleft. "
)
def check_broadcast(v1, v2):
"""Checks that the broadcast pattern of v1 and v2.
Controls that the broadcast pattern of the variable provided as
input to `scan` matches the broadcast pattern provided in
`output_info`. It raises an error when they don't match. The
typical case is when the user provides either the input or the
`output_info` (but not both) with a dimension fixed to 1,
which may wrongly be interpreted as broadcastable.
"""
if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"):
return
msg = (
"The broadcast pattern of the output of scan (%s) is "
"inconsistent with the one provided in `output_info` "
"(%s). The output on axis %d is `%r`, but it is `%r` on "
"axis %d in `output_info`. This can happen if one of the "
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using aesara.tensor."
"{patternbroadcast,unbroadcast,addbroadcast}."
)
size = min(len(v1.broadcastable), len(v2.broadcastable))
for n, (b1, b2) in enumerate(
zip(v1.broadcastable[-size:], v2.broadcastable[-size:])
):
if b1 != b2:
a1 = n + size - len(v1.broadcastable) + 1
a2 = n + size - len(v2.broadcastable) + 1
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))
def copy_var_format(var, as_var):
"""
This functions ensures that ``var`` has the same dtype as ``as_var`` as
well as calling `filter_variable` to make sure they are both `TensorType`
or `GpuArrayType`.
It internally deals with the corner case where ``inp.ndim + 1 = out.ndim``.
"""
if not hasattr(var, "dtype"):
return var
rval = var
if rval.type.dtype != as_var.type.dtype:
rval = rval.astype(as_var.type.dtype)
if rval.ndim == as_var.ndim:
rval = as_var.type.filter_variable(rval)
else:
tmp = as_var.type.clone(
broadcastable=(tuple(var.broadcastable[:1]) + tuple(as_var.broadcastable))
)
rval = tmp.filter_variable(rval)
return rval
@dataclasses.dataclass(frozen=True)
class ScanInfo:
tap_array: tuple
......@@ -721,7 +834,9 @@ class Scan(Op, ScanMethodsMixin):
the inner function)
"""
assert np.all(isinstance(i, Variable) for i in inputs)
if not all(isinstance(i, Variable) for i in inputs):
raise TypeError("Inputs must be `Variable` instances")
# Check that the number of inputs to the Scan node corresponds to
# the number of inputs of the inner function of scan
n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1
......@@ -733,123 +848,15 @@ class Scan(Op, ScanMethodsMixin):
+ len(self.inner_shared(self.inputs))
+ len(self.inner_non_seqs(self.inputs))
)
assert n_outer_ins == n_inner_ins, (
"The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan."
)
# Force the inputs to be on the CPU
new_inputs = [as_tensor_variable(inputs[0])]
# assert dtype is consistent
err_msg1 = (
"When compiling the inner function of scan (the "
"function called by scan in each of its iterations) "
"the following error has been encountered: The "
"%s %s (argument number %d) has dtype "
"%s and %d dimension(s). The corresponding variable "
"in the inner function of scan %s "
"however has dtype %s and %d dimension(s). This "
"variable in the inner function of scan should "
"have the same dtype and one fewer dimension "
"compared to its corresponding variable in the initial "
"state (outputs_info in scan nomenclature). For example, "
"if the inner function of scan returns a vector "
"of size d and scan uses the values of "
"the previous time-step, then the initial state in scan "
"should be a matrix of shape (1, d). "
"The first dimension of this "
"matrix corresponds to the number of previous time-steps "
"that scan uses in each of its iterations. "
"In order to solve this issue if the two variable currently "
"have the same dimensionality, you can increase the "
"dimensionality of the varialbe in the initial state of scan "
"by using dimshuffle or shape_padleft. "
)
err_msg2 = (
"When compiling the inner function of scan the "
"following error has been encountered: The "
"initial state (`outputs_info` in scan nomenclature) "
"of variable %s (argument number %d) "
"has dtype %s, while the result of the inner function "
"(`fn`) has dtype %s. This can happen if the inner "
"function of scan results in an upcast or downcast."
)
err_msg3 = (
"When compiling the inner function of scan (the "
"function called by scan in each of its iterations) "
"the following error has been encountered: The "
"initial state (`outputs_info` in scan nomenclature) "
"of variable %s (argument number %d) has %d dimension(s), "
"while the corresponding variable in the result of the inner "
"function of scan (`fn`) has %d dimension(s) (it should "
"be one less than the initial state). For example, "
"if the inner function of scan returns a vector "
"of size d and scan uses the values of "
"the previous time-step, then the initial state in scan "
"should be a matrix of shape (1, d). "
"The first dimension of this "
"matrix corresponds to the number of previous time-steps "
"that scan uses in each of its iterations. "
"In order to solve this issue if the two varialbe currently "
"have the same dimensionality, you can increase the "
"dimensionality of the variable in the initial state of scan "
"by using dimshuffle or shape_padleft. "
)
def check_broadcast(v1, v2):
"""Checks that the broadcast pattern of v1 and v2.
Controls that the broadcast pattern of the variable provided as
input to `scan` matches the broadcast pattern provided in
`output_info`. It raises an error when they don't match. The
typical case is when the user provides either the input or the
`output_info` (but not both) with a dimension fixed to 1,
which may wrongly be interpreted as broadcastable.
"""
if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"):
return
msg = (
"The broadcast pattern of the output of scan (%s) is "
"inconsistent with the one provided in `output_info` "
"(%s). The output on axis %d is `%r`, but it is `%r` on "
"axis %d in `output_info`. This can happen if one of the "
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using aesara.tensor."
"{patternbroadcast,unbroadcast,addbroadcast}."
if n_outer_ins != n_inner_ins:
raise ValueError(
"The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan."
)
size = min(len(v1.broadcastable), len(v2.broadcastable))
for n, (b1, b2) in enumerate(
zip(v1.broadcastable[-size:], v2.broadcastable[-size:])
):
if b1 != b2:
a1 = n + size - len(v1.broadcastable) + 1
a2 = n + size - len(v2.broadcastable) + 1
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))
def format(var, as_var):
"""
This functions ensures that ``out`` has the same dtype as
``inp`` as well as calling filter_variable to make sure
they are both TensorType or GpuArrayType. It internally
deals with the corner case where inp.ndim + 1 = out.ndim
"""
if not hasattr(var, "dtype"):
return var
rval = var
if rval.type.dtype != as_var.type.dtype:
rval = rval.astype(as_var.type.dtype)
if rval.ndim == as_var.ndim:
rval = as_var.type.filter_variable(rval)
else:
tmp = as_var.type.clone(
broadcastable=(
tuple(var.broadcastable[:1]) + tuple(as_var.broadcastable)
)
)
rval = tmp.filter_variable(rval)
return rval
# Force the inputs to be on the CPU
new_inputs = [as_tensor_variable(inputs[0])]
# Check if input sequences and variables representing a slice of
# them have the same dtype
......@@ -858,7 +865,7 @@ class Scan(Op, ScanMethodsMixin):
self.inner_seqs(self.inputs), self.outer_seqs(inputs)
):
check_broadcast(outer_seq, inner_seq)
new_inputs.append(format(outer_seq, as_var=inner_seq))
new_inputs.append(copy_var_format(outer_seq, as_var=inner_seq))
argoffset += len(self.outer_seqs(inputs))
# Check that this 3 things have the same dtype for mit_mot:
......@@ -872,7 +879,7 @@ class Scan(Op, ScanMethodsMixin):
for idx, (itaps, otaps, _outer_mitmot) in enumerate(
zip(self.mitmot_taps(), self.mitmot_out_taps(), self.outer_mitmot(inputs))
):
outer_mitmot = format(_outer_mitmot, as_var=inner_mitmot[ipos])
outer_mitmot = copy_var_format(_outer_mitmot, as_var=inner_mitmot[ipos])
new_inputs.append(outer_mitmot)
for k in range(len(itaps)):
if (
......@@ -882,7 +889,7 @@ class Scan(Op, ScanMethodsMixin):
raise ValueError(
err_msg1
% (
"initial state (outputs_info" " in scan nomenclature) ",
"initial state (`outputs_info` in scan nomenclature) ",
str(outer_mitmot),
argoffset + idx,
outer_mitmot.type.dtype,
......@@ -926,7 +933,7 @@ class Scan(Op, ScanMethodsMixin):
self.inner_mitsot_outs(self.outputs),
)
):
outer_mitsot = format(_outer_mitsot, as_var=inner_mitsots[ipos])
outer_mitsot = copy_var_format(_outer_mitsot, as_var=inner_mitsots[ipos])
new_inputs.append(outer_mitsot)
for k in range(len(itaps)):
......@@ -978,7 +985,7 @@ class Scan(Op, ScanMethodsMixin):
self.inner_sitsot_outs(self.outputs),
)
):
outer_sitsot = format(_outer_sitsot, as_var=inner_sitsot)
outer_sitsot = copy_var_format(_outer_sitsot, as_var=inner_sitsot)
new_inputs.append(outer_sitsot)
if inner_sitsot.ndim != outer_sitsot.ndim - 1:
raise ValueError(
......@@ -1025,7 +1032,7 @@ class Scan(Op, ScanMethodsMixin):
self.outer_shared(inputs),
)
):
outer_shared = format(_outer_shared, as_var=inner_shared)
outer_shared = copy_var_format(_outer_shared, as_var=inner_shared)
new_inputs.append(outer_shared)
if (
hasattr(outer_shared, "dtype")
......@@ -1071,7 +1078,7 @@ class Scan(Op, ScanMethodsMixin):
inner_shared.ndim,
)
)
# We do not need to call `format` on outer_nisot arguments.
# We do not need to call `copy_var_format` on outer_nisot arguments.
# outer_nitsot stands for no input tap single output tap. This means
# these are states that do not feed anything back in the recurrent
# computation, and hence they do not have an initial state. The scan
......@@ -1083,15 +1090,14 @@ class Scan(Op, ScanMethodsMixin):
for inner_nonseq, _outer_nonseq in zip(
self.inner_non_seqs(self.inputs), self.outer_non_seqs(inputs)
):
outer_nonseq = format(_outer_nonseq, as_var=inner_nonseq)
outer_nonseq = copy_var_format(_outer_nonseq, as_var=inner_nonseq)
new_inputs.append(outer_nonseq)
if inner_nonseq.type != outer_nonseq.type:
raise ValueError(
(
"Argument %s given to scan node does not"
" match its correspondence %s"
f"Argument {outer_nonseq} given to the scan node does not"
f" match its corresponding loop function variable {inner_nonseq}"
)
% (str(outer_nonseq), str(inner_nonseq))
)
for outer_nitsot in self.outer_nitsot(inputs):
......@@ -1102,10 +1108,8 @@ class Scan(Op, ScanMethodsMixin):
str(outer_nitsot.type.dtype) not in integer_dtypes
or outer_nitsot.ndim != 0
):
raise ValueError(
"For output %s you need to provide a " "scalar int !",
str(outer_nitsot),
)
raise ValueError(f"A scalar int is required for output {outer_nitsot}")
assert len(new_inputs) == len(inputs)
# The vector_seqs and vector_outs are just a workaround
......@@ -1505,24 +1509,16 @@ class Scan(Op, ScanMethodsMixin):
# History, in the past, this was used for backward
# scan. Now we reverse the inputs outside of scan.
raise IndexError(
f"Scan was asked to run for negative number of step {int(n_steps)}"
f"Scan was asked to run for negative number of step {n_steps}"
)
elif n_steps == 0:
raise NotImplementedError(
"We didn't implemented yet the case where scan do 0 iteration"
)
raise NotImplementedError("n_steps == 0")
else:
for idx, seq in enumerate(inputs[1 : self.seqs_arg_offset]):
if seq.shape[0] < n_steps:
raise ValueError(
(
"Sequence is shorter then the required "
"number of steps : (n_steps, seq, "
"seq.shape):"
),
n_steps,
node.inputs[1 + idx],
seq.shape,
f"Sequence {idx} has shape {seq.shape} "
f"but the Scan's required number of steps is {n_steps}"
)
seqs.append(seq)
......@@ -2267,7 +2263,7 @@ class Scan(Op, ScanMethodsMixin):
if str(g_y.dtype) in integer_dtypes:
raise TypeError(
"Gradients may never be integers but g_y "
"has type " + str(g_y.type)
f"has type {g_y.type}"
)
out_indices = [get_out_idx(self_outputs.index(y)) for y in y_s]
......
......@@ -58,7 +58,7 @@ from aesara.link.utils import raise_with_op
def get_version():
return 0.299
return 0.300
@cython.boundscheck(False)
def perform(
......@@ -206,16 +206,18 @@ def perform(
"Scan was asked to run for negative number of step %d" %
n_steps)
elif n_steps == 0:
raise NotImplementedError(
"We didn't implemented yet the case where scan do 0 iteration")
raise NotImplementedError("n_steps == 0")
else:
for idx in range(n_seqs):
if args[<unsigned int>(1+idx)].shape[0] < n_steps:
raise ValueError(('Sequence is shorter than the required '
'number of steps : (n_steps, seq, '
'seq.shape):'), n_steps,
args[1+idx],
args[1+idx].shape)
raise ValueError((
"Sequence %s has shape %s "
"but the Scan's required number of steps is %s"
) % (
idx,
args[1+idx].shape,
n_steps,
))
# 2. Allocate memory for the outputs. Construct the list:
# store_steps -- map containing the length of each output
# pos -- map containing the current position of each output
......
......@@ -21,7 +21,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.299 # must match constant returned in function get_version()
version = 0.300 # must match constant returned in function get_version()
need_reload = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论