提交 ece8b25b authored 作者: Adam Becker's avatar Adam Becker

make out_dtype param of make_node

上级 ea15e371
......@@ -287,35 +287,34 @@ class ArgTopKOp(theano.Op):
"""
__props__ = ('out_dtype', 'axis')
__props__ = ('axis',)
def __init__(self, axis=-1, out_dtype='int64'):
# numpy always uses float64 as output dtype for arg*() routines
# however, we add this option as memory is more precious on gpu
def __init__(self, axis=-1):
assert isinstance(axis, int)
self.out_dtype = out_dtype
self.axis = axis
def __str__(self):
return '%(op)s{axis=%(axis)d, dtype=%(dtype)s}' % dict(
op=self.__class__.__name__, dtype=self.out_dtype, axis=self.axis)
return '%(op)s{axis=%(axis)d}' % dict(
op=self.__class__.__name__, axis=self.axis)
def make_node(self, inp, k):
def make_node(self, inp, k, out_dtype='int64'):
inp = theano.tensor.as_tensor_variable(inp)
k = theano.tensor.as_tensor_variable(k)
bcast = inp.type.broadcastable
return theano.Apply(self, [inp, k], [
theano.tensor.TensorType(
dtype=self.out_dtype,
dtype=out_dtype,
broadcastable=bcast)()])
def perform(self, node, inputs, output_storage):
x, k = inputs
pz = output_storage[0]
print("Op's axis: %d" % self.axis)
pz[0] = _argtopk_py_impl(x, k, self.axis, self.out_dtype)
pz[0] = _argtopk_py_impl(x, k, self.axis, node.outputs[0].dtype)
def infer_shape(self, node, inp_shapes):
# numpy always uses float64 as output dtype for arg*() routines
# however, we add this option as memory is more precious on gpu
_check_tensor_is_scalar(node.inputs[1])
shp = list(inp_shapes[0])
if not isinstance(self.axis, int):
......@@ -328,7 +327,7 @@ class ArgTopKOp(theano.Op):
raise IndexError(
'axis parameter out of range,'
' expected integer within [%d, %d]' % (-ndim, ndim - 1))
shp[self.axis] = node.inputs[1]
shp[self.axis] = np.abs(node.inputs[1])
return [tuple(shp)]
......@@ -359,4 +358,4 @@ def argtopk(x, k, axis=-1, out_dtype='int64'):
if axis is None:
x = theano.tensor.flatten(x)
axis = -1
return ArgTopKOp(axis=axis, out_dtype=out_dtype)(x, k)
return ArgTopKOp(axis=axis)(x, k, out_dtype=out_dtype)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论