提交 14c7373e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix attribute error in Cython version of Scan's exception handling

上级 17ba075e
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -64,7 +64,7 @@ from aesara.link.utils import raise_with_op
def get_version():
return 0.298
return 0.299
@cython.boundscheck(False)
def perform(
......@@ -153,9 +153,7 @@ def perform(
starting point of implementing this function in C ( we need to take
all the code around the call of this function and put in C inside
that code)
fnct: python object
Only used to attach some timings for the profile mode ( can be
skiped if we don't care about Aesara's profile mode)
fnct: Function
destroy_map
Array of boolean saying if an output is computed inplace
args: list of ndarrays (and random states)
......@@ -404,7 +402,7 @@ def perform(
# done by raise_with_op is not implemented in C.
if hasattr(fn, 'thunks'):
# For the CVM
raise_with_op(fn.maker.fgraph,
raise_with_op(fnct.maker.fgraph,
fn.nodes[fn.position_of_error],
fn.thunks[fn.position_of_error])
else:
......@@ -412,7 +410,7 @@ def perform(
# We don't have access from python to all the
# temps values So for now, we just don't print
# the extra shapes/strides info
raise_with_op(fn.maker.fgraph, fn.nodes[fn.position_of_error])
raise_with_op(fnct.maker.fgraph, fn.nodes[fn.position_of_error])
else:
# old-style linkers raise their own exceptions
raise
......@@ -612,11 +610,11 @@ def perform(
# do not get applied
if i < n_steps:
# Cython can not handle negative indices ( because of a
# derictive at the beginning of the function that says not
# to do boundschecks). The directive is used to make the
# code faster, so this workaround is better then removing
# the directive.
# Cython can not handle negative indices ( because of a
# directive at the beginning of the function that says not
# to do boundschecks). The directive is used to make the
# code faster, so this workaround is better then removing
# the directive.
sh0 = outs[idx][0].shape[0]
outs[idx][0] = outs[idx][0][:sh0-(n_steps - i)]
......@@ -639,15 +637,5 @@ def perform(
if hasattr(fn, 'update_profile'):
fn.update_profile(profile)
### Old Profile Mode
#if hasattr(fnct.maker.mode,'fct_call_time'):
# fnct.maker.mode.fct_call_time[fnct] += t_fn
# fnct.maker.mode.fct_call[fnct] += n_steps
#fnct.maker.mode.call_time += t_fn
#fnct.maker.mode.fn_time += t_fn
# DEBUG PRINT :
self.t_call = t_call
self.t_fn = t_fn
# print 'Cython > timing', t_call, t_fn, 'in percentage', 100.*t_fn/t_call
......@@ -21,7 +21,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.298 # must match constant returned in function get_version()
version = 0.299 # must match constant returned in function get_version()
need_reload = False
......
......@@ -37,8 +37,9 @@ from aesara.gradient import (
hessian,
jacobian,
)
from aesara.graph.basic import clone_replace, graph_inputs
from aesara.graph.basic import Apply, clone_replace, graph_inputs
from aesara.graph.fg import MissingInputError
from aesara.graph.op import Op
from aesara.misc.safe_asarray import _asarray
from aesara.scan.basic import scan
from aesara.scan.op import Scan
......@@ -4519,6 +4520,32 @@ class TestScan:
assert detect_large_outputs.large_count == 3
@pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_cvm_exception_handling():
class MyOp(Op):
def make_node(self, input):
return Apply(self, [input], [vector()])
def perform(self, node, inputs, outputs):
raise Exception("blah")
myop = MyOp()
def scan_fn():
return myop(aet.as_tensor(1))
mode = Mode(optimizer=None, linker="cvm")
res, _ = scan(scan_fn, n_steps=4, mode=mode)
res_fn = function([], res, mode=mode)
with pytest.raises(Exception, match="blah"):
res_fn()
@pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test."
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论