提交 fe1083e8 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5077 from abergeron/opt_composite_f32

Add a method to properly convert the graph of Composite to float32.
...@@ -6,7 +6,7 @@ import theano ...@@ -6,7 +6,7 @@ import theano
from theano import Apply, scalar, config, Op from theano import Apply, scalar, config, Op
from six.moves import StringIO, xrange from six.moves import StringIO, xrange
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.scalar import Scalar from theano.scalar import Scalar, Composite
from theano.tensor.elemwise import (Elemwise, DimShuffle, CAReduceDtype) from theano.tensor.elemwise import (Elemwise, DimShuffle, CAReduceDtype)
try: try:
...@@ -49,6 +49,11 @@ class GpuElemwise(HideC, Elemwise): ...@@ -49,6 +49,11 @@ class GpuElemwise(HideC, Elemwise):
nout = property(lambda self: self.scalar_op.nout) nout = property(lambda self: self.scalar_op.nout)
_f16_ok = True _f16_ok = True
def __init__(self, scalar_op, *args, **kwargs):
if isinstance(scalar_op, Composite):
scalar_op = scalar_op.clone_float32()
Elemwise.__init__(self, scalar_op, *args, **kwargs)
def __str__(self): def __str__(self):
if self.name is not None: if self.name is not None:
return self.name return self.name
...@@ -86,15 +91,7 @@ class GpuElemwise(HideC, Elemwise): ...@@ -86,15 +91,7 @@ class GpuElemwise(HideC, Elemwise):
except MethodNotDefined: except MethodNotDefined:
pass pass
if fake_node.op != self.scalar_op: node = Apply(self, inputs, outputs)
# If the new op is different due to type changes, we make a new
# op for it.
elem = GpuElemwise(fake_node.op, self.inplace_pattern, self.name,
self.nfunc_spec, self.openmp)
else:
elem = self
node = Apply(elem, inputs, outputs)
return node return node
def get_params(self, node): def get_params(self, node):
...@@ -119,8 +116,7 @@ class GpuElemwise(HideC, Elemwise): ...@@ -119,8 +116,7 @@ class GpuElemwise(HideC, Elemwise):
inps, outs, inps, outs,
dict(fail='return;')) dict(fail='return;'))
# Some ops like cast will reintroduce float16 in the internal graph. assert 'npy_float16' not in kop
kop = kop.replace('npy_float16', 'ga_float')
support_code = "" support_code = ""
try: try:
......
...@@ -3852,6 +3852,11 @@ class Composite(ScalarOp): ...@@ -3852,6 +3852,11 @@ class Composite(ScalarOp):
n.op.prepare_node(n, None, None, impl) n.op.prepare_node(n, None, None, impl)
self.prepare_node_called.add(impl) self.prepare_node_called.add(impl)
def clone_float32(self):
# This will not modify the fgraph or the nodes
new_ins, new_outs = composite_f32.apply(self.fgraph)
return Composite(new_ins, new_outs)
def output_types(self, input_types): def output_types(self, input_types):
if tuple(input_types) != self.inputs_type: if tuple(input_types) != self.inputs_type:
raise TypeError("Wrong types for Composite. Expected %s, got %s." raise TypeError("Wrong types for Composite. Expected %s, got %s."
...@@ -3986,3 +3991,82 @@ class Composite(ScalarOp): ...@@ -3986,3 +3991,82 @@ class Composite(ScalarOp):
self.init_fgraph() self.init_fgraph()
self.init_py_impls() self.init_py_impls()
assert self._c_code assert self._c_code
class Compositef32(object):
# This is a dict of scalar op classes that need special handling
special = {}
def apply(self, fgraph):
mapping = {}
topo = fgraph.toposort()
for i in fgraph.inputs:
if i.dtype == 'float16':
mapping[i] = get_scalar_type('float32')()
else:
mapping[i] = i
for node in topo:
# Patch up for constants
for i in node.inputs:
if i not in mapping:
assert type(i) is ScalarConstant
if i.type == float16:
ni = ScalarConstant(float32, i.data)
else:
ni = i
mapping[i] = ni
if type(node.op) in self.special:
self.special[type(node.op)](node, mapping)
continue
new_node = node.clone_with_new_inputs(
[mapping[inp] for inp in node.inputs],
strict=False)
# make sure we don't produce any float16.
assert not any(o.dtype == 'float16' for o in new_node.outputs)
for o, no in zip(node.outputs, new_node.outputs):
mapping[o] = no
new_ins = [mapping[inp] for inp in fgraph.inputs]
new_outs = [mapping[out] for out in fgraph.outputs]
return new_ins, new_outs
composite_f32 = Compositef32()
def handle_cast(node, mapping):
inp = mapping[node.inputs[0]]
out = node.outputs[0]
node_ok = False
if node.op.o_type == float16:
if node.inputs[0].type == float32:
# cast f32 -> f16, remove
mapping[out] = inp
return
else:
# cast to f16, convert to f32
new_out = cast(inp, 'float32')
# change the node for the following if
node = new_out.owner
mapping[out] = new_out
node_ok = True
if node.inputs[0].type == float16:
if node.op.o_type == inp.type:
# cast f16 to new input type, remove
mapping[out] = inp
return
if not node_ok:
new_node = node.clone_with_new_inputs([inp],
strict=False)
mapping[out] = new_node.outputs[0]
Compositef32.special[Cast] = handle_cast
def handle_composite(node, mapping):
new_op = node.op.clone_float32()
new_outs = new_op(*[mapping[i] for i in node.inputs], return_list=True)
assert len(new_outs) == len(node.outputs)
for o, no in zip(node.outputs, new_outs):
mapping[o] = no
Compositef32.special[Composite] = handle_composite
...@@ -20,11 +20,12 @@ from theano.gof import FunctionGraph ...@@ -20,11 +20,12 @@ from theano.gof import FunctionGraph
from theano import gof from theano import gof
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.scalar.basic import (floats, float32, float64, from theano.scalar.basic import (floats, float16, float32, float64,
ints, int8, int32, complex64, ints, int8, int32, complex64,
ComplexError, IntDiv, TrueDiv, ComplexError, IntDiv, TrueDiv,
Composite, add, div_proxy, Composite, add, div_proxy,
and_, eq, neq, invert, mul, Scalar, InRange) and_, eq, neq, invert, mul, Scalar, InRange,
cast, constant)
from theano.scalar.basic import ( from theano.scalar.basic import (
true_div, inv, log, log2, log10, log1p, exp, exp2, expm1, sqrt, deg2rad, true_div, inv, log, log2, log10, log1p, exp, exp2, expm1, sqrt, deg2rad,
rad2deg, cos, arccos, sin, arcsin, tan, arctan, arctan2, cosh, arccosh, rad2deg, cos, arccos, sin, arcsin, tan, arctan, arctan2, cosh, arccosh,
...@@ -58,15 +59,33 @@ class test_ScalarOps(unittest.TestCase): ...@@ -58,15 +59,33 @@ class test_ScalarOps(unittest.TestCase):
as Python. That is what we want. as Python. That is what we want.
""" """
x, y = ints('xy') x, y = ints('xy')
fn = gof.DualLinker().accept(FunctionGraph([x, y], [x%y])).make_function() fn = gof.DualLinker().accept(FunctionGraph([x, y], [x % y])).make_function()
for a, b in ((0, 1), (1, 1), (0, -1), (1, -1), (-1, -1), for a, b in ((0, 1), (1, 1), (0, -1), (1, -1), (-1, -1),
(1, 2), (-1, 2), (1, -2), (-1, -2), (1, 2), (-1, 2), (1, -2), (-1, -2),
(5, 3), (-5, 3), (5, -3), (-5, -3) (5, 3), (-5, 3), (5, -3), (-5, -3)
): ):
self.assertTrue(fn(a, b) == a%b, (a,)) self.assertTrue(fn(a, b) == a % b, (a,))
def has_f16(comp):
if any(v.type == float16 for v in comp.fgraph.variables):
return True
return False
class test_composite(unittest.TestCase): class test_composite(unittest.TestCase):
def test_composite_clone_float32(self):
w = int8()
x = float16()
y = float32()
cz = Composite([x, y], [tanh(x + cast(y, 'float16'))])
c = Composite([w, x, y], [cz(x, y) - cz(x, y)**2 +
cast(x, 'int16') + cast(x, 'float32') +
cast(w, 'float16') -
constant(np.float16(1.0))])
assert has_f16(c)
nc = c.clone_float32()
assert not has_f16(nc)
def test_straightforward(self): def test_straightforward(self):
x, y, z = inputs() x, y, z = inputs()
...@@ -126,7 +145,7 @@ class test_composite(unittest.TestCase): ...@@ -126,7 +145,7 @@ class test_composite(unittest.TestCase):
C = Composite([x, y, z], [e0, e1, e2, e3, e4, e5, e6, e7]) C = Composite([x, y, z], [e0, e1, e2, e3, e4, e5, e6, e7])
c = C.make_node(x, y, z) c = C.make_node(x, y, z)
g = FunctionGraph([x, y, z], c.outputs) g = FunctionGraph([x, y, z], c.outputs)
fn = gof.DualLinker().accept(g).make_function() gof.DualLinker().accept(g).make_function()
assert str(g) == ('[*1 -> Composite{((i0 + i1) + i2),' assert str(g) == ('[*1 -> Composite{((i0 + i1) + i2),'
' (i0 + (i1 * i2)), (i0 / i1), ' ' (i0 + (i1 * i2)), (i0 / i1), '
...@@ -195,13 +214,13 @@ class test_logical(unittest.TestCase): ...@@ -195,13 +214,13 @@ class test_logical(unittest.TestCase):
def test_or(self): def test_or(self):
x, y, z = ints('xyz') x, y, z = ints('xyz')
fn = gof.DualLinker().accept(FunctionGraph([x, y], [x|y])).make_function() fn = gof.DualLinker().accept(FunctionGraph([x, y], [x | y])).make_function()
for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)): for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
self.assertTrue(fn(a, b) == (a|b), (a, b)) self.assertTrue(fn(a, b) == (a | b), (a, b))
def test_xor(self): def test_xor(self):
x, y, z = ints('xyz') x, y, z = ints('xyz')
fn = gof.DualLinker().accept(FunctionGraph([x, y], [x^y])).make_function() fn = gof.DualLinker().accept(FunctionGraph([x, y], [x ^ y])).make_function()
for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)): for a, b in ((0, 1), (0, 0), (1, 0), (1, 1)):
self.assertTrue(fn(a, b) == (a ^ b), (a, b)) self.assertTrue(fn(a, b) == (a ^ b), (a, b))
...@@ -335,8 +354,9 @@ class test_upgrade_to_float(object): ...@@ -335,8 +354,9 @@ class test_upgrade_to_float(object):
# Automatically define all individual unary tests # Automatically define all individual unary tests
for unary_op, x_range in self.unary_ops_vals: for unary_op, x_range in self.unary_ops_vals:
test_name = 'test_%s' % unary_op.name test_name = 'test_%s' % unary_op.name
# Make a lambda function so we can name the test
test = lambda: self._test_unary(unary_op, x_range) def test():
self._test_unary(unary_op, x_range)
test.description = test_name test.description = test_name
yield test yield test
...@@ -344,8 +364,9 @@ class test_upgrade_to_float(object): ...@@ -344,8 +364,9 @@ class test_upgrade_to_float(object):
# Automatically define all individual binary tests # Automatically define all individual binary tests
for binary_op, x_range, y_range in self.binary_ops_vals: for binary_op, x_range, y_range in self.binary_ops_vals:
test_name = 'test_%s' % binary_op.name test_name = 'test_%s' % binary_op.name
# Make a lambda function so we can name the test
test = lambda: self._test_binary(binary_op, x_range, y_range) def test():
self._test_binary(binary_op, x_range, y_range)
test.description = test_name test.description = test_name
yield test yield test
...@@ -371,16 +392,15 @@ class test_div(unittest.TestCase): ...@@ -371,16 +392,15 @@ class test_div(unittest.TestCase):
d = float64() d = float64()
f = float32() f = float32()
#print (a//b).owner.op assert isinstance((a // b).owner.op, IntDiv)
assert isinstance((a//b).owner.op, IntDiv) assert isinstance((b // a).owner.op, IntDiv)
assert isinstance((b//a).owner.op, IntDiv) assert isinstance((b / d).owner.op, TrueDiv)
assert isinstance((b/d).owner.op, TrueDiv) assert isinstance((b / f).owner.op, TrueDiv)
assert isinstance((b/f).owner.op, TrueDiv) assert isinstance((f / a).owner.op, TrueDiv)
assert isinstance((f/a).owner.op, TrueDiv) assert isinstance((d / b).owner.op, TrueDiv)
assert isinstance((d/b).owner.op, TrueDiv) assert isinstance((d / f).owner.op, TrueDiv)
assert isinstance((d/f).owner.op, TrueDiv) assert isinstance((f / c).owner.op, TrueDiv)
assert isinstance((f/c).owner.op, TrueDiv) assert isinstance((a / c).owner.op, TrueDiv)
assert isinstance((a/c).owner.op, TrueDiv)
def test_grad_gt(): def test_grad_gt():
...@@ -388,7 +408,7 @@ def test_grad_gt(): ...@@ -388,7 +408,7 @@ def test_grad_gt():
y = float32(name='y') y = float32(name='y')
z = x > y z = x > y
g = theano.gradient.grad(z, y) g = theano.gradient.grad(z, y)
assert g.eval({ y : 1. }) == 0. assert g.eval({y: 1.}) == 0.
def test_grad_switch(): def test_grad_switch():
......
...@@ -76,7 +76,6 @@ whitelist_flake8 = [ ...@@ -76,7 +76,6 @@ whitelist_flake8 = [
"tensor/signal/tests/__init__.py", "tensor/signal/tests/__init__.py",
"scalar/__init__.py", "scalar/__init__.py",
"scalar/tests/__init__.py", "scalar/tests/__init__.py",
"scalar/tests/test_basic.py",
"sandbox/__init__.py", "sandbox/__init__.py",
"sandbox/tests/test_theano_object.py", "sandbox/tests/test_theano_object.py",
"sandbox/tests/test_scan.py", "sandbox/tests/test_scan.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论