提交 246ed4db authored 作者: James Bergstra's avatar James Bergstra

debug: cuda.opt local_cut_gpu_host_gpu

上级 eb0860ea
......@@ -2,7 +2,7 @@ import sys
import theano
import numpy
from theano import scalar as scal
from theano import tensor, compile
from theano import tensor, compile, gof
from theano.gof import local_optimizer, EquilibriumDB, SequenceDB, Optimizer, toolbox, DestroyHandler
from theano.sandbox.cuda.basic_ops import *
......@@ -62,10 +62,16 @@ gpu_seqopt.register('InputToGpuOptimizer', InputToGpuOptimizer(),
@local_optimizer([])
def local_cut_gpu_host_gpu(node):
copy_op=None
if tensor.opt.opt.check_chain(node, gpu_from_host, host_from_gpu):
return [node.inputs[0].owner.inputs[0]]
copy_op = GpuElemwise(scal.identity, {})
if tensor.opt.opt.check_chain(node, host_from_gpu, gpu_from_host):
return [node.inputs[0].owner.inputs[0]]
copy_op =tensor.copy
if copy_op:
#copy_op = lambda x:x
rval = copy_op(node.inputs[0].owner.inputs[0])
assert isinstance(rval, gof.Variable), "rval is not a variable"
return [rval]
return False
gpu_cut_copies.register('cut_gpu_host_transfers', local_cut_gpu_host_gpu,
'fast_run', 'inplace', 'gpu')
......
......@@ -1063,11 +1063,18 @@ def local_fill_cut(node):
If c.type == a.type.
"""
# this optimization is essentially for getting broadcasting to replace fill.
# This is always possible when using a Compound Elemwise operation,
# but it is not always possible without one (consider filling a large matrix with a scalar,
# and then adding another scalar. The only numbers that count are the two scalars, but we
# can't ignore the large matrix because it gives the shape of the result.
if not opt.check_chain(node, T.Elemwise):
return False
output = node.outputs[0]
try:
#reference is some input with the same type as the input but that is not produced by a fill
reference = [input
for input in node.inputs
if input.type == output.type and (not input.owner or input.owner.op != T.fill)][0]
......@@ -1086,7 +1093,13 @@ def local_fill_cut(node):
if new_inputs == node.inputs:
return False
return node.op.make_node(*new_inputs).outputs
print 'NEW INPUTS', new_inputs
rval = node.op(*new_inputs)
if isinstance(rval, gof.Variable):
return rval.owner.outputs
else:
return rval[0].owner.outputs
register_canonicalize(local_fill_cut)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论