提交 f35eefb8 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add method and helper class to convert the internal fgraph of a

composite to float32. Comes with a test.
上级 173826a7
...@@ -23,8 +23,8 @@ from six.moves import xrange ...@@ -23,8 +23,8 @@ from six.moves import xrange
import theano import theano
from theano.compat import imap, izip from theano.compat import imap, izip
from theano import gof, printing from theano import gof, printing
from theano.gof import (Op, utils, Variable, Constant, Type, Apply, from theano.gof import (Op, Optimizer, utils, Variable, Constant, Type, Apply,
FunctionGraph) FunctionGraph, toolbox)
from functools import partial from functools import partial
from theano.configparser import config from theano.configparser import config
...@@ -3673,6 +3673,11 @@ class Composite(ScalarOp): ...@@ -3673,6 +3673,11 @@ class Composite(ScalarOp):
n.op.prepare_node(n, None, None, impl) n.op.prepare_node(n, None, None, impl)
self.prepare_node_called.add(impl) self.prepare_node_called.add(impl)
def clone_float32(self):
# This will not modify the fgraph or the nodes
new_ins, new_outs = composite_f32.apply(self.fgraph)
return Composite(new_ins, new_outs)
def output_types(self, input_types): def output_types(self, input_types):
if tuple(input_types) != self.inputs_type: if tuple(input_types) != self.inputs_type:
raise TypeError("Wrong types for Composite. Expected %s, got %s." raise TypeError("Wrong types for Composite. Expected %s, got %s."
...@@ -3807,3 +3812,79 @@ class Composite(ScalarOp): ...@@ -3807,3 +3812,79 @@ class Composite(ScalarOp):
self.init_fgraph() self.init_fgraph()
self.init_py_impls() self.init_py_impls()
assert self._c_code assert self._c_code
class Compositef32(object):
# This is a dict of scalar op classes that need special handling
special = {}
def apply(self, fgraph):
mapping = {}
topo = fgraph.toposort()
for i in fgraph.inputs:
if i.dtype == 'float16':
mapping[i] = get_scalar_type('float32')()
else:
mapping[i] = i
for node in topo:
# Patch up for constants
for i in node.inputs:
if i not in mapping:
assert type(i) is ScalarConstant
mapping[i] = i
if type(node.op) in self.special:
self.special[type(node.op)](node, mapping)
continue
# make sure we won't produce any float16.
assert not any(o.dtype == 'float16' for o in
node.op.output_types([mapping[i] for i in node.inputs]))
new_node = node.clone_with_new_inputs(
[mapping[i] for i in node.inputs],
strict=False)
for o, no in zip(node.outputs, new_node.outputs):
mapping[o] = no
new_ins = [mapping[i] for i in fgraph.inputs]
new_outs = [mapping[o] for o in fgraph.outputs]
return new_ins, new_outs
composite_f32 = Compositef32()
def handle_cast(node, mapping):
inp = mapping[node.inputs[0]]
out = node.outputs[0]
node_ok = False
if node.op.o_type == float16:
if node.inputs[0].type == float32:
# cast f32 -> f16, remove
mapping[out] = inp
return
else:
# cast to f16, convert to f32
new_out = cast(inp, 'float32')
# change the node for the following if
node = new_out.owner
mapping[out] = new_out
node_ok = True
if node.inputs[0].type == float16:
if node.op.o_type == inp.type:
# cast f16 to new input type, remove
mapping[out] = inp
return
if not node_ok:
new_node = node.clone_with_new_inputs([inp],
strict=False)
mapping[out] = new_node.outputs[0]
Compositef32.special[Cast] = handle_cast
def handle_composite(node, mapping):
new_op = node.op.clone_float32()
new_outs = new_op(*[mapping[i] for i in node.inputs], return_list=True)
assert len(new_outs) == len(node.outputs)
for o, no in zip(node.outputs, new_outs):
mapping[o] = no
Compositef32.special[Composite] = handle_composite
...@@ -20,11 +20,12 @@ from theano.gof import FunctionGraph ...@@ -20,11 +20,12 @@ from theano.gof import FunctionGraph
from theano import gof from theano import gof
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.scalar.basic import (floats, float32, float64, from theano.scalar.basic import (floats, float16, float32, float64,
ints, int8, int32, complex64, ints, int8, int32, complex64,
ComplexError, IntDiv, TrueDiv, ComplexError, IntDiv, TrueDiv,
Composite, add, div_proxy, Composite, add, div_proxy,
and_, eq, neq, invert, mul, Scalar, InRange) and_, eq, neq, invert, mul, Scalar, InRange,
cast)
from theano.scalar.basic import ( from theano.scalar.basic import (
true_div, inv, log, log2, log10, log1p, exp, exp2, expm1, sqrt, deg2rad, true_div, inv, log, log2, log10, log1p, exp, exp2, expm1, sqrt, deg2rad,
rad2deg, cos, arccos, sin, arcsin, tan, arctan, arctan2, cosh, arccosh, rad2deg, cos, arccos, sin, arcsin, tan, arctan, arctan2, cosh, arccosh,
...@@ -66,7 +67,27 @@ class test_ScalarOps(unittest.TestCase): ...@@ -66,7 +67,27 @@ class test_ScalarOps(unittest.TestCase):
self.assertTrue(fn(a, b) == a%b, (a,)) self.assertTrue(fn(a, b) == a%b, (a,))
def has_f16(comp):
if any(i.type == float16 for i in comp.fgraph.inputs):
return True
for n in comp.fgraph.apply_nodes:
if any(o.type == float16 for o in n.outputs):
return True
return False
class test_composite(unittest.TestCase): class test_composite(unittest.TestCase):
def test_composite_clone_float32(self):
w = int8()
x = float16()
y = float32()
cz = Composite([x, y], [tanh(x + cast(y, 'float16'))])
c = Composite([w, x, y], [cz(x, y) - cz(x, y)**2 +
cast(x, 'int16') + cast(x, 'float32') +
cast(w, 'float16')])
assert has_f16(c)
nc = c.clone_float32()
assert not has_f16(nc)
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论