提交 73fbb214 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Fix path for BatchedDot

上级 f87853c5
......@@ -156,7 +156,7 @@ cpu_ops_moved_to_gpu = [
tensor.Reshape, tensor.flatten, tensor.Subtensor,
tensor.AdvancedSubtensor1, tensor.AdvancedIncSubtensor1,
tensor.IncSubtensor, tensor.Shape, tensor.Join,
tensor.Alloc, tensor.Eye, tensor.BatchedDot]
tensor.Alloc, tensor.Eye, tensor.blas.BatchedDot]
class InputToGpuOptimizer(Optimizer):
......@@ -614,7 +614,7 @@ def local_gpu_dot22(node):
@register_opt()
@local_optimizer([gpu_from_host, tensor.BatchedDot])
@local_optimizer([gpu_from_host, tensor.blas.BatchedDot])
def local_gpu_batched_dot(node):
"""
gpu_from_host(batched_dot) -> gpu_batched_dot(gpu_from_host)
......@@ -641,10 +641,10 @@ def local_gpu_batched_dot(node):
if isinstance(node.op, GpuFromHost):
host_input = node.inputs[0]
if host_input.owner and isinstance(host_input.owner.op,
tensor.BatchedDot):
tensor.blas.BatchedDot):
x, y = host_input.owner.inputs
return [gpu_batched_dot(x, y)]
if isinstance(node.op, tensor.BatchedDot):
if isinstance(node.op, tensor.blas.BatchedDot):
if any([(i.owner and isinstance(i.owner.op, HostFromGpu))
for i in node.inputs]):
x, y = node.inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论