提交 d823134d authored 作者: Frederic's avatar Frederic

Some more fix.

上级 c649d668
...@@ -1291,7 +1291,7 @@ def local_conv_gemm(node): ...@@ -1291,7 +1291,7 @@ def local_conv_gemm(node):
print "WARNING, YOU ARE USING BUGGED CODE!" print "WARNING, YOU ARE USING BUGGED CODE!"
img, kern = node.inputs img, kern = node.inputs
img = gpu_contiguous(img) img = gpu_contiguous(img)
#kern = kern[:, :, ::-1, ::-1] kern = kern[:, :, ::-1, ::-1]
kern = gpu_contiguous(kern) kern = gpu_contiguous(kern)
return [GpuConvMM(node.op.border_mode)(img, kern)] return [GpuConvMM(node.op.border_mode)(img, kern)]
......
...@@ -636,7 +636,8 @@ def test_valid(): ...@@ -636,7 +636,8 @@ def test_valid():
# print_=print_, ones=ones, rtol=1.1e-5) # print_=print_, ones=ones, rtol=1.1e-5)
mode = theano_mode.including("conv_gemm") mode = theano_mode.including("conv_gemm")
# import pdb;pdb.set_trace()
# Remove case not implemented
shapes = [shp for shp in shapes if shp[1][2] == shp[1][3]] shapes = [shp for shp in shapes if shp[1][2] == shp[1][3]]
shapes = [shp for shp in shapes if shp[0][2] == shp[0][3]] shapes = [shp for shp in shapes if shp[0][2] == shp[0][3]]
exec_conv(version, shapes, verbose, random, 'valid', exec_conv(version, shapes, verbose, random, 'valid',
...@@ -644,7 +645,7 @@ def test_valid(): ...@@ -644,7 +645,7 @@ def test_valid():
theano_mode=mode, cls=cuda.blas.GpuConvMM) theano_mode=mode, cls=cuda.blas.GpuConvMM)
def test_full(gemm=False): def test_full():
seed_rng() seed_rng()
shapes = get_basic_shapes() shapes = get_basic_shapes()
shapes += get_shapes2() shapes += get_shapes2()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论