提交 4ab51bb5 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Fix for replace join with gpu_join bug.

Consecutive calls of the the join op ( during the optimization phase) get the axis as a constant. What used to happen was that in such cases Theano would not use that constant to define its broadcastable pattern.
上级 a7ab337e
...@@ -3895,6 +3895,15 @@ class Join(Op): ...@@ -3895,6 +3895,15 @@ class Join(Op):
# the loops. # the loops.
bcastable = [False] * len(as_tensor_variable_args[0].type.broadcastable) bcastable = [False] * len(as_tensor_variable_args[0].type.broadcastable)
ndim = len(bcastable) ndim = len(bcastable)
# Axis can also be a constant
if isinstance(axis, Constant):
try:
# Note : `get_constant_value` returns a ndarray not a
# int
axis = int(get_constant_value(axis))
except TypeError:
pass
if isinstance(axis, int): if isinstance(axis, int):
# Basically, broadcastable -> length 1, but the converse does not # Basically, broadcastable -> length 1, but the converse does not
# hold. So we permit e.g. T/F/T joins, and if they fail at runtime # hold. So we permit e.g. T/F/T joins, and if they fail at runtime
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论