提交 a75dd96e authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fixes comments from review.

上级 bcd4a76b
...@@ -3831,16 +3831,17 @@ class Compositef32(object): ...@@ -3831,16 +3831,17 @@ class Compositef32(object):
for i in node.inputs: for i in node.inputs:
if i not in mapping: if i not in mapping:
assert type(i) is ScalarConstant assert type(i) is ScalarConstant
if i.type == float16:
i = ScalarConstant(float32, i.data)
mapping[i] = i mapping[i] = i
if type(node.op) in self.special: if type(node.op) in self.special:
self.special[type(node.op)](node, mapping) self.special[type(node.op)](node, mapping)
continue continue
# make sure we won't produce any float16.
assert not any(o.dtype == 'float16' for o in
node.op.output_types([mapping[i].type for i in node.inputs]))
new_node = node.clone_with_new_inputs( new_node = node.clone_with_new_inputs(
[mapping[i] for i in node.inputs], [mapping[i] for i in node.inputs],
strict=False) strict=False)
# make sure we don't produce any float16.
assert not any(o.dtype == 'float16' for o in new_node.outputs)
for o, no in zip(node.outputs, new_node.outputs): for o, no in zip(node.outputs, new_node.outputs):
mapping[o] = no mapping[o] = no
......
...@@ -25,7 +25,7 @@ from theano.scalar.basic import (floats, float16, float32, float64, ...@@ -25,7 +25,7 @@ from theano.scalar.basic import (floats, float16, float32, float64,
ComplexError, IntDiv, TrueDiv, ComplexError, IntDiv, TrueDiv,
Composite, add, div_proxy, Composite, add, div_proxy,
and_, eq, neq, invert, mul, Scalar, InRange, and_, eq, neq, invert, mul, Scalar, InRange,
cast) cast, constant)
from theano.scalar.basic import ( from theano.scalar.basic import (
true_div, inv, log, log2, log10, log1p, exp, exp2, expm1, sqrt, deg2rad, true_div, inv, log, log2, log10, log1p, exp, exp2, expm1, sqrt, deg2rad,
rad2deg, cos, arccos, sin, arcsin, tan, arctan, arctan2, cosh, arccosh, rad2deg, cos, arccos, sin, arcsin, tan, arctan, arctan2, cosh, arccosh,
...@@ -68,11 +68,8 @@ class test_ScalarOps(unittest.TestCase): ...@@ -68,11 +68,8 @@ class test_ScalarOps(unittest.TestCase):
def has_f16(comp): def has_f16(comp):
if any(i.type == float16 for i in comp.fgraph.inputs): if any(v.type == float16 for v in comp.fgraph.variables):
return True return True
for n in comp.fgraph.apply_nodes:
if any(o.type == float16 for o in n.outputs):
return True
return False return False
...@@ -83,8 +80,9 @@ class test_composite(unittest.TestCase): ...@@ -83,8 +80,9 @@ class test_composite(unittest.TestCase):
y = float32() y = float32()
cz = Composite([x, y], [tanh(x + cast(y, 'float16'))]) cz = Composite([x, y], [tanh(x + cast(y, 'float16'))])
c = Composite([w, x, y], [cz(x, y) - cz(x, y)**2 + c = Composite([w, x, y], [cz(x, y) - cz(x, y)**2 +
cast(x, 'int16') + cast(x, 'float32') + cast(x, 'int16') + cast(x, 'float32') +
cast(w, 'float16')]) cast(w, 'float16') -
constant(np.float16(1.0))])
assert has_f16(c) assert has_f16(c)
nc = c.clone_float32() nc = c.clone_float32()
assert not has_f16(nc) assert not has_f16(nc)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论