提交 094af0c1 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix some optimizations not able to infer context when the lift request

is from the clients.
上级 1d69eac9
...@@ -129,18 +129,16 @@ def op_lifter(OP, cuda_only=False): ...@@ -129,18 +129,16 @@ def op_lifter(OP, cuda_only=False):
get_context(context_name).kind != 'cuda')): get_context(context_name).kind != 'cuda')):
return False return False
# tag the inputs with the context in case
# the context was derived from the outputs
for i in node.inputs:
i.tag.context_name = context_name
new_op = maker(node, context_name) new_op = maker(node, context_name)
# This is needed as sometimes new_op inherits from OP. # This is needed as sometimes new_op inherits from OP.
if new_op and new_op != node.op: if new_op and new_op != node.op:
if isinstance(new_op, theano.Op): if isinstance(new_op, theano.Op):
# tag the inputs with the context in case
# the context was derived from the outputs
def tag(i, ctx):
i.tag.context_name = ctx
return i
inputs = [tag(i, context_name) for i in node.inputs]
return [safe_to_cpu(o) for o in return [safe_to_cpu(o) for o in
new_op(*inputs, return_list=True)] new_op(*node.inputs, return_list=True)]
elif isinstance(new_op, (tuple, list)): elif isinstance(new_op, (tuple, list)):
return [safe_to_cpu(o) for o in new_op] return [safe_to_cpu(o) for o in new_op]
else: # suppose it is a variable on the GPU else: # suppose it is a variable on the GPU
...@@ -593,11 +591,11 @@ def local_gpua_advanced_incsubtensor(node, context_name): ...@@ -593,11 +591,11 @@ def local_gpua_advanced_incsubtensor(node, context_name):
compute_capability = device_properties(active_device_no)['major'] compute_capability = device_properties(active_device_no)['major']
if (compute_capability < 2 or x.ndim != 2 or y.ndim != 2): if (compute_capability < 2 or x.ndim != 2 or y.ndim != 2):
return [GpuAdvancedIncSubtensor1( return GpuAdvancedIncSubtensor1(
set_instead_of_inc=set_instead_of_inc)(x, y, ilist)] set_instead_of_inc=set_instead_of_inc)
else: else:
return [GpuAdvancedIncSubtensor1_dev20( return GpuAdvancedIncSubtensor1_dev20(
set_instead_of_inc=set_instead_of_inc)(x, y, ilist)] set_instead_of_inc=set_instead_of_inc)
@register_opt('fast_compile') @register_opt('fast_compile')
...@@ -621,7 +619,6 @@ def local_gpua_careduce(node, context_name): ...@@ -621,7 +619,6 @@ def local_gpua_careduce(node, context_name):
node.op.scalar_op, axis=node.op.axis, node.op.scalar_op, axis=node.op.axis,
dtype=getattr(node.op, 'dtype', None), dtype=getattr(node.op, 'dtype', None),
acc_dtype=getattr(node.op, 'acc_dtype', None)) acc_dtype=getattr(node.op, 'acc_dtype', None))
x.tag.context_name = context_name
gvar = greduce(x) gvar = greduce(x)
# We need to have the make node called, otherwise the mask can # We need to have the make node called, otherwise the mask can
# be None # be None
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论