提交 2a9156a9 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add an exception for old argmax nodes.

上级 7061380c
...@@ -1462,6 +1462,10 @@ class Argmax(Op): ...@@ -1462,6 +1462,10 @@ class Argmax(Op):
outputs = [tensor('int64', broadcastable, name='argmax')] outputs = [tensor('int64', broadcastable, name='argmax')]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def prepare_node(self, node):
if len(node.inputs) == 2:
raise ValueError('You are trying to compile a graph with an old Argmax node. Either reoptimize your graph or rebuild it to get the new node format.')
def perform(self, node, inp, outs, params): def perform(self, node, inp, outs, params):
x, = inp x, = inp
axes = self.axis axes = self.axis
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论