提交 d5cd4c00 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Check that GpuConv is in the function graph.

上级 f799ffcc
...@@ -358,8 +358,11 @@ def build_conv_nnet2_classif(use_gpu, isize, ksize, n_batch, ...@@ -358,8 +358,11 @@ def build_conv_nnet2_classif(use_gpu, isize, ksize, n_batch,
train = pfunc([x,y,lr], [loss], mode=mode, updates=[(p, p-g) for p,g in zip(params, gparams)]) train = pfunc([x,y,lr], [loss], mode=mode, updates=[(p, p-g) for p,g in zip(params, gparams)])
if verbose: if verbose:
for i, n in enumerate(train.maker.env.toposort()): theano.printing.debugprint(train)
print i, n if use_gpu:
# Check that GpuConv is used
topo = train.maker.env.toposort()
assert len([n for n in topo if isinstance(n.op, tcn.blas.GpuConv)]) > 0
shape_target = (n_batch,n_out) shape_target = (n_batch,n_out)
return train, params, shape_img, shape_target, mode return train, params, shape_img, shape_target, mode
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论