提交 dc5617c1 authored 作者: lamblin's avatar lamblin

Merge pull request #1341 from pascanur/recent_scan_bugs

Recent scan bugs
...@@ -1159,7 +1159,7 @@ class ScanMerge(gof.Optimizer): ...@@ -1159,7 +1159,7 @@ class ScanMerge(gof.Optimizer):
Questionable, we should also consider profile ? Questionable, we should also consider profile ?
""" """
rep = set_nodes[0] rep = set_nodes[0]
if not rep.op.as_while and node.op.as_while: if rep.op.as_while != node.op.as_while:
return False return False
nsteps = node.inputs[0] nsteps = node.inputs[0]
......
...@@ -48,7 +48,10 @@ def safe_new(x, tag='', dtype=None): ...@@ -48,7 +48,10 @@ def safe_new(x, tag='', dtype=None):
nw_name = None nw_name = None
if isinstance(x, theano.Constant): if isinstance(x, theano.Constant):
if dtype and x.dtype != dtype: if dtype and x.dtype != dtype:
return x.clone().astype(dtype) casted_x = x.astype(dtype)
nwx = x.__class__(casted_x.type, x.data, x.name)
nwx.tag = copy(x.tag)
return nwx
else: else:
return x.clone() return x.clone()
# Note, as_tensor_variable will convert the Scalar into a # Note, as_tensor_variable will convert the Scalar into a
...@@ -70,6 +73,8 @@ def safe_new(x, tag='', dtype=None): ...@@ -70,6 +73,8 @@ def safe_new(x, tag='', dtype=None):
# ndarrays # ndarrays
pass pass
nw_x = x.type() nw_x = x.type()
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype).type()
nw_x.name = nw_name nw_x.name = nw_name
# Preserve test values so that the 'compute_test_value' option can be used. # Preserve test values so that the 'compute_test_value' option can be used.
# The test value is deep-copied to ensure there can be no interactions # The test value is deep-copied to ensure there can be no interactions
...@@ -82,9 +87,6 @@ def safe_new(x, tag='', dtype=None): ...@@ -82,9 +87,6 @@ def safe_new(x, tag='', dtype=None):
# This means `x` has no test value. # This means `x` has no test value.
pass pass
if dtype and nw_x.dtype != dtype:
nw_x = nw_x.astype(dtype)
return nw_x return nw_x
......
...@@ -3452,6 +3452,39 @@ class T_Scan(unittest.TestCase): ...@@ -3452,6 +3452,39 @@ class T_Scan(unittest.TestCase):
assert numpy.allclose(test(x, tensor.sum((x+1)**2), mention_y=True), assert numpy.allclose(test(x, tensor.sum((x+1)**2), mention_y=True),
1.21000003815) 1.21000003815)
def test_grad_find_input(self):
w = theano.shared(numpy.array(0, dtype='float32'), name='w')
init = tensor.fscalar('init')
out, _ = theano.scan(
fn=lambda prev: w,
outputs_info=init,
n_steps=2,
)
tensor.grad(out[-1], w)
def test_scan_merge_nodes(self):
inps = tensor.vector()
state = tensor.scalar()
y1, _ = theano.scan(lambda x,y: x*y,
sequences = inps,
outputs_info = state,
n_steps = 5)
y2, _ = theano.scan(lambda x,y : (x+y, theano.scan_module.until(x>0)),
sequences = inps,
outputs_info = state,
n_steps = 5)
scan_node1 = y1.owner.inputs[0].owner
assert isinstance(scan_node1.op, theano.scan_module.scan_op.Scan)
scan_node2 = y2.owner.inputs[0].owner
assert isinstance(scan_node2.op, theano.scan_module.scan_op.Scan)
opt_obj = theano.scan_module.scan_opt.ScanMerge()
# Test the method belongs_to of this class. Specifically see if it
# detects the two scan_nodes as not being similar
assert not opt_obj.belongs_to_set(scan_node1, [scan_node2])
assert not opt_obj.belongs_to_set(scan_node2, [scan_node1])
def test_speed(): def test_speed():
# #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论