提交 f35f6ba9 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merged

......@@ -346,7 +346,7 @@ def local_IncSubtensor_serialize(node):
#
# add(x, incsubtensor(b, c), incsubtensor(b, d))
# -> incsubtensor(incsubtensor(add(x,b), c), d)
# -> incsubtensor(incsubtensor(add(x,b,b), c), d)
"""
def movable(i):
......@@ -354,7 +354,8 @@ def local_IncSubtensor_serialize(node):
return i.owner \
and isinstance(i.owner.op, T.IncSubtensor) \
and i.type == o_type \
and len(i.clients) == 1
and len(i.clients) == 1 \
and not i.owner.op.set_instead_of_inc
if node.op == T.add:
o_type = node.outputs[0].type
......@@ -383,7 +384,8 @@ def local_IncSubtensor_serialize(node):
@gof.local_optimizer([None])
def local_inplace_setsubtensor(node):
if isinstance(node.op, T.IncSubtensor) and not node.op.inplace:
new_op = T.IncSubtensor(node.op.idx_list, inplace=True)
new_op = T.IncSubtensor(node.op.idx_list, inplace=True, \
set_instead_of_inc=node.op.set_instead_of_inc)
new_node = new_op(*node.inputs)
return [new_node]
return False
......
......@@ -309,7 +309,9 @@ def permutation_helper(random_state, n, shape):
"""
# n should be a 0-dimension array
assert n.shape == ()
n = n.item()
# Note that it is important to convert `n` into an integer, because if it
# is a long, the numpy permutation function will crash on Windows.
n = int(n.item())
out_shape = list(shape)
out_shape.append(n)
......
......@@ -35,7 +35,7 @@ class ScalarSharedVariable(SharedVariable, _tensor_py_operators):
@shared_constructor
def scalar_constructor(value, name=None, strict=False, dtype=None):
"""SharedVariable constructor for scalar values. Defaults to int64 or float64.
"""SharedVariable constructor for scalar values. Default: int64 or float64.
:note: We implement this using 0-d tensors for now.
......@@ -50,12 +50,14 @@ def scalar_constructor(value, name=None, strict=False, dtype=None):
else:
dtype = type(value).__name__
type = TensorType(dtype=dtype, broadcastable=[])
tensor_type = TensorType(dtype=dtype, broadcastable=[])
try:
# don't pass the dtype to asarray because we want this to fail if strict is True and the
# types do not match
rval = ScalarSharedVariable(type=type, value=numpy.asarray(value), name=name, strict=strict)
# Do not pass the dtype to asarray because we want this to fail if
# strict is True and the types do not match.
rval = ScalarSharedVariable(type=tensor_type,
value=numpy.asarray(value),
name=name, strict=strict)
return rval
except:
traceback.print_exc()
......
......@@ -277,12 +277,12 @@ class T_RandomStreams(unittest.TestCase):
assert numpy.all(fn_val1 == numpy_val1)
def test_shuffle_row_elements(self):
"""Test that RandomStreams.shuffle_row_elements generates the right results"""
"""Ensure RandomStreams.shuffle_row_elements generates right results"""
# Check over two calls to see if the random state is correctly updated.
# On matrices, for each row, the elements of that row should be shuffled.
# Note that this differs from numpy.random.shuffle, where all the elements
# of the matrix are shuffled.
# On matrices, for each row, the elements of that row should be
# shuffled.
# Note that this differs from numpy.random.shuffle, where all the
# elements of the matrix are shuffled.
mm = Module()
mm.random = RandomStreams(234)
m_input = tensor.dmatrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论