提交 14c7373e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix attribute error in Cython version of Scan's exception handling

上级 17ba075e
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -64,7 +64,7 @@ from aesara.link.utils import raise_with_op ...@@ -64,7 +64,7 @@ from aesara.link.utils import raise_with_op
def get_version(): def get_version():
return 0.298 return 0.299
@cython.boundscheck(False) @cython.boundscheck(False)
def perform( def perform(
...@@ -153,9 +153,7 @@ def perform( ...@@ -153,9 +153,7 @@ def perform(
starting point of implementing this function in C ( we need to take starting point of implementing this function in C ( we need to take
all the code around the call of this function and put in C inside all the code around the call of this function and put in C inside
that code) that code)
fnct: python object fnct: Function
Only used to attach some timings for the profile mode ( can be
skiped if we don't care about Aesara's profile mode)
destroy_map destroy_map
Array of boolean saying if an output is computed inplace Array of boolean saying if an output is computed inplace
args: list of ndarrays (and random states) args: list of ndarrays (and random states)
...@@ -404,7 +402,7 @@ def perform( ...@@ -404,7 +402,7 @@ def perform(
# done by raise_with_op is not implemented in C. # done by raise_with_op is not implemented in C.
if hasattr(fn, 'thunks'): if hasattr(fn, 'thunks'):
# For the CVM # For the CVM
raise_with_op(fn.maker.fgraph, raise_with_op(fnct.maker.fgraph,
fn.nodes[fn.position_of_error], fn.nodes[fn.position_of_error],
fn.thunks[fn.position_of_error]) fn.thunks[fn.position_of_error])
else: else:
...@@ -412,7 +410,7 @@ def perform( ...@@ -412,7 +410,7 @@ def perform(
# We don't have access from python to all the # We don't have access from python to all the
# temps values So for now, we just don't print # temps values So for now, we just don't print
# the extra shapes/strides info # the extra shapes/strides info
raise_with_op(fn.maker.fgraph, fn.nodes[fn.position_of_error]) raise_with_op(fnct.maker.fgraph, fn.nodes[fn.position_of_error])
else: else:
# old-style linkers raise their own exceptions # old-style linkers raise their own exceptions
raise raise
...@@ -612,11 +610,11 @@ def perform( ...@@ -612,11 +610,11 @@ def perform(
# do not get applied # do not get applied
if i < n_steps: if i < n_steps:
# Cython can not handle negative indices ( because of a # Cython can not handle negative indices ( because of a
# derictive at the beginning of the function that says not # directive at the beginning of the function that says not
# to do boundschecks). The directive is used to make the # to do boundschecks). The directive is used to make the
# code faster, so this workaround is better then removing # code faster, so this workaround is better then removing
# the directive. # the directive.
sh0 = outs[idx][0].shape[0] sh0 = outs[idx][0].shape[0]
outs[idx][0] = outs[idx][0][:sh0-(n_steps - i)] outs[idx][0] = outs[idx][0][:sh0-(n_steps - i)]
...@@ -639,15 +637,5 @@ def perform( ...@@ -639,15 +637,5 @@ def perform(
if hasattr(fn, 'update_profile'): if hasattr(fn, 'update_profile'):
fn.update_profile(profile) fn.update_profile(profile)
### Old Profile Mode
#if hasattr(fnct.maker.mode,'fct_call_time'):
# fnct.maker.mode.fct_call_time[fnct] += t_fn
# fnct.maker.mode.fct_call[fnct] += n_steps
#fnct.maker.mode.call_time += t_fn
#fnct.maker.mode.fn_time += t_fn
# DEBUG PRINT :
self.t_call = t_call self.t_call = t_call
self.t_fn = t_fn self.t_fn = t_fn
# print 'Cython > timing', t_call, t_fn, 'in percentage', 100.*t_fn/t_call
...@@ -21,7 +21,7 @@ if not config.cxx: ...@@ -21,7 +21,7 @@ if not config.cxx:
_logger = logging.getLogger("aesara.scan.scan_perform") _logger = logging.getLogger("aesara.scan.scan_perform")
version = 0.298 # must match constant returned in function get_version() version = 0.299 # must match constant returned in function get_version()
need_reload = False need_reload = False
......
...@@ -37,8 +37,9 @@ from aesara.gradient import ( ...@@ -37,8 +37,9 @@ from aesara.gradient import (
hessian, hessian,
jacobian, jacobian,
) )
from aesara.graph.basic import clone_replace, graph_inputs from aesara.graph.basic import Apply, clone_replace, graph_inputs
from aesara.graph.fg import MissingInputError from aesara.graph.fg import MissingInputError
from aesara.graph.op import Op
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.scan.basic import scan from aesara.scan.basic import scan
from aesara.scan.op import Scan from aesara.scan.op import Scan
...@@ -4519,6 +4520,32 @@ class TestScan: ...@@ -4519,6 +4520,32 @@ class TestScan:
assert detect_large_outputs.large_count == 3 assert detect_large_outputs.large_count == 3
@pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test."
)
def test_cvm_exception_handling():
class MyOp(Op):
def make_node(self, input):
return Apply(self, [input], [vector()])
def perform(self, node, inputs, outputs):
raise Exception("blah")
myop = MyOp()
def scan_fn():
return myop(aet.as_tensor(1))
mode = Mode(optimizer=None, linker="cvm")
res, _ = scan(scan_fn, n_steps=4, mode=mode)
res_fn = function([], res, mode=mode)
with pytest.raises(Exception, match="blah"):
res_fn()
@pytest.mark.skipif( @pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test." not config.cxx, reason="G++ not available, so we need to skip this test."
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论