提交 5ad3c667 authored 作者: Frederic's avatar Frederic

pep8

上级 1f833c24
...@@ -1795,6 +1795,7 @@ shape = Shape() ...@@ -1795,6 +1795,7 @@ shape = Shape()
_shape = shape #was used in the past, now use shape directly. _shape = shape #was used in the past, now use shape directly.
pprint.assign(_shape, printing.MemberPrinter('shape')) pprint.assign(_shape, printing.MemberPrinter('shape'))
class SpecifyShape(Op): class SpecifyShape(Op):
""" """
L{Op} put into the graph the user provided shape L{Op} put into the graph the user provided shape
...@@ -1808,14 +1809,18 @@ class SpecifyShape(Op): ...@@ -1808,14 +1809,18 @@ class SpecifyShape(Op):
@note: We currently don't support specifying partial shape information. @note: We currently don't support specifying partial shape information.
""" """
view_map = {0: [0]} view_map = {0: [0]}
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
def make_node(self, x, shape): def make_node(self, x, shape):
if not isinstance(x,Variable): if not isinstance(x, Variable):
x = as_tensor_variable(x) x = as_tensor_variable(x)
shape = as_tensor_variable(shape) shape = as_tensor_variable(shape)
return Apply(self, [x, shape], [x.type()]) return Apply(self, [x, shape], [x.type()])
...@@ -1823,22 +1828,22 @@ class SpecifyShape(Op): ...@@ -1823,22 +1828,22 @@ class SpecifyShape(Op):
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
x, shape = inp x, shape = inp
out, = out_ out, = out_
assert numpy.all(x.shape==shape), ("got shape", x.shape, assert numpy.all(x.shape == shape), ("got shape", x.shape,
"expected", shape) "expected", shape)
out[0] = x out[0] = x
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
xshape, sshape = shapes xshape, sshape = shapes
new_shape=[] new_shape = []
for dim in xrange(node.inputs[0].ndim): for dim in xrange(node.inputs[0].ndim):
try: try:
s=get_constant_value(node.inputs[1][dim]) s = get_constant_value(node.inputs[1][dim])
s=as_tensor_variable(s) s = as_tensor_variable(s)
new_shape.append(s) new_shape.append(s)
except TypeError, e: except TypeError, e:
new_shape.append(node.inputs[1][dim]) new_shape.append(node.inputs[1][dim])
assert len(new_shape)==len(xshape) assert len(new_shape) == len(xshape)
return [new_shape] return [new_shape]
def grad(self, inp, grads): def grad(self, inp, grads):
...@@ -1847,9 +1852,10 @@ class SpecifyShape(Op): ...@@ -1847,9 +1852,10 @@ class SpecifyShape(Op):
# Should I set an SpecifyShape on gz? I think so # Should I set an SpecifyShape on gz? I think so
# But I don't do it now as we need to make an optimization # But I don't do it now as we need to make an optimization
# to remove that op from the graph to don't block other optimization # to remove that op from the graph to don't block other optimization
# Should I do an optimizer that will remove the SpecifyShape? I think Yes # Should I do an optimizer that will remove the SpecifyShape?
# I think Yes
return [gz, None] return [gz, None]
return [specify_shape(gz,s), None] return [specify_shape(gz, s), None]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
...@@ -1860,31 +1866,35 @@ class SpecifyShape(Op): ...@@ -1860,31 +1866,35 @@ class SpecifyShape(Op):
specify_shape = SpecifyShape() specify_shape = SpecifyShape()
class MaxAndArgmax(Op): class MaxAndArgmax(Op):
"""Calculate the max and argmax over a given axis. """Calculate the max and argmax over a given axis.
""" """
nin=2 # tensor, axis nin = 2 # tensor, axis
nout=2 # max val, max idx nout = 2 # max val, max idx
E_axis = 'invalid axis' E_axis = 'invalid axis'
def __eq__(self,other): def __eq__(self, other):
return type(self)==type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
def make_node(self, x, axis=None): def make_node(self, x, axis=None):
x = _as_tensor_variable(x) x = _as_tensor_variable(x)
if isinstance(axis,int): if isinstance(axis, int):
axis = [axis] axis = [axis]
elif isinstance(axis,(tuple,list)): elif isinstance(axis, (tuple, list)):
assert len(axis)==1,"MaxAndArgmax don't support multiple axis. the max fct support it." assert len(axis) == 1, ("MaxAndArgmax don't support multiple"
#we make the axis all positive to make the infer_shape work with negative axis " axis. the max fct support it.")
if x.type.ndim>0 and axis is not None: # we make the axis all positive to make the infer_shape work
for id,a in enumerate(axis): # with negative axis
if not isinstance(a, TensorVariable) and a<0: if x.type.ndim > 0 and axis is not None:
if -a>x.type.ndim: for id, a in enumerate(axis):
if not isinstance(a, TensorVariable) and a < 0:
if -a > x.type.ndim:
raise ValueError('axis out of range') raise ValueError('axis out of range')
axis[id]=x.type.ndim+a axis[id] = x.type.ndim + a
if axis is None: if axis is None:
axis = _as_tensor_variable(range(x.type.ndim)) axis = _as_tensor_variable(range(x.type.ndim))
else: else:
...@@ -1893,9 +1903,10 @@ class MaxAndArgmax(Op): ...@@ -1893,9 +1903,10 @@ class MaxAndArgmax(Op):
inputs = [x, axis] inputs = [x, axis]
#TODO: figure things out if axis is a constant #TODO: figure things out if axis is a constant
broadcastable = [False] * (x.type.ndim - 1) broadcastable = [False] * (x.type.ndim - 1)
outputs = [tensor(x.type.dtype, broadcastable,name='max'), outputs = [tensor(x.type.dtype, broadcastable, name='max'),
tensor('int32', broadcastable,name='argmax')] tensor('int32', broadcastable, name='argmax')]
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def perform(self, node, inp, outs): def perform(self, node, inp, outs):
x, axis = inp x, axis = inp
max, max_idx = outs max, max_idx = outs
...@@ -1906,27 +1917,29 @@ class MaxAndArgmax(Op): ...@@ -1906,27 +1917,29 @@ class MaxAndArgmax(Op):
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
ishape, axis_shape = shapes ishape, axis_shape = shapes
axis=node.inputs[1] axis = node.inputs[1]
if axis is None: if axis is None:
return [(),()] return [(), ()]
rval = tuple([ishape[i] for (i,b) in enumerate(node.inputs[0].type.broadcastable) if i !=axis.data]) rval = tuple([ishape[i] for (i, b) in enumerate(
return [rval,rval] node.inputs[0].type.broadcastable) if i != axis.data])
return [rval, rval]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if eval_points[0] is None: if eval_points[0] is None:
return [None, None] return [None, None]
if not isinstance(inputs[1], theano.Constant): if not isinstance(inputs[1], theano.Constant):
raise ValueError( ('R_op supported for arg_max only for ' raise ValueError(('R_op supported for arg_max only for '
'constant axis!')) 'constant axis!'))
if inputs[1].data > 1: if inputs[1].data > 1:
raise ValueError( ('R_op supported for arg_max only when ' raise ValueError(('R_op supported for arg_max only when '
' axis is 0 or 1')) ' axis is 0 or 1'))
if inputs[0].ndim != 2: if inputs[0].ndim != 2:
raise ValueError( ('R_op supported for arg_max only when ' raise ValueError(('R_op supported for arg_max only when '
' input is a matrix')) ' input is a matrix'))
max_vals, max_pos = self.make_node(*inputs).outputs max_vals, max_pos = self.make_node(*inputs).outputs
if inputs[1].data == 0: if inputs[1].data == 0:
return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None] return [eval_points[0][max_pos,
arange(eval_points[0].shape[1])], None]
else: else:
return [eval_points[0][arange(eval_points[0].shape[0]), return [eval_points[0][arange(eval_points[0].shape[0]),
max_pos], None] max_pos], None]
...@@ -1963,8 +1976,9 @@ class MaxAndArgmax(Op): ...@@ -1963,8 +1976,9 @@ class MaxAndArgmax(Op):
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
_max_and_argmax = MaxAndArgmax() _max_and_argmax = MaxAndArgmax()
@_redefine_asRoutine(_max_and_argmax) @_redefine_asRoutine(_max_and_argmax)
def max_and_argmax(a): def max_and_argmax(a):
pass pass
...@@ -1979,13 +1993,14 @@ def max(x, axis=None): ...@@ -1979,13 +1993,14 @@ def max(x, axis=None):
:note: we return an error as numpy when we reduce a dim with a shape of 0 :note: we return an error as numpy when we reduce a dim with a shape of 0
""" """
if isinstance(axis,(list,tuple)) and len(axis)>1: if isinstance(axis, (list, tuple)) and len(axis) > 1:
return CAReduce(scal.maximum,axis)(x) return CAReduce(scal.maximum, axis)(x)
try: try:
const = get_constant_value(axis) const = get_constant_value(axis)
return CAReduce(scal.maximum,list(const))(x) return CAReduce(scal.maximum, list(const))(x)
except Exception: except Exception:
return max_and_argmax(x,axis)[0] return max_and_argmax(x, axis)[0]
@constructor @constructor
def argmax(x, axis=None): def argmax(x, axis=None):
...@@ -1998,7 +2013,8 @@ def argmax(x, axis=None): ...@@ -1998,7 +2013,8 @@ def argmax(x, axis=None):
# In python (using MaxAndArgmax.perform()) this leads to an wasteful # In python (using MaxAndArgmax.perform()) this leads to an wasteful
# implementation that goes through the data twice instead of once # implementation that goes through the data twice instead of once
# but when Argmax.c_impl() is in place, it should be fine. # but when Argmax.c_impl() is in place, it should be fine.
return max_and_argmax(x,axis)[1] return max_and_argmax(x, axis)[1]
@constructor @constructor
def min(x, axis=None): def min(x, axis=None):
...@@ -2009,6 +2025,7 @@ def min(x, axis=None): ...@@ -2009,6 +2025,7 @@ def min(x, axis=None):
#Be careful about unsigned integers, complex #Be careful about unsigned integers, complex
raise NotImplementedError() raise NotImplementedError()
@constructor @constructor
def argmin(x, axis=None): def argmin(x, axis=None):
str_x_type = str(x.dtype) str_x_type = str(x.dtype)
...@@ -2018,6 +2035,7 @@ def argmin(x, axis=None): ...@@ -2018,6 +2035,7 @@ def argmin(x, axis=None):
#Be careful about unsigned integers, complex #Be careful about unsigned integers, complex
raise NotImplementedError() raise NotImplementedError()
@constructor @constructor
def smallest(*args): def smallest(*args):
"""Return the [elementwise] smallest of a variable number of arguments (like python's min).""" """Return the [elementwise] smallest of a variable number of arguments (like python's min)."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论