提交 13a54c99 authored 作者: Frederic Bastien's avatar Frederic Bastien

Added test following code review and generalized the change done in that commit.

上级 06a256ea
......@@ -4011,7 +4011,7 @@ class Join(Op):
bcastable = [False] * len(as_tensor_variable_args[0].type.broadcastable)
ndim = len(bcastable)
# Axis can also be a constant
if isinstance(axis, Constant):
if not isinstance(axis, int):
try:
# Note : `get_constant_value` returns a ndarray not a
# int
......@@ -5357,4 +5357,3 @@ def outer(x, y):
return dot(
x.dimshuffle(0, 'x'),
y.dimshuffle('x', 0))
......@@ -2553,6 +2553,17 @@ class T_Join_and_Split(unittest.TestCase):
assert c.type.broadcastable[0] and c.type.broadcastable[2]
assert not c.type.broadcastable[1]
# Opt can remplace the int by a Theano constant
c = join(theano.tensor.constant(1), a, b)
assert c.type.broadcastable[0] and c.type.broadcastable[2]
assert not c.type.broadcastable[1]
# In case futur opt insert other useless stuff
c = join(theano.tensor.cast(theano.tensor.constant(1), dtype="int32"),
a, b)
assert c.type.broadcastable[0] and c.type.broadcastable[2]
assert not c.type.broadcastable[1]
f = function([a,b], c)
rng = numpy.random.RandomState(seed=utt.fetch_seed())
a_val = rng.rand(1, 4, 1).astype(config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论