提交 ece8b25b authored 作者: Adam Becker's avatar Adam Becker

make out_dtype param of make_node

上级 ea15e371
...@@ -287,35 +287,34 @@ class ArgTopKOp(theano.Op): ...@@ -287,35 +287,34 @@ class ArgTopKOp(theano.Op):
""" """
__props__ = ('out_dtype', 'axis') __props__ = ('axis',)
def __init__(self, axis=-1, out_dtype='int64'): def __init__(self, axis=-1):
# numpy always uses float64 as output dtype for arg*() routines
# however, we add this option as memory is more precious on gpu
assert isinstance(axis, int) assert isinstance(axis, int)
self.out_dtype = out_dtype
self.axis = axis self.axis = axis
def __str__(self): def __str__(self):
return '%(op)s{axis=%(axis)d, dtype=%(dtype)s}' % dict( return '%(op)s{axis=%(axis)d}' % dict(
op=self.__class__.__name__, dtype=self.out_dtype, axis=self.axis) op=self.__class__.__name__, axis=self.axis)
def make_node(self, inp, k): def make_node(self, inp, k, out_dtype='int64'):
inp = theano.tensor.as_tensor_variable(inp) inp = theano.tensor.as_tensor_variable(inp)
k = theano.tensor.as_tensor_variable(k) k = theano.tensor.as_tensor_variable(k)
bcast = inp.type.broadcastable bcast = inp.type.broadcastable
return theano.Apply(self, [inp, k], [ return theano.Apply(self, [inp, k], [
theano.tensor.TensorType( theano.tensor.TensorType(
dtype=self.out_dtype, dtype=out_dtype,
broadcastable=bcast)()]) broadcastable=bcast)()])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
x, k = inputs x, k = inputs
pz = output_storage[0] pz = output_storage[0]
print("Op's axis: %d" % self.axis) print("Op's axis: %d" % self.axis)
pz[0] = _argtopk_py_impl(x, k, self.axis, self.out_dtype) pz[0] = _argtopk_py_impl(x, k, self.axis, node.outputs[0].dtype)
def infer_shape(self, node, inp_shapes): def infer_shape(self, node, inp_shapes):
# numpy always uses float64 as output dtype for arg*() routines
# however, we add this option as memory is more precious on gpu
_check_tensor_is_scalar(node.inputs[1]) _check_tensor_is_scalar(node.inputs[1])
shp = list(inp_shapes[0]) shp = list(inp_shapes[0])
if not isinstance(self.axis, int): if not isinstance(self.axis, int):
...@@ -328,7 +327,7 @@ class ArgTopKOp(theano.Op): ...@@ -328,7 +327,7 @@ class ArgTopKOp(theano.Op):
raise IndexError( raise IndexError(
'axis parameter out of range,' 'axis parameter out of range,'
' expected integer within [%d, %d]' % (-ndim, ndim - 1)) ' expected integer within [%d, %d]' % (-ndim, ndim - 1))
shp[self.axis] = node.inputs[1] shp[self.axis] = np.abs(node.inputs[1])
return [tuple(shp)] return [tuple(shp)]
...@@ -359,4 +358,4 @@ def argtopk(x, k, axis=-1, out_dtype='int64'): ...@@ -359,4 +358,4 @@ def argtopk(x, k, axis=-1, out_dtype='int64'):
if axis is None: if axis is None:
x = theano.tensor.flatten(x) x = theano.tensor.flatten(x)
axis = -1 axis = -1
return ArgTopKOp(axis=axis, out_dtype=out_dtype)(x, k) return ArgTopKOp(axis=axis)(x, k, out_dtype=out_dtype)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论