提交 f9a2e45b authored 作者: Frederic's avatar Frederic 提交者: Arnaud Bergeron

Allow transfer type to hardcode the output dtype

上级 a579c1eb
......@@ -724,7 +724,7 @@ def same_out_float_only(type):
class transfer_type(gof.utils.object2):
def __init__(self, *transfer):
assert all(type(x) == int for x in transfer)
assert all(type(x) in [int, str] or x is None for x in transfer)
self.transfer = transfer
def __str__(self):
......@@ -736,6 +736,8 @@ class transfer_type(gof.utils.object2):
for i in self.transfer:
if i is None:
retval += [upcast]
elif isinstance(i, str):
retval += [i]
else:
retval += [types[i]]
return retval
......@@ -3410,7 +3412,10 @@ class Composite(ScalarOp):
return lambda inputs: r.data
node = r.owner
producers = [compose_impl(input) for input in node.inputs]
return lambda inputs: node.op.impl(*[p(inputs) for p in producers])
def f(inputs):
return node.op.impl(*[p(inputs) for p in producers])
return f
self._impls = [compose_impl(r) for r in self.fgraph.outputs]
def init_name(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论