提交 9f6f3579 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Change the test in op_lifter to be really OP-specific.

上级 9f48c753
...@@ -50,7 +50,7 @@ def op_lifter(OP): ...@@ -50,7 +50,7 @@ def op_lifter(OP):
""" """
def f(maker): def f(maker):
def local_opt(node): def local_opt(node):
if isinstance(node.op, OP): if type(node.op) is OP:
# This does not support nodes that have more than one output. # This does not support nodes that have more than one output.
assert len(node.outputs) == 1 assert len(node.outputs) == 1
# either one of our inputs is on the gpu or # either one of our inputs is on the gpu or
...@@ -128,8 +128,6 @@ def local_gpualloc(node): ...@@ -128,8 +128,6 @@ def local_gpualloc(node):
def local_gpureshape(node): def local_gpureshape(node):
op = node.op op = node.op
name = op.name name = op.name
if type(node.op) is not tensor.Reshape:
return None
if name: if name:
name = 'Gpu' + name name = 'Gpu' + name
res = GpuReshape(op.ndim, op.name) res = GpuReshape(op.ndim, op.name)
...@@ -140,8 +138,6 @@ def local_gpureshape(node): ...@@ -140,8 +138,6 @@ def local_gpureshape(node):
@op_lifter(tensor.Flatten) @op_lifter(tensor.Flatten)
def local_gpuflatten(node): def local_gpuflatten(node):
op = node.op op = node.op
if type(node.op) is not tensor.Flatten:
return None
if op.outdim != 1: if op.outdim != 1:
return None return None
res = GpuReshape(op.outdim, None) res = GpuReshape(op.outdim, None)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论