提交 f35f6ba9 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merged

...@@ -346,7 +346,7 @@ def local_IncSubtensor_serialize(node): ...@@ -346,7 +346,7 @@ def local_IncSubtensor_serialize(node):
# #
# add(x, incsubtensor(b, c), incsubtensor(b, d)) # add(x, incsubtensor(b, c), incsubtensor(b, d))
# -> incsubtensor(incsubtensor(add(x,b), c), d) # -> incsubtensor(incsubtensor(add(x,b,b), c), d)
""" """
def movable(i): def movable(i):
...@@ -354,7 +354,8 @@ def local_IncSubtensor_serialize(node): ...@@ -354,7 +354,8 @@ def local_IncSubtensor_serialize(node):
return i.owner \ return i.owner \
and isinstance(i.owner.op, T.IncSubtensor) \ and isinstance(i.owner.op, T.IncSubtensor) \
and i.type == o_type \ and i.type == o_type \
and len(i.clients) == 1 and len(i.clients) == 1 \
and not i.owner.op.set_instead_of_inc
if node.op == T.add: if node.op == T.add:
o_type = node.outputs[0].type o_type = node.outputs[0].type
...@@ -383,7 +384,8 @@ def local_IncSubtensor_serialize(node): ...@@ -383,7 +384,8 @@ def local_IncSubtensor_serialize(node):
@gof.local_optimizer([None]) @gof.local_optimizer([None])
def local_inplace_setsubtensor(node): def local_inplace_setsubtensor(node):
if isinstance(node.op, T.IncSubtensor) and not node.op.inplace: if isinstance(node.op, T.IncSubtensor) and not node.op.inplace:
new_op = T.IncSubtensor(node.op.idx_list, inplace=True) new_op = T.IncSubtensor(node.op.idx_list, inplace=True, \
set_instead_of_inc=node.op.set_instead_of_inc)
new_node = new_op(*node.inputs) new_node = new_op(*node.inputs)
return [new_node] return [new_node]
return False return False
......
...@@ -309,7 +309,9 @@ def permutation_helper(random_state, n, shape): ...@@ -309,7 +309,9 @@ def permutation_helper(random_state, n, shape):
""" """
# n should be a 0-dimension array # n should be a 0-dimension array
assert n.shape == () assert n.shape == ()
n = n.item() # Note that it is important to convert `n` into an integer, because if it
# is a long, the numpy permutation function will crash on Windows.
n = int(n.item())
out_shape = list(shape) out_shape = list(shape)
out_shape.append(n) out_shape.append(n)
......
...@@ -35,7 +35,7 @@ class ScalarSharedVariable(SharedVariable, _tensor_py_operators): ...@@ -35,7 +35,7 @@ class ScalarSharedVariable(SharedVariable, _tensor_py_operators):
@shared_constructor @shared_constructor
def scalar_constructor(value, name=None, strict=False, dtype=None): def scalar_constructor(value, name=None, strict=False, dtype=None):
"""SharedVariable constructor for scalar values. Defaults to int64 or float64. """SharedVariable constructor for scalar values. Default: int64 or float64.
:note: We implement this using 0-d tensors for now. :note: We implement this using 0-d tensors for now.
...@@ -50,12 +50,14 @@ def scalar_constructor(value, name=None, strict=False, dtype=None): ...@@ -50,12 +50,14 @@ def scalar_constructor(value, name=None, strict=False, dtype=None):
else: else:
dtype = type(value).__name__ dtype = type(value).__name__
type = TensorType(dtype=dtype, broadcastable=[]) tensor_type = TensorType(dtype=dtype, broadcastable=[])
try: try:
# don't pass the dtype to asarray because we want this to fail if strict is True and the # Do not pass the dtype to asarray because we want this to fail if
# types do not match # strict is True and the types do not match.
rval = ScalarSharedVariable(type=type, value=numpy.asarray(value), name=name, strict=strict) rval = ScalarSharedVariable(type=tensor_type,
value=numpy.asarray(value),
name=name, strict=strict)
return rval return rval
except: except:
traceback.print_exc() traceback.print_exc()
......
...@@ -277,12 +277,12 @@ class T_RandomStreams(unittest.TestCase): ...@@ -277,12 +277,12 @@ class T_RandomStreams(unittest.TestCase):
assert numpy.all(fn_val1 == numpy_val1) assert numpy.all(fn_val1 == numpy_val1)
def test_shuffle_row_elements(self): def test_shuffle_row_elements(self):
"""Test that RandomStreams.shuffle_row_elements generates the right results""" """Ensure RandomStreams.shuffle_row_elements generates right results"""
# Check over two calls to see if the random state is correctly updated. # Check over two calls to see if the random state is correctly updated.
# On matrices, for each row, the elements of that row should be
# On matrices, for each row, the elements of that row should be shuffled. # shuffled.
# Note that this differs from numpy.random.shuffle, where all the elements # Note that this differs from numpy.random.shuffle, where all the
# of the matrix are shuffled. # elements of the matrix are shuffled.
mm = Module() mm = Module()
mm.random = RandomStreams(234) mm.random = RandomStreams(234)
m_input = tensor.dmatrix() m_input = tensor.dmatrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论