提交 cd26cc10 authored 作者: f0k's avatar f0k

Minor cleanups suggested by @nouiz

上级 17f7b9ec
......@@ -839,7 +839,7 @@ class LocalMetaOptimizer(LocalOptimizer):
return self._tracks
def transform(self, node):
# safety check: not sure if needed, but all optimizers do it
# safety check: depending on registration, tracks may have been ignored
if self._tracks is not None:
if not isinstance(node.op, tuple(self._tracks)):
return
......@@ -852,8 +852,7 @@ class LocalMetaOptimizer(LocalOptimizer):
pass
elif hasattr(input.tag, 'test_value'):
givens[input] = theano.shared(
numpy.require(input.tag.test_value,
dtype=input.dtype),
input.type.filter(input.tag.test_value),
input.name, borrow=True)
else:
missing.add(input)
......
......@@ -155,21 +155,6 @@ gpu_seqopt.register('InputToGpuOptimizer', InputToGpuOptimizer(),
'merge') # TODO: how to make it mandatory for gpu_seqopt?
class LocalCudaMetaOptimizer(LocalMetaOptimizer):
"""Base class for CUDA-based LocalMetaOptimizers"""
def __init__(self, *args):
super(LocalCudaMetaOptimizer, self).__init__(*args)
def time_call(self, fn):
# Override time_call() to do device synchronization
theano.sandbox.cuda.synchronize()
start = time.time()
fn()
theano.sandbox.cuda.synchronize()
return time.time() - start
@local_optimizer([gpu_from_host, host_from_gpu])
def local_cut_gpu_host_gpu(node):
if tensor.opt.opt.check_chain(node, gpu_from_host, host_from_gpu):
......@@ -1362,6 +1347,18 @@ conv_groupopt.register('local_conv_gemm', local_conv_gemm, 30,
'fast_compile', 'fast_run')
class LocalCudaMetaOptimizer(LocalMetaOptimizer):
"""Base class for CUDA-based LocalMetaOptimizers"""
def time_call(self, fn):
# Override time_call() to do device synchronization
theano.sandbox.cuda.synchronize()
start = time.time()
fn()
theano.sandbox.cuda.synchronize()
return time.time() - start
# Convolution Meta-optimizer
class ConvMetaOptimizer(LocalCudaMetaOptimizer):
......@@ -1386,6 +1383,9 @@ class ConvMetaOptimizer(LocalCudaMetaOptimizer):
(shape is not None) and
not any(s is None for s in shape)):
result[var] = theano.shared(
# TODO: Use var.type.filter when cuda_ndarray.filter supports non-strict casts
# var.type.filter(numpy.random.randn(*shape),
# allow_downcast=True),
numpy.require(numpy.random.randn(*shape),
dtype=var.dtype),
var.name, borrow=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论