Swapped node.inputs[0].data with theano.tensor.extract_constant(node.inputs[0])

上级 403ab069
...@@ -2942,7 +2942,7 @@ class GpuJoin(tensor.Join, GpuOp): ...@@ -2942,7 +2942,7 @@ class GpuJoin(tensor.Join, GpuOp):
out[0] = rval out[0] = rval
def c_code(self, node, name, inputs, out_, sub): def c_code(self, node, name, inputs, out_, sub):
if node.inputs[0].data not in [0, 1]: if theano.tensor.extract_constant(node.inputs[0]) not in [0, 1]:
raise NotImplementedError() raise NotImplementedError()
# only works for the first two axis # only works for the first two axis
if len(inputs) != 3: if len(inputs) != 3:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论