提交 b73adff1 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Spread the good word and add a connection_pattern.

上级 b20cca75
......@@ -1564,6 +1564,7 @@ class Assert(T.Op):
used in the function computing the graph, but it doesn't have to be
returned.
"""
__props__ = ('msg',)
view_map = {0: [0]}
check_input = False
......@@ -1583,24 +1584,18 @@ class Assert(T.Op):
assert numpy.all([c.type.ndim == 0 for c in cond])
return gof.Apply(self, [value] + cond, [value.type()])
def __str__(self):
return self.__class__.__name__
def perform(self, node, inputs, out_):
out, = out_
v = inputs[0]
out[0] = v
assert numpy.all(inputs[1:]), self.msg
def __eq__(self, other):
return type(self) == type(other) and self.msg == other.msg
def __hash__(self):
return hash(type(self)) ^ hash(self.msg)
def grad(self, input, output_gradients):
return output_gradients + [DisconnectedType()()] * (len(input) - 1)
def connection_pattern(self, node):
return [[1]] + [[0]] * (len(node.inputs) - 1)
def c_code(self, node, name, inames, onames, sub):
value = inames[0]
out = onames[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论