提交 933aa55b authored 作者: Frederic Bastien's avatar Frederic Bastien
......@@ -70,15 +70,17 @@ def cloned_env(inputs, outputs):
env = gof.env.Env(inputs, outputs)
return env
def std_env(inputs, outputs, disown_inputs = False):
def std_env(inputs, outputs, disown_inputs = False,
use_destroy_handler = True):
inputs, outputs = gof.graph.clone(inputs, outputs)
_mark_indestructible(outputs)
env = gof.env.Env(inputs, outputs)
env.extend(gof.DestroyHandler())
if use_destroy_handler:
env.extend(gof.DestroyHandler())
env.extend(gof.ReplaceValidate())
env.validate()
for input in inputs:
input.destroyed_by_user = len(env.destroyers(input)) != 0
input.destroyed_by_user = use_destroy_handler and len(env.destroyers(input)) != 0
if not input.destroyed_by_user and not disown_inputs:
# prevent optimizations from destroying the inputs
input.tag.indestructible = True
......@@ -97,13 +99,15 @@ predefined_linkers = {
class FunctionFactory:
def __init__(self, inputs, outputs, linker = 'py', optimizer = std_opt, borrow_outputs = False, disown_inputs = False):
def __init__(self, inputs, outputs, linker = 'py', optimizer = std_opt, borrow_outputs = False, disown_inputs = False,
use_destroy_handler = True):
if len(inputs) != len(set(inputs)):
print >>sys.stderr, "Warning: duplicate inputs"
for r in list(inputs) + list(outputs):
if not isinstance(r, gof.Result):
raise TypeError("All inputs and outputs to FunctionFactory should be Result instances. Received:", type(r), r)
env = std_env(inputs, outputs, disown_inputs = disown_inputs)
env = std_env(inputs, outputs, disown_inputs = disown_inputs,
use_destroy_handler = use_destroy_handler)
if None is not optimizer:
optimizer(env)
env.validate()
......@@ -144,13 +148,15 @@ def function(inputs,
disown_inputs = False,
profiler = None,
unpack_single = True,
strict = 'if_destroyed'):
strict = 'if_destroyed',
use_destroy_handler = True):
ff = FunctionFactory(inputs,
outputs,
linker = linker,
optimizer = optimizer,
borrow_outputs = borrow_outputs,
disown_inputs = disown_inputs)
disown_inputs = disown_inputs,
use_destroy_handler = use_destroy_handler)
return ff.create(profiler = profiler,
unpack_single = unpack_single,
strict = strict)
......
......@@ -272,6 +272,12 @@ class Elemwise(Op):
outputs = [Tensor(dtype = dtype, broadcastable = broadcastable)() for dtype, broadcastable in zip(out_dtypes, out_broadcastables)]
return Apply(self, inputs, outputs)
def __eq__(self, other):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.inplace_pattern == other.inplace_pattern
def __hash__(self):
return hash(self.scalar_op) ^ hash(self.inplace_pattern)
def __str__(self):
if self.name is None:
if self.inplace_pattern:
......
......@@ -670,7 +670,7 @@ class Tan(UnaryScalarOp):
def impl(self, x):
return math.tan(x)
def grad(self, (x, ), (gz, )):
return gz / (cos(x) ** 2),
return gz / sqr(cos(x)),
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = tan(%(x)s);" % locals()
tan = Tan(upgrade_to_float, name = 'tan')
......@@ -707,7 +707,7 @@ class Tanh(UnaryScalarOp):
def impl(self, x):
return math.tanh(x)
def grad(self, (x, ), (gz, )):
return gz * (1 - tanh(x)**2),
return gz * (1 - sqr(tanh(x))),
def c_code(self, node, name, (x, ), (z, ), sub):
return "%(z)s = tanh(%(x)s);" % locals()
tanh = Tanh(upgrade_to_float, name = 'tanh')
......
......@@ -390,7 +390,8 @@ class _tensor_py_operators:
def __iter__(self):
# This prevents accidental iteration via builtin.sum(self)
raise TypeError('Tensor does not support iteration')
raise TypeError('Tensor does not support iteration. '
'Maybe you are using builtin.sum instead of theano.tensor.sum?')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论