提交 0d8583ed authored 作者: Olivier Delalleau's avatar Olivier Delalleau

New attempt to fix the mixed scalar-array output types (previous one failed).

This one is not working either (committing it mostly to keep track of it in the history).
上级 53b503a3
...@@ -464,6 +464,9 @@ class Elemwise(Op): ...@@ -464,6 +464,9 @@ class Elemwise(Op):
else: else:
array_inputs.append((input_idx, input)) array_inputs.append((input_idx, input))
shadow = self.scalar_op.make_node(*[Scalar(dtype=dtype)() for dtype in input_dtypes])
out_dtypes = [o.type.dtype for o in shadow.outputs]
if (scalar_inputs and if (scalar_inputs and
array_inputs and array_inputs and
theano.config.cast_policy in ('numpy', 'numpy+floatX')): theano.config.cast_policy in ('numpy', 'numpy+floatX')):
...@@ -471,32 +474,53 @@ class Elemwise(Op): ...@@ -471,32 +474,53 @@ class Elemwise(Op):
# they are fundamentally different. This is specified in # they are fundamentally different. This is specified in
# http://docs.scipy.org/doc/numpy/reference/ufuncs.html # http://docs.scipy.org/doc/numpy/reference/ufuncs.html
# in the 'casting rules' section. # in the 'casting rules' section.
# It seems difficult to find a generic mechanism that would work
# for any elemwise Op. In the following we use a heuristic that
# should work for simple Ops, but may break in the future for more
# complex Ops (in which case we may need to implement a way for
# these Ops to override this heuristic).
# The heuristic consists in detecting a situation where we suspect
# some scalar input upcasted an array, by comparing the highest
# type of the outputs with the highest type of the input arrays.
# If it happens that the former is of higher type than the latter,
# then we go through all scalar inputs and if they are of a higher
# type than the highest type of the input arrays, we pretend they
# actually are of the same type (the idea is that we suspect they
# are responsible for the upcasting, so by downcasting them we hope
# to get rid of this upcasting).
array_dtype = scalar.upcast(*[a[1].dtype for a in array_inputs]) array_dtype = scalar.upcast(*[a[1].dtype for a in array_inputs])
for input_idx, input in scalar_inputs: out_dtype = scalar.upcast(*out_dtypes)
# Replace this scalar input's type with the one that numpy def is_higher(dtype_a, dtype_b):
# would use when adding this scalar to the array. return (dtype_a != dtype_b and
# Note that currently numpy's behavior is not consistent, which scalar.upcast(dtype_a, dtype_b) == dtype_a)
# is a bug (will be fixed in 1.6). See for details if is_higher(out_dtype, array_dtype):
# http://projects.scipy.org/numpy/ticket/1827 # We are in the situation described above.
# As a result, we pick the highest precision data type that modified_scalar_inputs = False
# numpy may decide to use (although we may prefer float32 over for input_idx, input in scalar_inputs:
# float64). if scalar.upcast(input.dtype, array_dtype) == out_dtype:
n_inputs = [ # This scalar may be responsible for the upcasting.
numpy.array(0, dtype=input_dtypes[input_idx]), input_dtypes[input_idx] = array_dtype
numpy.array([0], dtype=array_dtype)] modified_scalar_inputs = True
n_types = [(n_inputs[0] + n_inputs[1]).dtype, if modified_scalar_inputs:
(n_inputs[1] + n_inputs[0]).dtype] # Update 'shadow' and 'out_dtypes'.
n_highest_type = scalar.upcast(*map(str, n_types)) shadow = self.scalar_op.make_node(
if (n_highest_type == 'float64' and *[Scalar(dtype=dtype)() for dtype in input_dtypes])
theano.config.cast_policy == 'numpy+floatX' and out_dtypes = [o.type.dtype for o in shadow.outputs]
theano.config.floatX == 'float32' and # The whole point of all this is to try to avoid upcasting
array_dtype != 'float64' and # the dtype of the input arrays. The following assert makes
input_dtypes[input_idx] != 'float64'): # sure this goal was achieved. Note however that it might
# Prefer float 32 instead. # fail for some Ops that purposedly upcast arrays, in which
n_highest_type = 'float32' # case it would probably be better to use a different
input_dtypes[input_idx] = n_highest_type # mechanism for such Ops.
out_dtype = scalar.upcast(*out_dtypes)
shadow = self.scalar_op.make_node(*[Scalar(dtype=dtype)() for dtype in input_dtypes]) assert not is_higher(out_dtype, array_dtype)
else:
# Same as above: safety assert to make sure our heuristics
# did its job. It may fail in the future for some Ops that
# would require a different mechanism.
import pdb; pdb.set_trace()
raise AssertionError(
'Heuristic failure - see Elemwise.make_node')
target_length = max([input.type.ndim for input in inputs]) target_length = max([input.type.ndim for input in inputs])
...@@ -529,7 +553,6 @@ class Elemwise(Op): ...@@ -529,7 +553,6 @@ class Elemwise(Op):
for ob, ib in zip(out_broadcastables[overwriter], inputs[overwritten].type.broadcastable): for ob, ib in zip(out_broadcastables[overwriter], inputs[overwritten].type.broadcastable):
if ib and not ob: if ib and not ob:
raise ValueError("Operation cannot be done inplace on an input with broadcasted dimensions.") raise ValueError("Operation cannot be done inplace on an input with broadcasted dimensions.")
out_dtypes = [o.type.dtype for o in shadow.outputs]
if any(inputs[i].type.dtype != out_dtypes[o] for o, i in inplace_pattern.items()): if any(inputs[i].type.dtype != out_dtypes[o] for o, i in inplace_pattern.items()):
raise TypeError("Cannot do an inplace operation on incompatible data types.", raise TypeError("Cannot do an inplace operation on incompatible data types.",
([i.type.dtype for i in inputs], out_dtypes, inplace_pattern)) ([i.type.dtype for i in inputs], out_dtypes, inplace_pattern))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论