提交 7f17dde5 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

merge.

......@@ -193,6 +193,8 @@ class Mode(object):
def including(self, *tags):
link, opt = self.get_linker_optimizer(self.provided_linker, self.provided_optimizer)
#N.B. opt might be a Query instance, not sure what else it might be...
# string? Optimizer? OptDB? who knows???
return self.__class__(linker=link, optimizer=opt.including(*tags))
def excluding(self, *tags):
......
......@@ -98,10 +98,10 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
clone_d = {}
# Updates as list and dictionary.
# They will also store the 'default_update' expressions applicable.
# The dictionary is used to look up the existence of the keys, and to store
# the final (cloned) update expressions.
# The list of pairs is used to iterate in a consistent order while adding
# They will both store the 'default_update' expressions (where applicable).
# The dictionary (update_d) is used to look up the existence of the keys, and to store
# the final [cloned] update expressions.
# The list of pairs (update_expr) is used to iterate in a consistent order while adding
# new pairs.
update_d = {}
update_expr = []
......@@ -109,10 +109,11 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
shared_inputs = []
def clone_v_get_shared_updates(v):
'''Clone a variable and its inputs, until all are in clone_d.
'''Clone a variable and its inputs recursively until all are in clone_d.
Also appends all shared variables met along the way to shared_inputs,
and their default_update (if applicable) to update_d and update_expr.
'''
# this method co-recurses with clone_a
assert v is not None
if v.owner:
clone_a(v.owner)
......@@ -137,6 +138,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
return clone_d.setdefault(v, v)
def clone_a(a):
# this method co-recurses with clone_v_get_shared_updates
if a is None:
return None
if a not in clone_d:
......@@ -174,12 +176,13 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
#set_of_param_variables = set(input_variables)
# It was decided, as a first step, to prevent shared variables from being
# used as function inputs. Although it is technically possible, it is also
# potentially ambiguous and dangerous. This restriction may be revisited in
# the future if there is a need for such a feature.
# used as function inputs. Although it is technically possible, it is also not clear
# when/how to use the value of that shared variable (is it a default? ignored?, if the
# shared variable changes, does that function default also change?).
if numpy.any([isinstance(v, SharedVariable) for v in input_variables]):
raise TypeError('Cannot use a shared variable (%s) as explicit input '
% v)
raise TypeError(('Cannot use a shared variable (%s) as explicit input.'
' Consider substituting a non-shared'
' variable via the `givens` parameter') % v)
# Fill update_d and update_expr with provided updates
for (store_into, update_val) in iter_over_pairs(updates):
......@@ -189,7 +192,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
raise ValueError('this shared variable already has an update expression',
(store_into, update_d[store_into]))
update_val = store_into.filter_update(update_val)
update_val = store_into.filter_update(update_val) # typically this might be a cast()
if update_val.type != store_into.type:
raise TypeError('an update must have the same type as the original shared variable',
(store_into, store_into.type,
......@@ -224,7 +227,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
cloned_outputs = Out(cloned_v, borrow=outputs.borrow)
#computed_list.append(cloned_v)
elif outputs is None:
cloned_outputs = [] # TODO: return None
cloned_outputs = [] # TODO: get Function.__call__ to return None
else:
raise TypeError('output must be a theano Variable or Out instance (or list of them)', outputs)
......
......@@ -905,36 +905,39 @@ class test_fusion(unittest.TestCase):
#g.owner.inputs[0] is out... make owner a weakref?
def test_log1p():
m = theano.compile.default_mode
if m == 'FAST_COMPILE':
m = 'FAST_RUN'
# check some basic cases
x = dvector()
f = function([x], T.log(1+(x)), mode='FAST_RUN')
f = function([x], T.log(1+(x)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.log1p]
f = (function([x], T.log(1+(-x))), mode='FAST_RUN')
f = function([x], T.log(1+(-x)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.neg, inplace.log1p_inplace]
f = (function([x], -T.log(1+(-x))), mode='FAST_RUN')
f = function([x], -T.log(1+(-x)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.neg, inplace.log1p_inplace, inplace.neg_inplace]
# check trickier cases (and use different dtype)
y = fmatrix()
f = (function([x,y], T.log(fill(y,1)+(x))), mode='FAST_RUN')
f = function([x,y], T.log(fill(y,1)+(x)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.DimShuffle([False], ['x', 0], True), T.log1p, T.fill]
f = (function([x,y], T.log(0+(x) + fill(y,1.0) )), mode='FAST_RUN')
f = function([x,y], T.log(0+(x) + fill(y,1.0)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.DimShuffle([False], ['x', 0], True), T.log1p, T.fill]
f = (function([x,y], T.log(2+(x) - fill(y,1.0) )), mode='FAST_RUN')
f = function([x,y], T.log(2+(x) - fill(y,1.0)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.DimShuffle([False], ['x', 0], True), T.log1p, T.fill]
f([1e-7, 10], [[0, 0], [0, 0]]) #debugmode will verify values
# should work for complex
z = zmatrix()
f = function([z], T.log(1+(z)), mode='FAST_RUN')
f = function([z], T.log(1+(z)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.log1p]
# should work for int
z = imatrix()
f = function([z], T.log(1+(z)), mode='FAST_RUN')
f = function([z], T.log(1+(z)), mode=m)
assert [node.op for node in f.maker.env.toposort()] == [T.log1p]
if __name__ == '__main__':
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论