提交 8c95b5bf authored 作者: Frederic Bastien's avatar Frederic Bastien

give name to output variable and a better string reprensentation for MaxAndArgMax

上级 d32556f8
......@@ -638,7 +638,8 @@ class TensorType(Type):
# Easy constructors
def tensor(*args, **kwargs):
return TensorType(*args, **kwargs).make_variable()
name = kwargs.get('name',None)
return TensorType(*args, **kwargs).make_variable(name=name)
def _multi(*fns):
def f2(f, *names):
......@@ -1307,8 +1308,8 @@ class MaxAndArgmax(Op):
axis = _as_tensor_variable(axis)
inputs = [x, axis]
broadcastable = [False] * (x.type.ndim - 1) #TODO: be less conservative
outputs = [tensor(x.type.dtype, broadcastable),
tensor('int32', broadcastable)]
outputs = [tensor(x.type.dtype, broadcastable,name='max'),
tensor('int32', broadcastable,name='argmax')]
return Apply(self, inputs, outputs)
def perform(self, node, (x, axis), (max, max_idx)):
max[0] = numpy.asarray(numpy.max(x, axis))
......@@ -1336,6 +1337,8 @@ class MaxAndArgmax(Op):
xmax_pad = shape_padright(xmax)
g_x = eq(xmax_pad, x) * g_max_pad
return g_x, None
def __str__(self):
return self.__class__.__name__
_max_and_argmax = MaxAndArgmax()
@_redefine_asRoutine(_max_and_argmax)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论