提交 46fbfeb6 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4450 from nouiz/deepworx-pr1

Remove upcast to float64 in the graph.
......@@ -5210,11 +5210,12 @@ def local_opt_alloc(node):
val = get_scalar_constant_value(input)
assert val.size == 1
# check which type of op
casted = T.mul(*shapes).astype(str(input.dtype))
if isinstance(node.op, T.Sum):
val = val.reshape(1)[0] * T.mul(*shapes)
val = val.reshape(1)[0] * casted
else:
val = val.reshape(1)[0] ** T.mul(*shapes)
return [T.cast(val, dtype=node.outputs[0].dtype)]
val = val.reshape(1)[0] ** casted
return [val]
except NotScalarConstantError:
pass
......@@ -5226,11 +5227,12 @@ def local_opt_alloc(node):
to_prod = [shapes[i] for i in xrange(len(shapes))
if i in node.op.axis]
if to_prod:
casted = T.mul(*to_prod).astype(str(input.dtype))
if isinstance(node.op, T.Sum):
val *= T.mul(*to_prod)
val *= casted
else:
val = val ** T.mul(*to_prod)
return [T.alloc(T.cast(val, dtype=node.outputs[0].dtype),
val = val ** casted
return [T.alloc(val,
*[shapes[i] for i in xrange(len(shapes))
if i not in node.op.axis])]
except NotScalarConstantError:
......
......@@ -5271,8 +5271,8 @@ class T_local_sum_prod(unittest.TestCase):
assert numpy.allclose(f(input), input.sum())
assert len(f.maker.fgraph.apply_nodes) == 1
def test_local_sum_prod_alloc(self):
# test local_opt_alloc
a = T.dtensor3()
input = numpy.asarray(numpy.arange(2 * 3 * 4).reshape(2, 3, 4),
dtype='float64')
......@@ -5376,6 +5376,30 @@ class T_local_sum_prod(unittest.TestCase):
config.on_opt_error = backup
class T_local_opt_alloc(unittest.TestCase):
def test_sum_upcast(self):
s = theano.tensor.lscalar()
a = theano.tensor.alloc(numpy.asarray(5, dtype='float32'), s, s)
orig = theano.config.warn_float64
theano.config.warn_float64 = "raise"
try:
f = theano.function([s], a.sum())
f(5)
finally:
theano.config.warn_float64 = orig
def test_prod_upcast(self):
s = theano.tensor.lscalar()
a = theano.tensor.alloc(numpy.asarray(5, dtype='float32'), s, s)
orig = theano.config.warn_float64
theano.config.warn_float64 = "raise"
try:
f = theano.function([s], a.prod())
f(5)
finally:
theano.config.warn_float64 = orig
class T_local_reduce(unittest.TestCase):
def setUp(self):
self.mode = theano.compile.get_default_mode().including(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论