提交 56e5b9c0 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix an opt being skipped due to wrong dtype exchange. This remove opt warning in a keras test.

上级 46000c76
......@@ -5717,11 +5717,19 @@ def local_opt_alloc(node):
only_process_constants=True)
assert val.size == 1
# check which type of op
casted = T.mul(*shapes).astype(str(input.dtype))
size = T.mul(*shapes)
if isinstance(node.op, T.Sum):
val = val.reshape(1)[0] * casted
val = val.reshape(1)[0] * size
else:
val = val.reshape(1)[0] ** casted
val = val.reshape(1)[0] ** size
# Sum can change the input dtype (upcast or bool
# -> float32) by default or by user request.
# We can ignore the acc_dtype, as there is only 1
# elemwise we will do and not a sequence, so there is no
# accumulation of errors.
# So mostly, we just need to cast the output to the old
# dtype.
val = val.astype(node.outputs[0].dtype)
return [val]
except NotScalarConstantError:
......@@ -5735,11 +5743,12 @@ def local_opt_alloc(node):
to_prod = [shapes[i] for i in xrange(len(shapes))
if i in node.op.axis]
if to_prod:
casted = T.mul(*to_prod).astype(str(input.dtype))
size = T.mul(*to_prod)
if isinstance(node.op, T.Sum):
val *= casted
val *= size
else:
val = val ** casted
val = val ** size
val = val.astype(node.outputs[0].dtype)
return [T.alloc(val,
*[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])]
......
......@@ -5580,6 +5580,19 @@ class T_local_opt_alloc(unittest.TestCase):
finally:
theano.config.warn_float64 = orig
@theano.configparser.change_flags(on_opt_error='raise')
def test_sum_bool_upcast(self):
s = theano.tensor.lscalar()
a = theano.tensor.alloc(np.asarray(True, dtype='bool'), s, s)
f = theano.function([s], a.sum())
f(5)
# test with user specified dtype
f = theano.function([s], a.sum(dtype='float32'))
f(5)
# test only 1 axis summed
f = theano.function([s], a.sum(axis=0, dtype='float32'))
f(5)
class T_local_reduce(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论