提交 f386cdd1 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4820 from nouiz/local_gpua_careduce

Local gpua careduce
...@@ -373,6 +373,14 @@ class GraphToGPU(NavigatorOptimizer): ...@@ -373,6 +373,14 @@ class GraphToGPU(NavigatorOptimizer):
if new_ops: if new_ops:
node_created[lopt] += len(graph.ops([mapping[i] for i in node.inputs], outputs)) node_created[lopt] += len(graph.ops([mapping[i] for i in node.inputs], outputs))
if any([getattr(old_o, 'dtype', None) != getattr(new_o, 'dtype', None)
for old_o, new_o in zip(outputs, node.outputs)]):
_logger.warning(
"The optimization %s returned bad dtype. Skipping it."
" Write to theano-dev mailing list about this." %
str(lopt))
newnode = node.clone_with_new_inputs([mapping.get(i) for i in node.inputs])
outputs = newnode.outputs
for new_o, old_o in zip(outputs, node.outputs): for new_o, old_o in zip(outputs, node.outputs):
assert len(outputs) == len(node.outputs) assert len(outputs) == len(node.outputs)
...@@ -1009,10 +1017,9 @@ def local_gpua_careduce(op, context_name, inputs, outputs): ...@@ -1009,10 +1017,9 @@ def local_gpua_careduce(op, context_name, inputs, outputs):
else: else:
return False return False
x, = inputs x, = inputs
greduce = op2( greduce = op2(
op.scalar_op, axis=op.axis, op.scalar_op, axis=op.axis,
dtype=getattr(op, 'dtype', None), dtype=getattr(op, 'dtype', outputs[0].dtype),
acc_dtype=getattr(op, 'acc_dtype', None)) acc_dtype=getattr(op, 'acc_dtype', None))
gvar = greduce(x) gvar = greduce(x)
# We need to have the make node called, otherwise the mask can # We need to have the make node called, otherwise the mask can
...@@ -1051,7 +1058,7 @@ def local_gpua_careduce(op, context_name, inputs, outputs): ...@@ -1051,7 +1058,7 @@ def local_gpua_careduce(op, context_name, inputs, outputs):
greduce = op2( greduce = op2(
op.scalar_op, op.scalar_op,
axis=new_axis, reduce_mask=new_mask, axis=new_axis, reduce_mask=new_mask,
dtype=getattr(op, 'dtype', None), dtype=getattr(op, 'dtype', outputs[0].dtype),
acc_dtype=getattr(op, 'acc_dtype', None)) acc_dtype=getattr(op, 'acc_dtype', None))
reshaped_x = x.reshape(tensor.stack(new_in_shp)) reshaped_x = x.reshape(tensor.stack(new_in_shp))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论