提交 e5b8c40c authored 作者: Olivier Breuleux's avatar Olivier Breuleux

merge

差异被折叠。
...@@ -183,6 +183,9 @@ class Result(object): ...@@ -183,6 +183,9 @@ class Result(object):
""" """
return False return False
def c_literal(self):
raise AbstractFunctionError()
def c_declare(self, name, sub): def c_declare(self, name, sub):
""" """
Declares variables that will be instantiated by L{c_extract}. Declares variables that will be instantiated by L{c_extract}.
......
...@@ -85,6 +85,22 @@ def grad_sources_inputs(sources, graph_inputs): ...@@ -85,6 +85,22 @@ def grad_sources_inputs(sources, graph_inputs):
output_arg = g_outputs output_arg = g_outputs
input_arg = op.inputs input_arg = op.inputs
try:
dinputs = [x[0] for x in op.destroy_map().values()]
except AttributeError:
dinputs = []
# input_arg = [input in dinputs and input.copy() or input for input in input_arg]
new_input_arg = []
for input in input_arg:
if input in dinputs:
new_input_arg.append(input.copy())
else:
new_input_arg.append(input)
input_arg = new_input_arg
op_grad = op.grad(input_arg, output_arg) op_grad = op.grad(input_arg, output_arg)
if not isinstance(op_grad, (list,tuple)): if not isinstance(op_grad, (list,tuple)):
raise ValueError(_msg_retType, op.__class__) raise ValueError(_msg_retType, op.__class__)
......
...@@ -291,7 +291,7 @@ class Identity(UnaryScalarOp): ...@@ -291,7 +291,7 @@ class Identity(UnaryScalarOp):
return x return x
def c_code(self, (x, ), (z, ), sub): def c_code(self, (x, ), (z, ), sub):
return "%(z)s = %(x)s;" % locals() return "%(z)s = %(x)s;" % locals()
def grad(self, (x, y), (gz, )): def grad(self, (x, ), (gz, )):
return gz, return gz,
class Neg(UnaryScalarOp): class Neg(UnaryScalarOp):
......
...@@ -67,6 +67,10 @@ class Tensor(BaseTensor): ...@@ -67,6 +67,10 @@ class Tensor(BaseTensor):
#SLICING #SLICING
def __getitem__(self, item): return subtensor(self, item) def __getitem__(self, item): return subtensor(self, item)
def __getslice__(self, *args): return subtensor(self, slice(*args)) def __getslice__(self, *args): return subtensor(self, slice(*args))
#COPYING
def copy(self): return tensor_copy(self)
s2t.Tensor = Tensor s2t.Tensor = Tensor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论