提交 5f75ecdc authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Remove useless optimizer now that mixed-type abstractconv can't happen.

上级 3f31dc24
...@@ -14,10 +14,8 @@ from theano.gof.optdb import LocalGroupDB ...@@ -14,10 +14,8 @@ from theano.gof.optdb import LocalGroupDB
from theano.scalar.basic import Scalar, Pow, Cast from theano.scalar.basic import Scalar, Pow, Cast
from theano.scan_module import scan_utils, scan_op, scan_opt from theano.scan_module import scan_utils, scan_op, scan_opt
from theano.tensor import as_tensor_variable
from theano.tensor.nnet.conv import ConvOp from theano.tensor.nnet.conv import ConvOp
from theano.tensor.nnet.abstract_conv import (BaseAbstractConv2d, from theano.tensor.nnet.abstract_conv import (AbstractConv2d,
AbstractConv2d,
AbstractConv2d_gradWeights, AbstractConv2d_gradWeights,
AbstractConv2d_gradInputs) AbstractConv2d_gradInputs)
...@@ -819,26 +817,6 @@ def local_lift_abstractconv2d(node, context_name): ...@@ -819,26 +817,6 @@ def local_lift_abstractconv2d(node, context_name):
context_name=context_name) context_name=context_name)
return [node.op(*inps)] return [node.op(*inps)]
# This will deal with ops that don't have an explicit transfer but
# have one of their inputs on the GPU already and the other not on the
# GPU (to avoid endlessly replacing things).
@register_opt('fast_compile')
@local_optimizer([AbstractConv2d,
AbstractConv2d_gradWeights,
AbstractConv2d_gradInputs])
def local_gpu_abstractconv2d(node):
if isinstance(node.op, BaseAbstractConv2d):
if ((isinstance(node.inputs[0].type, GpuArrayType) or
isinstance(node.inputs[1].type, GpuArrayType)) and
not (isinstance(node.inputs[0].type, GpuArrayType) or
isinstance(node.inputs[1].type, GpuArrayType))):
inps = list(node.inputs)
ctx_name = infer_context_name(inps[0], inps[1])
inps[0] = as_gpuarray_variable(inps[0], context_name=ctx_name)
inps[1] = as_gpuarray_variable(inps[1], context_name=ctx_name)
return as_tensor_variable(node.op(*inps))
# Register this here so that it goes after the abstract lifting # Register this here so that it goes after the abstract lifting
register_opt()(conv_groupopt) register_opt()(conv_groupopt)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论