提交 a75dd96e authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fixes comments from review.

上级 bcd4a76b
......@@ -3831,16 +3831,17 @@ class Compositef32(object):
for i in node.inputs:
if i not in mapping:
assert type(i) is ScalarConstant
if i.type == float16:
i = ScalarConstant(float32, i.data)
mapping[i] = i
if type(node.op) in self.special:
self.special[type(node.op)](node, mapping)
continue
# make sure we won't produce any float16.
assert not any(o.dtype == 'float16' for o in
node.op.output_types([mapping[i].type for i in node.inputs]))
new_node = node.clone_with_new_inputs(
[mapping[i] for i in node.inputs],
strict=False)
# make sure we don't produce any float16.
assert not any(o.dtype == 'float16' for o in new_node.outputs)
for o, no in zip(node.outputs, new_node.outputs):
mapping[o] = no
......
......@@ -25,7 +25,7 @@ from theano.scalar.basic import (floats, float16, float32, float64,
ComplexError, IntDiv, TrueDiv,
Composite, add, div_proxy,
and_, eq, neq, invert, mul, Scalar, InRange,
cast)
cast, constant)
from theano.scalar.basic import (
true_div, inv, log, log2, log10, log1p, exp, exp2, expm1, sqrt, deg2rad,
rad2deg, cos, arccos, sin, arcsin, tan, arctan, arctan2, cosh, arccosh,
......@@ -68,11 +68,8 @@ class test_ScalarOps(unittest.TestCase):
def has_f16(comp):
if any(i.type == float16 for i in comp.fgraph.inputs):
if any(v.type == float16 for v in comp.fgraph.variables):
return True
for n in comp.fgraph.apply_nodes:
if any(o.type == float16 for o in n.outputs):
return True
return False
......@@ -83,8 +80,9 @@ class test_composite(unittest.TestCase):
y = float32()
cz = Composite([x, y], [tanh(x + cast(y, 'float16'))])
c = Composite([w, x, y], [cz(x, y) - cz(x, y)**2 +
cast(x, 'int16') + cast(x, 'float32') +
cast(w, 'float16')])
cast(x, 'int16') + cast(x, 'float32') +
cast(w, 'float16') -
constant(np.float16(1.0))])
assert has_f16(c)
nc = c.clone_float32()
assert not has_f16(nc)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论