提交 ff835e51 authored 作者: Frederic's avatar Frederic

Fix compilation crash reported in gh-1518 for gpujoin.

上级 a351e3db
...@@ -5,7 +5,7 @@ import sys ...@@ -5,7 +5,7 @@ import sys
import numpy import numpy
import theano import theano
from theano import Type, Apply from theano import gof, Type, Apply
from theano import tensor, scalar, config from theano import tensor, scalar, config
from theano.compat.six import StringIO from theano.compat.six import StringIO
from theano.scalar import Scalar from theano.scalar import Scalar
...@@ -2960,7 +2960,6 @@ class GpuJoin(tensor.Join, GpuOp): ...@@ -2960,7 +2960,6 @@ class GpuJoin(tensor.Join, GpuOp):
str = """ str = """
const int axis = PyInt_AsLong((PyObject*)%(axis)s); const int axis = PyInt_AsLong((PyObject*)%(axis)s);
const int nd = %(nd)s; const int nd = %(nd)s;
int shape_%(input_1)s[nd];
int shape_out[nd]; int shape_out[nd];
int width_sum = 0; int width_sum = 0;
int errorcode; int errorcode;
...@@ -2973,11 +2972,6 @@ class GpuJoin(tensor.Join, GpuOp): ...@@ -2973,11 +2972,6 @@ class GpuJoin(tensor.Join, GpuOp):
start = NULL; start = NULL;
stop = NULL; stop = NULL;
for(int i = 0; i<nd; i+=1)
{
shape_%(input_1)s[i] = CudaNdarray_HOST_DIMS(%(input_1)s)[i];
shape_out[i] = shape_%(input_1)s[i];
}
""" % locals() """ % locals()
# getting the shapes of all the involved tensors (input[1:]) # getting the shapes of all the involved tensors (input[1:])
...@@ -2986,7 +2980,7 @@ class GpuJoin(tensor.Join, GpuOp): ...@@ -2986,7 +2980,7 @@ class GpuJoin(tensor.Join, GpuOp):
# shape_%(cdna)s[nd] is initialized before, to prevent following # shape_%(cdna)s[nd] is initialized before, to prevent following
# error: jump to label __label_9 crosses initialization of # error: jump to label __label_9 crosses initialization of
# shape_%(cdna)s[nd] # shape_%(cdna)s[nd]
for i, cdna in enumerate(inputs[2:]): for i, cdna in enumerate(gof.utils.uniq(inputs[1:])):
str += """ str += """
int shape_%(cdna)s[nd]; int shape_%(cdna)s[nd];
""" % locals() """ % locals()
...@@ -2998,8 +2992,14 @@ class GpuJoin(tensor.Join, GpuOp): ...@@ -2998,8 +2992,14 @@ class GpuJoin(tensor.Join, GpuOp):
if(full_slice == NULL){ if(full_slice == NULL){
%(fail)s; %(fail)s;
} }
for(int i = 0; i<nd; i+=1)
{
shape_%(input_1)s[i] = CudaNdarray_HOST_DIMS(%(input_1)s)[i];
shape_out[i] = shape_%(input_1)s[i];
}
""" % locals() """ % locals()
for i, cdna in enumerate(inputs[2:]): for i, cdna in enumerate(gof.utils.uniq(inputs[2:])):
str += """ str += """
for(int i = 0; i<nd; i+=1) for(int i = 0; i<nd; i+=1)
{ {
...@@ -3090,7 +3090,7 @@ class GpuJoin(tensor.Join, GpuOp): ...@@ -3090,7 +3090,7 @@ class GpuJoin(tensor.Join, GpuOp):
return str return str
def c_code_cache_version(self): def c_code_cache_version(self):
return (4,) return (5,)
gpu_join = GpuJoin() gpu_join = GpuJoin()
......
...@@ -3294,11 +3294,29 @@ class T_Join_and_Split(unittest.TestCase): ...@@ -3294,11 +3294,29 @@ class T_Join_and_Split(unittest.TestCase):
""" """
Regression test for a crash that used to happen when rebroadcasting. Regression test for a crash that used to happen when rebroadcasting.
""" """
x = tensor.TensorType(floatX, [False, False, True])() x = tensor.TensorType(self.floatX, [False, False, True])()
u = tensor.TensorType(floatX, [False, False, True])() u = tensor.TensorType(self.floatX, [False, False, True])()
# This line used to crash. # This line used to crash.
z = tensor.concatenate([x, -u], axis=2) z = tensor.concatenate([x, -u], axis=2)
def test_concatenate_same(self):
"""
Test that we can concatenate the same tensor multiple time.
In the past it was broken on the GPU.
"""
rng = numpy.random.RandomState(seed=utt.fetch_seed())
T_shared = self.shared(rng.rand(3, 4).astype(self.floatX))
Tout = tensor.concatenate([T_shared, T_shared])
f = function([], Tout, mode=self.mode)
out = f()
if theano.config.mode != 'FAST_COMPILE':
assert [True for node in f.maker.fgraph.toposort() if isinstance(
node.op, self.join_op)]
assert numpy.allclose(out,
numpy.concatenate([T_shared.get_value(),
T_shared.get_value()]))
class test_comparison(unittest.TestCase): class test_comparison(unittest.TestCase):
"""Test <, >, <=, >=, == and != """Test <, >, <=, >=, == and !=
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论