提交 6aa4a034 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add support for custom message in Assert.

上级 4cf06d2b
......@@ -1363,6 +1363,9 @@ class Assert(T.Op):
"""
view_map = {0: [0]}
def __init__(self, msg="Theano Assert failed!"):
self.msg = msg
def make_node(self, value, *conds):
cond = [T.as_tensor_variable(c) for c in conds]
assert numpy.all([c.type.ndim == 0 for c in cond])
......@@ -1375,13 +1378,13 @@ class Assert(T.Op):
out, = out_
v = inputs[0]
out[0] = v
assert numpy.all(inputs[1:])
assert numpy.all(inputs[1:]), self.msg
def __eq__(self, other):
return type(self) == type(other)
return type(self) == type(other) and self.msg == other.msg
def __hash__(self):
return hash(type(self))
return hash(type(self)) ^ hash(self.msg)
def grad(self, input, output_gradients):
return output_gradients
......@@ -1391,12 +1394,13 @@ class Assert(T.Op):
out = onames[0]
check = []
fail = sub['fail']
msg = self.msg.replace('"', '\\"').replace('\n', '\\n')
for idx in xrange(len(inames) - 1):
i = inames[idx + 1]
dtype = node.inputs[idx + 1].dtype
check.append('if(!((npy_%(dtype)s*)PyArray_DATA(%(i)s))[0])'
'{PyErr_SetString(PyExc_AssertionError,"Theano'
' Assert failed!");%(fail)s}' % locals())
'{PyErr_SetString(PyExc_AssertionError,"%(msg)s");'
'%(fail)s}' % locals())
check = "\n".join(check)
return """
%(check)s
......@@ -1405,7 +1409,7 @@ class Assert(T.Op):
""" % locals()
def c_code_cache_version(self):
return (0, 1)
return (1, 0)
def infer_shape(self, node, input_shapes):
return [input_shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论