提交 55b96ed6 authored 作者: Frederic Bastien's avatar Frederic Bastien

added tensor.round fct and associated test.

上级 025fb494
...@@ -398,6 +398,11 @@ def upcast_out(*types): ...@@ -398,6 +398,11 @@ def upcast_out(*types):
return Scalar(dtype = Scalar.upcast(*types)), return Scalar(dtype = Scalar.upcast(*types)),
def same_out(type): def same_out(type):
return type, return type,
def same_out_float_only(type):
if type not in float_types:
raise TypeError('only float type are supported')
return type,
class transfer_type(gof.utils.object2): class transfer_type(gof.utils.object2):
def __init__(self, *transfer): def __init__(self, *transfer):
assert all(type(x) == int for x in transfer) assert all(type(x) == int for x in transfer)
...@@ -1147,6 +1152,18 @@ class IRound(UnaryScalarOp): ...@@ -1147,6 +1152,18 @@ class IRound(UnaryScalarOp):
return "%(z)s = round(%(x)s);" % locals() return "%(z)s = round(%(x)s);" % locals()
iround = IRound(int_out_nocomplex) iround = IRound(int_out_nocomplex)
class Round(UnaryScalarOp):
def impl(self, x):
return theano._asarray(numpy.round(x), dtype = 'int64')
def c_code(self, node, name, (x, ), (z, ), sub):
if node.outputs[0].type.dtype == 'float32':
return "%(z)s = fround(%(x)s);" % locals()
elif node.outputs[0].type.dtype == 'float64':
return "%(z)s = round(%(x)s);" % locals()
else:
Exception("The output should be float32 or float64")
round = Round(same_out_float_only)
class Neg(UnaryScalarOp): class Neg(UnaryScalarOp):
def impl(self, x): def impl(self, x):
return -x return -x
......
...@@ -1578,6 +1578,10 @@ def floor(a): ...@@ -1578,6 +1578,10 @@ def floor(a):
def iround(a): def iround(a):
"""int(round(a))""" """int(round(a))"""
@_scal_elemwise
def round(a):
"""round(a)"""
@_scal_elemwise @_scal_elemwise
def sqr(a): def sqr(a):
"""square of a""" """square of a"""
......
...@@ -128,6 +128,10 @@ def floor_inplace(a): ...@@ -128,6 +128,10 @@ def floor_inplace(a):
def iround_inplace(a): def iround_inplace(a):
"""int(round(a)) (inplace on `a`)""" """int(round(a)) (inplace on `a`)"""
@_scal_inplace
def round_inplace(a):
"""round(a) (inplace on `a`)"""
@_scal_inplace @_scal_inplace
def sqr_inplace(a): def sqr_inplace(a):
"""square of `a` (inplace on `a`)""" """square of `a` (inplace on `a`)"""
......
...@@ -448,6 +448,15 @@ IRoundInplaceTester = makeBroadcastTester(op = inplace.iround_inplace, ...@@ -448,6 +448,15 @@ IRoundInplaceTester = makeBroadcastTester(op = inplace.iround_inplace,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
inplace = True) inplace = True)
RoundTester = makeBroadcastTester(op = round,
expected = numpy.round,
good = _good_broadcast_unary_normal_float)
RoundInplaceTester = makeBroadcastTester(op = inplace.round_inplace,
expected = numpy.round,
good = _good_broadcast_unary_normal_float,
inplace = True)
SqrTester = makeBroadcastTester(op = sqr, SqrTester = makeBroadcastTester(op = sqr,
expected = numpy.square, expected = numpy.square,
good = _good_broadcast_unary_normal, good = _good_broadcast_unary_normal,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论