提交 f8bbacbd authored 作者: carriepl's avatar carriepl

Split local_log_softmax_dnn into two separate op lifters

上级 487cf52c
...@@ -1459,7 +1459,7 @@ def local_softmax_dnn(node): ...@@ -1459,7 +1459,7 @@ def local_softmax_dnn(node):
@register_opt('cudnn') @register_opt('cudnn')
@local_optimizer([GpuElemwise, LogSoftmax]) @local_optimizer([GpuElemwise])
def local_log_softmax_dnn(node): def local_log_softmax_dnn(node):
if version() < 3000: if version() < 3000:
# No log-softmax before cudnn v3 # No log-softmax before cudnn v3
...@@ -1474,8 +1474,16 @@ def local_log_softmax_dnn(node): ...@@ -1474,8 +1474,16 @@ def local_log_softmax_dnn(node):
new_softmax = GpuDnnSoftmax('log', softmax_node.op.mode) new_softmax = GpuDnnSoftmax('log', softmax_node.op.mode)
return [new_softmax(softmax_node.inputs[0])] return [new_softmax(softmax_node.inputs[0])]
elif (isinstance(node.op, LogSoftmax) and node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, HostFromGpu)): @register_opt('cudnn')
@op_lifter([LogSoftmax])
def local_logsoftmax_to_dnn(node, ctx_name):
if not dnn_available(ctx_name) or version() < 3:
# No log-softmax before cudnn v3
return
if (isinstance(node.op, LogSoftmax) and node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, HostFromGpu)):
# Transform the input in the format expected by GpuDnnSoftmax # Transform the input in the format expected by GpuDnnSoftmax
inp = node.inputs[0].owner.inputs[0] inp = node.inputs[0].owner.inputs[0]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论