提交 721e1885 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

added setdefault op

上级 6a159276
...@@ -142,9 +142,9 @@ class Container(object): ...@@ -142,9 +142,9 @@ class Container(object):
self.storage = storage self.storage = storage
self.readonly = readonly self.readonly = readonly
self.strict = strict self.strict = strict
def __get(self): def __get__(self):
return self.storage[0] return self.storage[0]
def __set(self, value): def __set__(self, value):
if self.readonly: if self.readonly:
raise Exception("Cannot set readonly storage: %s" % self.name) raise Exception("Cannot set readonly storage: %s" % self.name)
try: try:
...@@ -158,8 +158,8 @@ class Container(object): ...@@ -158,8 +158,8 @@ class Container(object):
except Exception, e: except Exception, e:
e.args = e.args + (('Container name "%s"' % self.name),) e.args = e.args + (('Container name "%s"' % self.name),)
raise raise
data = property(__get, __set) data = property(__get__, __set__)
value = property(__get, __set) value = property(__get__, __set__)
def __str__(self): def __str__(self):
return "<" + str(self.storage[0]) + ">" return "<" + str(self.storage[0]) + ">"
def __repr__(self): def __repr__(self):
......
...@@ -359,9 +359,11 @@ class TensorType(Type): ...@@ -359,9 +359,11 @@ class TensorType(Type):
def c_extract(self, name, sub): def c_extract(self, name, sub):
"""Override `CLinkerOp.c_extract` """ """Override `CLinkerOp.c_extract` """
# TODO: make the error message print out the dtype of the
# input received.
return """ return """
%(name)s = NULL; %(name)s = NULL;
type_num_%(name)s = %(type_num)s; type_num_%(name)s = ((PyArrayObject*)py_%(name)s)->descr->type_num; //we expect %(type_num)s
if (py_%(name)s == Py_None) { if (py_%(name)s == Py_None) {
// We can either fail here or set %(name)s to NULL and rely on Ops using // We can either fail here or set %(name)s to NULL and rely on Ops using
// tensors to handle the NULL case, but if they fail to do so they'll end up // tensors to handle the NULL case, but if they fail to do so they'll end up
...@@ -373,7 +375,7 @@ class TensorType(Type): ...@@ -373,7 +375,7 @@ class TensorType(Type):
PyErr_SetString(PyExc_ValueError, "expected an ndarray"); PyErr_SetString(PyExc_ValueError, "expected an ndarray");
%(fail)s %(fail)s
} }
else if (((PyArrayObject*)py_%(name)s)->descr->type_num != %(type_num)s) { else if (type_num_%(name)s != %(type_num)s) {
PyErr_SetString(PyExc_ValueError, "expected %(type_num)s"); PyErr_SetString(PyExc_ValueError, "expected %(type_num)s");
%(fail)s %(fail)s
} }
...@@ -1353,6 +1355,15 @@ class Repeat(gof.Op): ...@@ -1353,6 +1355,15 @@ class Repeat(gof.Op):
repeat = Repeat() repeat = Repeat()
class SetDefault(gof.Op):
view_map = {0: [1]}
def make_node(self, x, default):
assert x.type == default.type
return gof.Apply(self, [x, default], [default.type()])
def perform(self, node, (x, default), (out, )):
out[0] = default.copy() if x is None else x
setdefault = SetDefault()
########################## ##########################
...@@ -1759,6 +1770,7 @@ class Split(Op): ...@@ -1759,6 +1770,7 @@ class Split(Op):
"""Join the gradients along the axis that was used to split x.""" """Join the gradients along the axis that was used to split x."""
return [join(axis, *g_outputs), None, None] return [join(axis, *g_outputs), None, None]
class Join(Op): class Join(Op):
""" """
Concatenate multiple `TensorVariable`s along some axis. Concatenate multiple `TensorVariable`s along some axis.
...@@ -2383,6 +2395,7 @@ class Outer(Op): ...@@ -2383,6 +2395,7 @@ class Outer(Op):
return "outer" return "outer"
outer = Outer() outer = Outer()
######################### #########################
# Gradient # Gradient
######################### #########################
......
...@@ -388,10 +388,12 @@ def local_softmax_with_bias(node): ...@@ -388,10 +388,12 @@ def local_softmax_with_bias(node):
vectors = [] vectors = []
non_vectors = [] non_vectors = []
for x_in in x.owner.inputs: for x_in in x.owner.inputs:
if list(x_in.type.broadcastable) == [True, False] \ if list(x_in.type.broadcastable) == [True, False]:
and isinstance(x_in.owner.op, tensor.DimShuffle): if x_in.owner and isinstance(x_in.owner.op, tensor.DimShuffle):
assert len(x_in.owner.inputs)==1 assert len(x_in.owner.inputs)==1
vectors.append(x_in.owner.inputs[0]) vectors.append(x_in.owner.inputs[0])
else:
vectors.append(tensor.DimShuffle((True, False), (1,))(x_in))
else: else:
non_vectors.append(x_in) non_vectors.append(x_in)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论