提交 8c95b5bf authored 作者: Frederic Bastien's avatar Frederic Bastien

give name to output variable and a better string reprensentation for MaxAndArgMax

上级 d32556f8
...@@ -638,7 +638,8 @@ class TensorType(Type): ...@@ -638,7 +638,8 @@ class TensorType(Type):
# Easy constructors # Easy constructors
def tensor(*args, **kwargs): def tensor(*args, **kwargs):
return TensorType(*args, **kwargs).make_variable() name = kwargs.get('name',None)
return TensorType(*args, **kwargs).make_variable(name=name)
def _multi(*fns): def _multi(*fns):
def f2(f, *names): def f2(f, *names):
...@@ -1307,8 +1308,8 @@ class MaxAndArgmax(Op): ...@@ -1307,8 +1308,8 @@ class MaxAndArgmax(Op):
axis = _as_tensor_variable(axis) axis = _as_tensor_variable(axis)
inputs = [x, axis] inputs = [x, axis]
broadcastable = [False] * (x.type.ndim - 1) #TODO: be less conservative broadcastable = [False] * (x.type.ndim - 1) #TODO: be less conservative
outputs = [tensor(x.type.dtype, broadcastable), outputs = [tensor(x.type.dtype, broadcastable,name='max'),
tensor('int32', broadcastable)] tensor('int32', broadcastable,name='argmax')]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def perform(self, node, (x, axis), (max, max_idx)): def perform(self, node, (x, axis), (max, max_idx)):
max[0] = numpy.asarray(numpy.max(x, axis)) max[0] = numpy.asarray(numpy.max(x, axis))
...@@ -1336,6 +1337,8 @@ class MaxAndArgmax(Op): ...@@ -1336,6 +1337,8 @@ class MaxAndArgmax(Op):
xmax_pad = shape_padright(xmax) xmax_pad = shape_padright(xmax)
g_x = eq(xmax_pad, x) * g_max_pad g_x = eq(xmax_pad, x) * g_max_pad
return g_x, None return g_x, None
def __str__(self):
return self.__class__.__name__
_max_and_argmax = MaxAndArgmax() _max_and_argmax = MaxAndArgmax()
@_redefine_asRoutine(_max_and_argmax) @_redefine_asRoutine(_max_and_argmax)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论