提交 83e5f3ef authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix exceptions in aesara.link.vm

上级 75239913
......@@ -166,7 +166,8 @@ class VM:
def __init__(self, fgraph, nodes, thunks, pre_call_clear):
if len(nodes) != len(thunks):
raise ValueError()
raise ValueError("`nodes` and `thunks` must be the same length")
self.fgraph = fgraph
self.nodes = nodes
self.thunks = thunks
......@@ -188,7 +189,7 @@ class VM:
what exactly this means and how it is done.
"""
raise NotImplementedError("override me")
raise NotImplementedError()
def clear_storage(self):
"""
......@@ -199,7 +200,7 @@ class VM:
calls.
"""
raise NotImplementedError("override me")
raise NotImplementedError()
def update_profile(self, profile):
"""
......@@ -282,7 +283,9 @@ class LoopGC(VM):
# Some other part of Aesara query that information
self.allow_gc = True
if not (len(nodes) == len(thunks) == len(post_thunk_clear)):
raise ValueError()
raise ValueError(
"`nodes`, `thunks` and `post_thunk_clear` are not the same lengths"
)
def __call__(self):
if self.time_thunks:
......@@ -1138,13 +1141,9 @@ class VMLinker(LocalLinker):
# So if they didn't specify that its lazy or not, it isn't.
# If this member isn't present, it will crash later.
thunks[-1].lazy = False
except Exception as e:
e.args = (
"The following error happened while" " compiling the node",
node,
"\n",
) + e.args
raise
except Exception:
raise_with_op(fgraph, node)
t1 = time.time()
if self.profile:
......
......@@ -11,16 +11,26 @@ from aesara.compile.mode import Mode, get_mode
from aesara.compile.sharedvalue import shared
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.ifelse import ifelse
from aesara.link.c.basic import OpWiseCLinker
from aesara.link.c.exceptions import MissingGXX
from aesara.link.vm import Loop, VMLinker
from aesara.link.utils import map_storage
from aesara.link.vm import VM, Loop, LoopGC, VMLinker
from aesara.tensor.math import cosh, sin, tanh
from aesara.tensor.type import dvector, lscalar, scalar, scalars, vector, vectors
from aesara.tensor.var import TensorConstant
class SomeOp(Op):
def perform(self, node, inputs, outputs):
pass
def make_node(self, x):
return Apply(self, [x], [x.type()])
class TestCallbacks:
# Test the `VMLinker`'s callback argument, which can be useful for debugging.
......@@ -494,3 +504,58 @@ def test_VMLinker_make_vm_no_cvm():
f = function([a], a, mode=Mode(optimizer=None, linker=linker))
assert isinstance(f.fn, Loop)
def test_VMLinker_exception():
class BadOp(Op):
def perform(self, node, inputs, outputs):
pass
def make_node(self, x):
return Apply(self, [x], [x.type()])
def make_thunk(self, *args, **kwargs):
raise Exception("bad Op")
a = scalar()
linker = VMLinker(allow_gc=False, use_cloop=True)
z = BadOp()(a)
with pytest.raises(Exception, match=".*Apply node that caused the error.*"):
function([a], z, mode=Mode(optimizer=None, linker=linker))
def test_VM_exception():
class SomeVM(VM):
def __call__(self):
pass
a = scalar()
fg = FunctionGraph(outputs=[SomeOp()(a)])
with pytest.raises(ValueError, match="`nodes` and `thunks`.*"):
SomeVM(fg, fg.apply_nodes, [], [])
def test_LoopGC_exception():
a = scalar()
fg = FunctionGraph(outputs=[SomeOp()(a)])
# Create valid(ish) `VM` arguments
nodes = fg.toposort()
input_storage, output_storage, storage_map = map_storage(
fg, nodes, None, None, None
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = [
node.op.make_thunk(node, storage_map, compute_map, True) for node in nodes
]
with pytest.raises(ValueError, match="`nodes`, `thunks` and `post_thunk_clear`.*"):
LoopGC(fg, fg.apply_nodes, thunks, [], [])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论