提交 dc8686bb authored 作者: Frederic Bastien's avatar Frederic Bastien

Make opt handle float16 and test it

上级 79ccac56
...@@ -5717,11 +5717,13 @@ def local_opt_alloc(node): ...@@ -5717,11 +5717,13 @@ def local_opt_alloc(node):
val = val.reshape(1)[0] val = val.reshape(1)[0]
# check which type of op # check which type of op
size = T.mul(*shapes) size = T.mul(*shapes)
if input.dtype == "float32": if input.dtype in ["float16", "float32"]:
# shapes are ints and normally int64. # shapes are ints and normally int64.
# We don't want to have a float64 upcast here # We don't want to have a float64 upcast
# if input is a float32. # We don't want to downcast to float16
size = size.astype(input.dtype) # as we fear it could loose too much precision
# that will be amplified by the mul/pow below.
size = size.astype('float32')
if (node.op.axis is None or if (node.op.axis is None or
node.op.axis == tuple(range(input.ndim))): node.op.axis == tuple(range(input.ndim))):
if isinstance(node.op, T.Sum): if isinstance(node.op, T.Sum):
......
...@@ -5558,9 +5558,11 @@ class T_local_sum_prod(unittest.TestCase): ...@@ -5558,9 +5558,11 @@ class T_local_sum_prod(unittest.TestCase):
class T_local_opt_alloc(unittest.TestCase): class T_local_opt_alloc(unittest.TestCase):
dtype = 'float32'
def test_sum_upcast(self): def test_sum_upcast(self):
s = theano.tensor.lscalar() s = theano.tensor.lscalar()
a = theano.tensor.alloc(np.asarray(5, dtype='float32'), s, s) a = theano.tensor.alloc(np.asarray(5, dtype=self.dtype), s, s)
orig = theano.config.warn_float64 orig = theano.config.warn_float64
theano.config.warn_float64 = "raise" theano.config.warn_float64 = "raise"
try: try:
...@@ -5571,7 +5573,7 @@ class T_local_opt_alloc(unittest.TestCase): ...@@ -5571,7 +5573,7 @@ class T_local_opt_alloc(unittest.TestCase):
def test_prod_upcast(self): def test_prod_upcast(self):
s = theano.tensor.lscalar() s = theano.tensor.lscalar()
a = theano.tensor.alloc(np.asarray(5, dtype='float32'), s, s) a = theano.tensor.alloc(np.asarray(5, dtype=self.dtype), s, s)
orig = theano.config.warn_float64 orig = theano.config.warn_float64
theano.config.warn_float64 = "raise" theano.config.warn_float64 = "raise"
try: try:
...@@ -5587,11 +5589,16 @@ class T_local_opt_alloc(unittest.TestCase): ...@@ -5587,11 +5589,16 @@ class T_local_opt_alloc(unittest.TestCase):
f = theano.function([s], a.sum()) f = theano.function([s], a.sum())
f(5) f(5)
# test with user specified dtype # test with user specified dtype
f = theano.function([s], a.sum(dtype='float32')) f = theano.function([s], a.sum(dtype=self.dtype))
f(5) f(5)
# test only 1 axis summed # test only 1 axis summed
f = theano.function([s], a.sum(axis=0, dtype='float32')) f = theano.function([s], a.sum(axis=0, dtype=self.dtype))
f(5) f(5)
print(self.dtype)
class T_local_opt_alloc_f16(T_local_opt_alloc):
dtype = 'float16'
class T_local_reduce(unittest.TestCase): class T_local_reduce(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论