提交 f4ce2464 authored 作者: Frederic's avatar Frederic

Correctly cast the scalar during opt with dot22scalar.

上级 01159dae
...@@ -890,21 +890,28 @@ def res_is_a(node, op, maxclients=None): ...@@ -890,21 +890,28 @@ def res_is_a(node, op, maxclients=None):
and retval and retval
def _as_scalar(res): def _as_scalar(res, dtype=None):
"""Return None or a TensorVariable whose type is in T.float_scalar_types""" """Return None or a TensorVariable whose type is in T.float_scalar_types"""
if dtype is None:
dtype = config.floatX
if numpy.all(res.type.broadcastable): if numpy.all(res.type.broadcastable):
while res.owner and isinstance(res.owner.op, T.DimShuffle): while res.owner and isinstance(res.owner.op, T.DimShuffle):
res = res.owner.inputs[0] res = res.owner.inputs[0]
if res.type.broadcastable: # may still have some number of True's # may still have some number of True's
if res.type.broadcastable:
rval = res.dimshuffle() rval = res.dimshuffle()
else: else:
rval = res rval = res
if rval.type.dtype[:3] in ('int', 'uin'): if rval.type.dtype[:3] in ('int', 'uin'):
rval = T.cast(rval, theano.config.floatX) #may lose precision !? # We check that the upcast of res and dtype won't change dtype.
# If dtype is float64, we will cast int64 to float64.
#if isinstance(rval, T.Constant): # This is valid when res is a scalar used as input to a dot22
#rval = rval.data.flatten()[0] # as the cast of the scalar can be done before or after the dot22
# and this will give the same result.
if theano.scalar.upcast(res.dtype, dtype) == dtype:
return T.cast(rval, dtype)
else:
return None
return rval return rval
...@@ -1567,7 +1574,7 @@ def local_dot22_to_dot22scalar(node): ...@@ -1567,7 +1574,7 @@ def local_dot22_to_dot22scalar(node):
#return False #TODO fix #return False #TODO fix
dot22_idx = i_dot22.index(True) dot22_idx = i_dot22.index(True)
d = node.inputs[dot22_idx] d = node.inputs[dot22_idx]
i_scalar = [_as_scalar(x) for x in node.inputs] i_scalar = [_as_scalar(x, dtype=d.dtype) for x in node.inputs]
if not any(i_scalar): if not any(i_scalar):
i_mul = [x.owner and x.owner.op ==T.mul for x in node.inputs] i_mul = [x.owner and x.owner.op ==T.mul for x in node.inputs]
if not any(i_mul): if not any(i_mul):
...@@ -1581,10 +1588,10 @@ def local_dot22_to_dot22scalar(node): ...@@ -1581,10 +1588,10 @@ def local_dot22_to_dot22scalar(node):
mul_idx = i_mul.index(True)#we take the first mul! mul_idx = i_mul.index(True)#we take the first mul!
m = node.inputs[mul_idx] m = node.inputs[mul_idx]
if len(m.owner.inputs)==2 and any([_as_scalar(x) for x in m.owner.inputs]): if len(m.owner.inputs)==2 and any([_as_scalar(x, dtype=d.dtype) for x in m.owner.inputs]):
scalar_idx = -1 scalar_idx = -1
for i,x in enumerate(m.owner.inputs): for i,x in enumerate(m.owner.inputs):
if _as_scalar(x) and (theano.scalar.upcast(x.type.dtype,d.type.dtype) if _as_scalar(x, dtype=d.dtype) and (theano.scalar.upcast(x.type.dtype,d.type.dtype)
== d.type.dtype): == d.type.dtype):
scalar_idx = i scalar_idx = i
break break
...@@ -1594,7 +1601,7 @@ def local_dot22_to_dot22scalar(node): ...@@ -1594,7 +1601,7 @@ def local_dot22_to_dot22scalar(node):
'of the scalar cannot be upcasted to the matrix type', 'of the scalar cannot be upcasted to the matrix type',
node.inputs, [x.type for x in node.inputs]) node.inputs, [x.type for x in node.inputs])
return False return False
a = T.cast(_as_scalar(m.owner.inputs[scalar_idx]), d.type.dtype) a = T.cast(_as_scalar(m.owner.inputs[scalar_idx], dtype=d.dtype), d.type.dtype)
assert not a.type.ndim assert not a.type.ndim
dot=_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a) dot=_dot22scalar(d.owner.inputs[0], d.owner.inputs[1], a)
......
...@@ -827,11 +827,19 @@ def test_dot22scalar_cast(): ...@@ -827,11 +827,19 @@ def test_dot22scalar_cast():
Test that in `dot22_to_dot22scalar` we properly cast integers to floats. Test that in `dot22_to_dot22scalar` we properly cast integers to floats.
""" """
# Note that this test was failing before d5ff6904. # Note that this test was failing before d5ff6904.
A = T.matrix() A = T.dmatrix()
for scalar_int_type in T.int_dtypes: for scalar_int_type in T.int_dtypes:
y = T.scalar(dtype=scalar_int_type) y = T.scalar(dtype=scalar_int_type)
f = theano.function([A, y], T.dot(A, A) * y, mode=mode_blas_opt) f = theano.function([A, y], T.dot(A, A) * y, mode=mode_blas_opt)
assert _dot22scalar in [x.op for x in f.maker.env.toposort()] assert _dot22scalar in [x.op for x in f.maker.env.toposort()]
A = T.fmatrix()
for scalar_int_type in T.int_dtypes:
y = T.scalar(dtype=scalar_int_type)
f = theano.function([A, y], T.dot(A, A) * y, mode=mode_blas_opt)
if scalar_int_type in ['int32', 'int64']:
assert _dot22 in [x.op for x in f.maker.env.toposort()]
else:
assert _dot22scalar in [x.op for x in f.maker.env.toposort()]
def test_dot_w_self(): def test_dot_w_self():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论