提交 d58d482a authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add FunctionGraph methods add_output, remove_node, remove_input, remove_output

上级 124ed5df
差异被折叠。
差异被折叠。
......@@ -46,19 +46,20 @@ def MyVariable2(name):
class MyOp(Op):
def __init__(self, name, dmap=None, x=None):
def __init__(self, name, dmap=None, x=None, n_outs=1):
self.name = name
if dmap is None:
dmap = {}
self.destroy_map = dmap
self.x = x
self.n_outs = n_outs
def make_node(self, *inputs):
inputs = list(map(is_variable, inputs))
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyType()()]
outputs = [MyType()() for i in range(self.n_outs)]
return Apply(self, inputs, outputs)
def perform(self, node, inputs, outputs):
......@@ -71,18 +72,19 @@ class MyOp(Op):
return self.name
def __eq__(self, other):
# rval = (self is other) or (isinstance(other, MyOp) and self.x is not None and self.x == other.x and self.name == other.name)
rval = (self is other) or (
isinstance(other, MyOp) and self.x is not None and self.x == other.x
isinstance(other, MyOp)
and self.x is not None
and self.x == other.x
and self.n_outs == other.n_outs
)
return rval
def __hash__(self):
# return hash(self.x if self.x is not None else id(self)) ^ hash(self.name)
if self.x is not None:
return hash(self.x)
return hash((self.x, self.n_outs))
else:
return id(self)
return hash((id(self), self.n_outs))
class MyOpCastType2(MyOp):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论