提交 f1acf82a authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix prepare_node args.

上级 9a2dffdd
...@@ -1899,7 +1899,7 @@ class GpuDnnDropoutOp(DnnBase): ...@@ -1899,7 +1899,7 @@ class GpuDnnDropoutOp(DnnBase):
return Apply(self, [inp, descriptor, state], return Apply(self, [inp, descriptor, state],
[inp.type(), state.type(), gpudata_type()]) [inp.type(), state.type(), gpudata_type()])
def prepare_node(self, node, storage_map, compute_map): def prepare_node(self, node, storage_map, compute_map, impl):
assert self.inplace, "GpuDnnDropoutOp not inplace" assert self.inplace, "GpuDnnDropoutOp not inplace"
......
...@@ -1462,7 +1462,7 @@ class Argmax(Op): ...@@ -1462,7 +1462,7 @@ class Argmax(Op):
outputs = [tensor('int64', broadcastable, name='argmax')] outputs = [tensor('int64', broadcastable, name='argmax')]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def prepare_node(self, node, storage_map, compute_map): def prepare_node(self, node, storage_map, compute_map, impl):
if len(node.inputs) == 2: if len(node.inputs) == 2:
raise ValueError('You are trying to compile a graph with an old Argmax node. Either reoptimize your graph or rebuild it to get the new node format.') raise ValueError('You are trying to compile a graph with an old Argmax node. Either reoptimize your graph or rebuild it to get the new node format.')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论