提交 83e5f3ef authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix exceptions in aesara.link.vm

上级 75239913
...@@ -166,7 +166,8 @@ class VM: ...@@ -166,7 +166,8 @@ class VM:
def __init__(self, fgraph, nodes, thunks, pre_call_clear): def __init__(self, fgraph, nodes, thunks, pre_call_clear):
if len(nodes) != len(thunks): if len(nodes) != len(thunks):
raise ValueError() raise ValueError("`nodes` and `thunks` must be the same length")
self.fgraph = fgraph self.fgraph = fgraph
self.nodes = nodes self.nodes = nodes
self.thunks = thunks self.thunks = thunks
...@@ -188,7 +189,7 @@ class VM: ...@@ -188,7 +189,7 @@ class VM:
what exactly this means and how it is done. what exactly this means and how it is done.
""" """
raise NotImplementedError("override me") raise NotImplementedError()
def clear_storage(self): def clear_storage(self):
""" """
...@@ -199,7 +200,7 @@ class VM: ...@@ -199,7 +200,7 @@ class VM:
calls. calls.
""" """
raise NotImplementedError("override me") raise NotImplementedError()
def update_profile(self, profile): def update_profile(self, profile):
""" """
...@@ -282,7 +283,9 @@ class LoopGC(VM): ...@@ -282,7 +283,9 @@ class LoopGC(VM):
# Some other part of Aesara query that information # Some other part of Aesara query that information
self.allow_gc = True self.allow_gc = True
if not (len(nodes) == len(thunks) == len(post_thunk_clear)): if not (len(nodes) == len(thunks) == len(post_thunk_clear)):
raise ValueError() raise ValueError(
"`nodes`, `thunks` and `post_thunk_clear` are not the same lengths"
)
def __call__(self): def __call__(self):
if self.time_thunks: if self.time_thunks:
...@@ -1138,13 +1141,9 @@ class VMLinker(LocalLinker): ...@@ -1138,13 +1141,9 @@ class VMLinker(LocalLinker):
# So if they didn't specify that its lazy or not, it isn't. # So if they didn't specify that its lazy or not, it isn't.
# If this member isn't present, it will crash later. # If this member isn't present, it will crash later.
thunks[-1].lazy = False thunks[-1].lazy = False
except Exception as e: except Exception:
e.args = ( raise_with_op(fgraph, node)
"The following error happened while" " compiling the node",
node,
"\n",
) + e.args
raise
t1 = time.time() t1 = time.time()
if self.profile: if self.profile:
......
...@@ -11,16 +11,26 @@ from aesara.compile.mode import Mode, get_mode ...@@ -11,16 +11,26 @@ from aesara.compile.mode import Mode, get_mode
from aesara.compile.sharedvalue import shared from aesara.compile.sharedvalue import shared
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.ifelse import ifelse from aesara.ifelse import ifelse
from aesara.link.c.basic import OpWiseCLinker from aesara.link.c.basic import OpWiseCLinker
from aesara.link.c.exceptions import MissingGXX from aesara.link.c.exceptions import MissingGXX
from aesara.link.vm import Loop, VMLinker from aesara.link.utils import map_storage
from aesara.link.vm import VM, Loop, LoopGC, VMLinker
from aesara.tensor.math import cosh, sin, tanh from aesara.tensor.math import cosh, sin, tanh
from aesara.tensor.type import dvector, lscalar, scalar, scalars, vector, vectors from aesara.tensor.type import dvector, lscalar, scalar, scalars, vector, vectors
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant
class SomeOp(Op):
def perform(self, node, inputs, outputs):
pass
def make_node(self, x):
return Apply(self, [x], [x.type()])
class TestCallbacks: class TestCallbacks:
# Test the `VMLinker`'s callback argument, which can be useful for debugging. # Test the `VMLinker`'s callback argument, which can be useful for debugging.
...@@ -494,3 +504,58 @@ def test_VMLinker_make_vm_no_cvm(): ...@@ -494,3 +504,58 @@ def test_VMLinker_make_vm_no_cvm():
f = function([a], a, mode=Mode(optimizer=None, linker=linker)) f = function([a], a, mode=Mode(optimizer=None, linker=linker))
assert isinstance(f.fn, Loop) assert isinstance(f.fn, Loop)
def test_VMLinker_exception():
class BadOp(Op):
def perform(self, node, inputs, outputs):
pass
def make_node(self, x):
return Apply(self, [x], [x.type()])
def make_thunk(self, *args, **kwargs):
raise Exception("bad Op")
a = scalar()
linker = VMLinker(allow_gc=False, use_cloop=True)
z = BadOp()(a)
with pytest.raises(Exception, match=".*Apply node that caused the error.*"):
function([a], z, mode=Mode(optimizer=None, linker=linker))
def test_VM_exception():
class SomeVM(VM):
def __call__(self):
pass
a = scalar()
fg = FunctionGraph(outputs=[SomeOp()(a)])
with pytest.raises(ValueError, match="`nodes` and `thunks`.*"):
SomeVM(fg, fg.apply_nodes, [], [])
def test_LoopGC_exception():
a = scalar()
fg = FunctionGraph(outputs=[SomeOp()(a)])
# Create valid(ish) `VM` arguments
nodes = fg.toposort()
input_storage, output_storage, storage_map = map_storage(
fg, nodes, None, None, None
)
compute_map = {}
for k in storage_map:
compute_map[k] = [k.owner is None]
thunks = [
node.op.make_thunk(node, storage_map, compute_map, True) for node in nodes
]
with pytest.raises(ValueError, match="`nodes`, `thunks` and `post_thunk_clear`.*"):
LoopGC(fg, fg.apply_nodes, thunks, [], [])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论