提交 6aa4a034 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add support for custom message in Assert.

上级 4cf06d2b
...@@ -1363,6 +1363,9 @@ class Assert(T.Op): ...@@ -1363,6 +1363,9 @@ class Assert(T.Op):
""" """
view_map = {0: [0]} view_map = {0: [0]}
def __init__(self, msg="Theano Assert failed!"):
self.msg = msg
def make_node(self, value, *conds): def make_node(self, value, *conds):
cond = [T.as_tensor_variable(c) for c in conds] cond = [T.as_tensor_variable(c) for c in conds]
assert numpy.all([c.type.ndim == 0 for c in cond]) assert numpy.all([c.type.ndim == 0 for c in cond])
...@@ -1375,13 +1378,13 @@ class Assert(T.Op): ...@@ -1375,13 +1378,13 @@ class Assert(T.Op):
out, = out_ out, = out_
v = inputs[0] v = inputs[0]
out[0] = v out[0] = v
assert numpy.all(inputs[1:]) assert numpy.all(inputs[1:]), self.msg
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other) and self.msg == other.msg
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self)) ^ hash(self.msg)
def grad(self, input, output_gradients): def grad(self, input, output_gradients):
return output_gradients return output_gradients
...@@ -1391,12 +1394,13 @@ class Assert(T.Op): ...@@ -1391,12 +1394,13 @@ class Assert(T.Op):
out = onames[0] out = onames[0]
check = [] check = []
fail = sub['fail'] fail = sub['fail']
msg = self.msg.replace('"', '\\"').replace('\n', '\\n')
for idx in xrange(len(inames) - 1): for idx in xrange(len(inames) - 1):
i = inames[idx + 1] i = inames[idx + 1]
dtype = node.inputs[idx + 1].dtype dtype = node.inputs[idx + 1].dtype
check.append('if(!((npy_%(dtype)s*)PyArray_DATA(%(i)s))[0])' check.append('if(!((npy_%(dtype)s*)PyArray_DATA(%(i)s))[0])'
'{PyErr_SetString(PyExc_AssertionError,"Theano' '{PyErr_SetString(PyExc_AssertionError,"%(msg)s");'
' Assert failed!");%(fail)s}' % locals()) '%(fail)s}' % locals())
check = "\n".join(check) check = "\n".join(check)
return """ return """
%(check)s %(check)s
...@@ -1405,7 +1409,7 @@ class Assert(T.Op): ...@@ -1405,7 +1409,7 @@ class Assert(T.Op):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (0, 1) return (1, 0)
def infer_shape(self, node, input_shapes): def infer_shape(self, node, input_shapes):
return [input_shapes[0]] return [input_shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论