提交 aed7bd08 authored 作者: John Salvatier's avatar John Salvatier

fixed inc_subtensor bug with reshape

上级 23040281
...@@ -5084,6 +5084,7 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, ...@@ -5084,6 +5084,7 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
""" """
# First of all, y cannot have a higher dimension than x, # First of all, y cannot have a higher dimension than x,
# nor have non-broadcastable dimensions where x is broadcastable. # nor have non-broadcastable dimensions where x is broadcastable.
x = as_tensor_variable(x) x = as_tensor_variable(x)
y = as_tensor_variable(y) y = as_tensor_variable(y)
...@@ -5154,7 +5155,7 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False, ...@@ -5154,7 +5155,7 @@ def inc_subtensor(x, y, inplace=False, set_instead_of_inc=False,
# Try to apply inc_subtensor on inner_x. # Try to apply inc_subtensor on inner_x.
# If it works, there is no need to reshape, as the inc_subtensor # If it works, there is no need to reshape, as the inc_subtensor
# will have the same shape as inner_x, which is what we want. # will have the same shape as inner_x, which is what we want.
inner_incsubtensor = inc_subtensor(inner_x, y, inner_incsubtensor = inc_subtensor(inner_x, y.flatten(),
inplace=inplace, inplace=inplace,
set_instead_of_inc=set_instead_of_inc, set_instead_of_inc=set_instead_of_inc,
tolerate_inplace_aliasing=tolerate_inplace_aliasing) tolerate_inplace_aliasing=tolerate_inplace_aliasing)
......
...@@ -3756,7 +3756,8 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -3756,7 +3756,8 @@ class TestAdvancedSubtensor(unittest.TestCase):
if inplace_increment is None: if inplace_increment is None:
raise inplace_increment_missing raise inplace_increment_missing
a = inc_subtensor(self.v[self.ix2], self.v[self.ix2]) subt = self.v[self.ix2]
a = inc_subtensor(subt,subt)
assert a.type == self.v.type, (a.type, self.v.type) assert a.type == self.v.type, (a.type, self.v.type)
f = theano.function([self.v, self.ix2], a, allow_input_downcast=True) f = theano.function([self.v, self.ix2], a, allow_input_downcast=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论