提交 93eb73fd authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor exceptions in Scan Op

上级 3cccea1f
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -89,6 +89,119 @@ from aesara.tensor.var import TensorVariable ...@@ -89,6 +89,119 @@ from aesara.tensor.var import TensorVariable
_logger = logging.getLogger("aesara.scan.op") _logger = logging.getLogger("aesara.scan.op")
err_msg1 = (
"When compiling the inner function of scan (the "
"function called by scan in each of its iterations) "
"the following error has been encountered: The "
"%s %s (argument number %d) has dtype "
"%s and %d dimension(s). The corresponding variable "
"in the inner function of scan %s "
"however has dtype %s and %d dimension(s). This "
"variable in the inner function of scan should "
"have the same dtype and one fewer dimension "
"compared to its corresponding variable in the initial "
"state (outputs_info in scan nomenclature). For example, "
"if the inner function of scan returns a vector "
"of size d and scan uses the values of "
"the previous time-step, then the initial state in scan "
"should be a matrix of shape (1, d). "
"The first dimension of this "
"matrix corresponds to the number of previous time-steps "
"that scan uses in each of its iterations. "
"In order to solve this issue if the two variable currently "
"have the same dimensionality, you can increase the "
"dimensionality of the varialbe in the initial state of scan "
"by using dimshuffle or shape_padleft. "
)
err_msg2 = (
"When compiling the inner function of scan the "
"following error has been encountered: The "
"initial state (`outputs_info` in scan nomenclature) "
"of variable %s (argument number %d) "
"has dtype %s, while the result of the inner function "
"(`fn`) has dtype %s. This can happen if the inner "
"function of scan results in an upcast or downcast."
)
err_msg3 = (
"When compiling the inner function of scan (the "
"function called by scan in each of its iterations) "
"the following error has been encountered: The "
"initial state (`outputs_info` in scan nomenclature) "
"of variable %s (argument number %d) has %d dimension(s), "
"while the corresponding variable in the result of the inner "
"function of scan (`fn`) has %d dimension(s) (it should "
"be one less than the initial state). For example, "
"if the inner function of scan returns a vector "
"of size d and scan uses the values of "
"the previous time-step, then the initial state in scan "
"should be a matrix of shape (1, d). "
"The first dimension of this "
"matrix corresponds to the number of previous time-steps "
"that scan uses in each of its iterations. "
"In order to solve this issue if the two varialbe currently "
"have the same dimensionality, you can increase the "
"dimensionality of the variable in the initial state of scan "
"by using dimshuffle or shape_padleft. "
)
def check_broadcast(v1, v2):
"""Checks that the broadcast pattern of v1 and v2.
Controls that the broadcast pattern of the variable provided as
input to `scan` matches the broadcast pattern provided in
`output_info`. It raises an error when they don't match. The
typical case is when the user provides either the input or the
`output_info` (but not both) with a dimension fixed to 1,
which may wrongly be interpreted as broadcastable.
"""
if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"):
return
msg = (
"The broadcast pattern of the output of scan (%s) is "
"inconsistent with the one provided in `output_info` "
"(%s). The output on axis %d is `%r`, but it is `%r` on "
"axis %d in `output_info`. This can happen if one of the "
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using aesara.tensor."
"{patternbroadcast,unbroadcast,addbroadcast}."
)
size = min(len(v1.broadcastable), len(v2.broadcastable))
for n, (b1, b2) in enumerate(
zip(v1.broadcastable[-size:], v2.broadcastable[-size:])
):
if b1 != b2:
a1 = n + size - len(v1.broadcastable) + 1
a2 = n + size - len(v2.broadcastable) + 1
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))
def copy_var_format(var, as_var):
"""
This functions ensures that ``var`` has the same dtype as ``as_var`` as
well as calling `filter_variable` to make sure they are both `TensorType`
or `GpuArrayType`.
It internally deals with the corner case where ``inp.ndim + 1 = out.ndim``.
"""
if not hasattr(var, "dtype"):
return var
rval = var
if rval.type.dtype != as_var.type.dtype:
rval = rval.astype(as_var.type.dtype)
if rval.ndim == as_var.ndim:
rval = as_var.type.filter_variable(rval)
else:
tmp = as_var.type.clone(
broadcastable=(tuple(var.broadcastable[:1]) + tuple(as_var.broadcastable))
)
rval = tmp.filter_variable(rval)
return rval
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ScanInfo: class ScanInfo:
tap_array: tuple tap_array: tuple
...@@ -721,7 +834,9 @@ class Scan(Op, ScanMethodsMixin): ...@@ -721,7 +834,9 @@ class Scan(Op, ScanMethodsMixin):
the inner function) the inner function)
""" """
assert np.all(isinstance(i, Variable) for i in inputs) if not all(isinstance(i, Variable) for i in inputs):
raise TypeError("Inputs must be `Variable` instances")
# Check that the number of inputs to the Scan node corresponds to # Check that the number of inputs to the Scan node corresponds to
# the number of inputs of the inner function of scan # the number of inputs of the inner function of scan
n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1 n_outer_ins = len(inputs) - len(self.outer_nitsot(inputs)) - 1
...@@ -733,123 +848,15 @@ class Scan(Op, ScanMethodsMixin): ...@@ -733,123 +848,15 @@ class Scan(Op, ScanMethodsMixin):
+ len(self.inner_shared(self.inputs)) + len(self.inner_shared(self.inputs))
+ len(self.inner_non_seqs(self.inputs)) + len(self.inner_non_seqs(self.inputs))
) )
assert n_outer_ins == n_inner_ins, (
"The number of inputs given to the inner function of scan"
" does not match the number of inputs given to scan."
)
# Force the inputs to be on the CPU
new_inputs = [as_tensor_variable(inputs[0])]
# assert dtype is consistent
err_msg1 = (
"When compiling the inner function of scan (the "
"function called by scan in each of its iterations) "
"the following error has been encountered: The "
"%s %s (argument number %d) has dtype "
"%s and %d dimension(s). The corresponding variable "
"in the inner function of scan %s "
"however has dtype %s and %d dimension(s). This "
"variable in the inner function of scan should "
"have the same dtype and one fewer dimension "
"compared to its corresponding variable in the initial "
"state (outputs_info in scan nomenclature). For example, "
"if the inner function of scan returns a vector "
"of size d and scan uses the values of "
"the previous time-step, then the initial state in scan "
"should be a matrix of shape (1, d). "
"The first dimension of this "
"matrix corresponds to the number of previous time-steps "
"that scan uses in each of its iterations. "
"In order to solve this issue if the two variable currently "
"have the same dimensionality, you can increase the "
"dimensionality of the varialbe in the initial state of scan "
"by using dimshuffle or shape_padleft. "
)
err_msg2 = (
"When compiling the inner function of scan the "
"following error has been encountered: The "
"initial state (`outputs_info` in scan nomenclature) "
"of variable %s (argument number %d) "
"has dtype %s, while the result of the inner function "
"(`fn`) has dtype %s. This can happen if the inner "
"function of scan results in an upcast or downcast."
)
err_msg3 = (
"When compiling the inner function of scan (the "
"function called by scan in each of its iterations) "
"the following error has been encountered: The "
"initial state (`outputs_info` in scan nomenclature) "
"of variable %s (argument number %d) has %d dimension(s), "
"while the corresponding variable in the result of the inner "
"function of scan (`fn`) has %d dimension(s) (it should "
"be one less than the initial state). For example, "
"if the inner function of scan returns a vector "
"of size d and scan uses the values of "
"the previous time-step, then the initial state in scan "
"should be a matrix of shape (1, d). "
"The first dimension of this "
"matrix corresponds to the number of previous time-steps "
"that scan uses in each of its iterations. "
"In order to solve this issue if the two varialbe currently "
"have the same dimensionality, you can increase the "
"dimensionality of the variable in the initial state of scan "
"by using dimshuffle or shape_padleft. "
)
def check_broadcast(v1, v2): if n_outer_ins != n_inner_ins:
"""Checks that the broadcast pattern of v1 and v2. raise ValueError(
"The number of inputs given to the inner function of scan"
Controls that the broadcast pattern of the variable provided as " does not match the number of inputs given to scan."
input to `scan` matches the broadcast pattern provided in
`output_info`. It raises an error when they don't match. The
typical case is when the user provides either the input or the
`output_info` (but not both) with a dimension fixed to 1,
which may wrongly be interpreted as broadcastable.
"""
if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"):
return
msg = (
"The broadcast pattern of the output of scan (%s) is "
"inconsistent with the one provided in `output_info` "
"(%s). The output on axis %d is `%r`, but it is `%r` on "
"axis %d in `output_info`. This can happen if one of the "
"dimension is fixed to 1 in the input, while it is still "
"variable in the output, or vice-verca. You have to make "
"them consistent, e.g. using aesara.tensor."
"{patternbroadcast,unbroadcast,addbroadcast}."
) )
size = min(len(v1.broadcastable), len(v2.broadcastable))
for n, (b1, b2) in enumerate( # Force the inputs to be on the CPU
zip(v1.broadcastable[-size:], v2.broadcastable[-size:]) new_inputs = [as_tensor_variable(inputs[0])]
):
if b1 != b2:
a1 = n + size - len(v1.broadcastable) + 1
a2 = n + size - len(v2.broadcastable) + 1
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))
def format(var, as_var):
"""
This functions ensures that ``out`` has the same dtype as
``inp`` as well as calling filter_variable to make sure
they are both TensorType or GpuArrayType. It internally
deals with the corner case where inp.ndim + 1 = out.ndim
"""
if not hasattr(var, "dtype"):
return var
rval = var
if rval.type.dtype != as_var.type.dtype:
rval = rval.astype(as_var.type.dtype)
if rval.ndim == as_var.ndim:
rval = as_var.type.filter_variable(rval)
else:
tmp = as_var.type.clone(
broadcastable=(
tuple(var.broadcastable[:1]) + tuple(as_var.broadcastable)
)
)
rval = tmp.filter_variable(rval)
return rval
# Check if input sequences and variables representing a slice of # Check if input sequences and variables representing a slice of
# them have the same dtype # them have the same dtype
...@@ -858,7 +865,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -858,7 +865,7 @@ class Scan(Op, ScanMethodsMixin):
self.inner_seqs(self.inputs), self.outer_seqs(inputs) self.inner_seqs(self.inputs), self.outer_seqs(inputs)
): ):
check_broadcast(outer_seq, inner_seq) check_broadcast(outer_seq, inner_seq)
new_inputs.append(format(outer_seq, as_var=inner_seq)) new_inputs.append(copy_var_format(outer_seq, as_var=inner_seq))
argoffset += len(self.outer_seqs(inputs)) argoffset += len(self.outer_seqs(inputs))
# Check that this 3 things have the same dtype for mit_mot: # Check that this 3 things have the same dtype for mit_mot:
...@@ -872,7 +879,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -872,7 +879,7 @@ class Scan(Op, ScanMethodsMixin):
for idx, (itaps, otaps, _outer_mitmot) in enumerate( for idx, (itaps, otaps, _outer_mitmot) in enumerate(
zip(self.mitmot_taps(), self.mitmot_out_taps(), self.outer_mitmot(inputs)) zip(self.mitmot_taps(), self.mitmot_out_taps(), self.outer_mitmot(inputs))
): ):
outer_mitmot = format(_outer_mitmot, as_var=inner_mitmot[ipos]) outer_mitmot = copy_var_format(_outer_mitmot, as_var=inner_mitmot[ipos])
new_inputs.append(outer_mitmot) new_inputs.append(outer_mitmot)
for k in range(len(itaps)): for k in range(len(itaps)):
if ( if (
...@@ -882,7 +889,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -882,7 +889,7 @@ class Scan(Op, ScanMethodsMixin):
raise ValueError( raise ValueError(
err_msg1 err_msg1
% ( % (
"initial state (outputs_info" " in scan nomenclature) ", "initial state (`outputs_info` in scan nomenclature) ",
str(outer_mitmot), str(outer_mitmot),
argoffset + idx, argoffset + idx,
outer_mitmot.type.dtype, outer_mitmot.type.dtype,
...@@ -926,7 +933,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -926,7 +933,7 @@ class Scan(Op, ScanMethodsMixin):
self.inner_mitsot_outs(self.outputs), self.inner_mitsot_outs(self.outputs),
) )
): ):
outer_mitsot = format(_outer_mitsot, as_var=inner_mitsots[ipos]) outer_mitsot = copy_var_format(_outer_mitsot, as_var=inner_mitsots[ipos])
new_inputs.append(outer_mitsot) new_inputs.append(outer_mitsot)
for k in range(len(itaps)): for k in range(len(itaps)):
...@@ -978,7 +985,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -978,7 +985,7 @@ class Scan(Op, ScanMethodsMixin):
self.inner_sitsot_outs(self.outputs), self.inner_sitsot_outs(self.outputs),
) )
): ):
outer_sitsot = format(_outer_sitsot, as_var=inner_sitsot) outer_sitsot = copy_var_format(_outer_sitsot, as_var=inner_sitsot)
new_inputs.append(outer_sitsot) new_inputs.append(outer_sitsot)
if inner_sitsot.ndim != outer_sitsot.ndim - 1: if inner_sitsot.ndim != outer_sitsot.ndim - 1:
raise ValueError( raise ValueError(
...@@ -1025,7 +1032,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1025,7 +1032,7 @@ class Scan(Op, ScanMethodsMixin):
self.outer_shared(inputs), self.outer_shared(inputs),
) )
): ):
outer_shared = format(_outer_shared, as_var=inner_shared) outer_shared = copy_var_format(_outer_shared, as_var=inner_shared)
new_inputs.append(outer_shared) new_inputs.append(outer_shared)
if ( if (
hasattr(outer_shared, "dtype") hasattr(outer_shared, "dtype")
...@@ -1071,7 +1078,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1071,7 +1078,7 @@ class Scan(Op, ScanMethodsMixin):
inner_shared.ndim, inner_shared.ndim,
) )
) )
# We do not need to call `format` on outer_nisot arguments. # We do not need to call `copy_var_format` on outer_nisot arguments.
# outer_nitsot stands for no input tap single output tap. This means # outer_nitsot stands for no input tap single output tap. This means
# these are states that do not feed anything back in the recurrent # these are states that do not feed anything back in the recurrent
# computation, and hence they do not have an initial state. The scan # computation, and hence they do not have an initial state. The scan
...@@ -1083,15 +1090,14 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1083,15 +1090,14 @@ class Scan(Op, ScanMethodsMixin):
for inner_nonseq, _outer_nonseq in zip( for inner_nonseq, _outer_nonseq in zip(
self.inner_non_seqs(self.inputs), self.outer_non_seqs(inputs) self.inner_non_seqs(self.inputs), self.outer_non_seqs(inputs)
): ):
outer_nonseq = format(_outer_nonseq, as_var=inner_nonseq) outer_nonseq = copy_var_format(_outer_nonseq, as_var=inner_nonseq)
new_inputs.append(outer_nonseq) new_inputs.append(outer_nonseq)
if inner_nonseq.type != outer_nonseq.type: if inner_nonseq.type != outer_nonseq.type:
raise ValueError( raise ValueError(
( (
"Argument %s given to scan node does not" f"Argument {outer_nonseq} given to the scan node does not"
" match its correspondence %s" f" match its corresponding loop function variable {inner_nonseq}"
) )
% (str(outer_nonseq), str(inner_nonseq))
) )
for outer_nitsot in self.outer_nitsot(inputs): for outer_nitsot in self.outer_nitsot(inputs):
...@@ -1102,10 +1108,8 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1102,10 +1108,8 @@ class Scan(Op, ScanMethodsMixin):
str(outer_nitsot.type.dtype) not in integer_dtypes str(outer_nitsot.type.dtype) not in integer_dtypes
or outer_nitsot.ndim != 0 or outer_nitsot.ndim != 0
): ):
raise ValueError( raise ValueError(f"A scalar int is required for output {outer_nitsot}")
"For output %s you need to provide a " "scalar int !",
str(outer_nitsot),
)
assert len(new_inputs) == len(inputs) assert len(new_inputs) == len(inputs)
# The vector_seqs and vector_outs are just a workaround # The vector_seqs and vector_outs are just a workaround
...@@ -1505,24 +1509,16 @@ class Scan(Op, ScanMethodsMixin): ...@@ -1505,24 +1509,16 @@ class Scan(Op, ScanMethodsMixin):
# History, in the past, this was used for backward # History, in the past, this was used for backward
# scan. Now we reverse the inputs outside of scan. # scan. Now we reverse the inputs outside of scan.
raise IndexError( raise IndexError(
f"Scan was asked to run for negative number of step {int(n_steps)}" f"Scan was asked to run for negative number of step {n_steps}"
) )
elif n_steps == 0: elif n_steps == 0:
raise NotImplementedError( raise NotImplementedError("n_steps == 0")
"We didn't implemented yet the case where scan do 0 iteration"
)
else: else:
for idx, seq in enumerate(inputs[1 : self.seqs_arg_offset]): for idx, seq in enumerate(inputs[1 : self.seqs_arg_offset]):
if seq.shape[0] < n_steps: if seq.shape[0] < n_steps:
raise ValueError( raise ValueError(
( f"Sequence {idx} has shape {seq.shape} "
"Sequence is shorter then the required " f"but the Scan's required number of steps is {n_steps}"
"number of steps : (n_steps, seq, "
"seq.shape):"
),
n_steps,
node.inputs[1 + idx],
seq.shape,
) )
seqs.append(seq) seqs.append(seq)
...@@ -2267,7 +2263,7 @@ class Scan(Op, ScanMethodsMixin): ...@@ -2267,7 +2263,7 @@ class Scan(Op, ScanMethodsMixin):
if str(g_y.dtype) in integer_dtypes: if str(g_y.dtype) in integer_dtypes:
raise TypeError( raise TypeError(
"Gradients may never be integers but g_y " "Gradients may never be integers but g_y "
"has type " + str(g_y.type) f"has type {g_y.type}"
) )
out_indices = [get_out_idx(self_outputs.index(y)) for y in y_s] out_indices = [get_out_idx(self_outputs.index(y)) for y in y_s]
......
...@@ -58,7 +58,7 @@ from aesara.link.utils import raise_with_op ...@@ -58,7 +58,7 @@ from aesara.link.utils import raise_with_op
def get_version(): def get_version():
return 0.299 return 0.300
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
...@@ -206,16 +206,18 @@ def perform( ...@@ -206,16 +206,18 @@ def perform(
"Scan was asked to run for negative number of step %d" % "Scan was asked to run for negative number of step %d" %
n_steps) n_steps)
elif n_steps == 0: elif n_steps == 0:
raise NotImplementedError( raise NotImplementedError("n_steps == 0")
"We didn't implemented yet the case where scan do 0 iteration")
else: else:
for idx in range(n_seqs): for idx in range(n_seqs):
if args[<unsigned int>(1+idx)].shape[0] < n_steps: if args[<unsigned int>(1+idx)].shape[0] < n_steps:
raise ValueError(('Sequence is shorter than the required ' raise ValueError((
'number of steps : (n_steps, seq, ' "Sequence %s has shape %s "
'seq.shape):'), n_steps, "but the Scan's required number of steps is %s"
args[1+idx], ) % (
args[1+idx].shape) idx,
args[1+idx].shape,
n_steps,
))
# 2. Allocate memory for the outputs. Construct the list: # 2. Allocate memory for the outputs. Construct the list:
# store_steps -- map containing the length of each output # store_steps -- map containing the length of each output
# pos -- map containing the current position of each output # pos -- map containing the current position of each output
......
...@@ -21,7 +21,7 @@ if not config.cxx: ...@@ -21,7 +21,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform") _logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.299 # must match constant returned in function get_version() version = 0.300 # must match constant returned in function get_version()
need_reload = False need_reload = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论