提交 5ad3c667 authored 作者: Frederic's avatar Frederic

pep8

上级 1f833c24
......@@ -1795,6 +1795,7 @@ shape = Shape()
_shape = shape #was used in the past, now use shape directly.
pprint.assign(_shape, printing.MemberPrinter('shape'))
class SpecifyShape(Op):
"""
L{Op} put into the graph the user provided shape
......@@ -1808,14 +1809,18 @@ class SpecifyShape(Op):
@note: We currently don't support specifying partial shape information.
"""
view_map = {0: [0]}
def __hash__(self):
return hash(type(self))
def __eq__(self, other):
return type(self) == type(other)
def __str__(self):
return self.__class__.__name__
def make_node(self, x, shape):
if not isinstance(x,Variable):
if not isinstance(x, Variable):
x = as_tensor_variable(x)
shape = as_tensor_variable(shape)
return Apply(self, [x, shape], [x.type()])
......@@ -1823,22 +1828,22 @@ class SpecifyShape(Op):
def perform(self, node, inp, out_):
x, shape = inp
out, = out_
assert numpy.all(x.shape==shape), ("got shape", x.shape,
assert numpy.all(x.shape == shape), ("got shape", x.shape,
"expected", shape)
out[0] = x
def infer_shape(self, node, shapes):
xshape, sshape = shapes
new_shape=[]
new_shape = []
for dim in xrange(node.inputs[0].ndim):
try:
s=get_constant_value(node.inputs[1][dim])
s=as_tensor_variable(s)
s = get_constant_value(node.inputs[1][dim])
s = as_tensor_variable(s)
new_shape.append(s)
except TypeError, e:
new_shape.append(node.inputs[1][dim])
assert len(new_shape)==len(xshape)
assert len(new_shape) == len(xshape)
return [new_shape]
def grad(self, inp, grads):
......@@ -1847,9 +1852,10 @@ class SpecifyShape(Op):
# Should I set an SpecifyShape on gz? I think so
# But I don't do it now as we need to make an optimization
# to remove that op from the graph to don't block other optimization
# Should I do an optimizer that will remove the SpecifyShape? I think Yes
# Should I do an optimizer that will remove the SpecifyShape?
# I think Yes
return [gz, None]
return [specify_shape(gz,s), None]
return [specify_shape(gz, s), None]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
......@@ -1860,31 +1866,35 @@ class SpecifyShape(Op):
specify_shape = SpecifyShape()
class MaxAndArgmax(Op):
"""Calculate the max and argmax over a given axis.
"""
nin=2 # tensor, axis
nout=2 # max val, max idx
nin = 2 # tensor, axis
nout = 2 # max val, max idx
E_axis = 'invalid axis'
def __eq__(self,other):
return type(self)==type(other)
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
def make_node(self, x, axis=None):
x = _as_tensor_variable(x)
if isinstance(axis,int):
if isinstance(axis, int):
axis = [axis]
elif isinstance(axis,(tuple,list)):
assert len(axis)==1,"MaxAndArgmax don't support multiple axis. the max fct support it."
#we make the axis all positive to make the infer_shape work with negative axis
if x.type.ndim>0 and axis is not None:
for id,a in enumerate(axis):
if not isinstance(a, TensorVariable) and a<0:
if -a>x.type.ndim:
elif isinstance(axis, (tuple, list)):
assert len(axis) == 1, ("MaxAndArgmax don't support multiple"
" axis. the max fct support it.")
# we make the axis all positive to make the infer_shape work
# with negative axis
if x.type.ndim > 0 and axis is not None:
for id, a in enumerate(axis):
if not isinstance(a, TensorVariable) and a < 0:
if -a > x.type.ndim:
raise ValueError('axis out of range')
axis[id]=x.type.ndim+a
axis[id] = x.type.ndim + a
if axis is None:
axis = _as_tensor_variable(range(x.type.ndim))
else:
......@@ -1893,9 +1903,10 @@ class MaxAndArgmax(Op):
inputs = [x, axis]
#TODO: figure things out if axis is a constant
broadcastable = [False] * (x.type.ndim - 1)
outputs = [tensor(x.type.dtype, broadcastable,name='max'),
tensor('int32', broadcastable,name='argmax')]
outputs = [tensor(x.type.dtype, broadcastable, name='max'),
tensor('int32', broadcastable, name='argmax')]
return Apply(self, inputs, outputs)
def perform(self, node, inp, outs):
x, axis = inp
max, max_idx = outs
......@@ -1906,27 +1917,29 @@ class MaxAndArgmax(Op):
def infer_shape(self, node, shapes):
ishape, axis_shape = shapes
axis=node.inputs[1]
axis = node.inputs[1]
if axis is None:
return [(),()]
rval = tuple([ishape[i] for (i,b) in enumerate(node.inputs[0].type.broadcastable) if i !=axis.data])
return [rval,rval]
return [(), ()]
rval = tuple([ishape[i] for (i, b) in enumerate(
node.inputs[0].type.broadcastable) if i != axis.data])
return [rval, rval]
def R_op(self, inputs, eval_points):
if eval_points[0] is None:
return [None, None]
if not isinstance(inputs[1], theano.Constant):
raise ValueError( ('R_op supported for arg_max only for '
'constant axis!'))
raise ValueError(('R_op supported for arg_max only for '
'constant axis!'))
if inputs[1].data > 1:
raise ValueError( ('R_op supported for arg_max only when '
' axis is 0 or 1'))
raise ValueError(('R_op supported for arg_max only when '
' axis is 0 or 1'))
if inputs[0].ndim != 2:
raise ValueError( ('R_op supported for arg_max only when '
' input is a matrix'))
raise ValueError(('R_op supported for arg_max only when '
' input is a matrix'))
max_vals, max_pos = self.make_node(*inputs).outputs
if inputs[1].data == 0:
return [eval_points[0][max_pos, arange(eval_points[0].shape[1])], None]
return [eval_points[0][max_pos,
arange(eval_points[0].shape[1])], None]
else:
return [eval_points[0][arange(eval_points[0].shape[0]),
max_pos], None]
......@@ -1963,8 +1976,9 @@ class MaxAndArgmax(Op):
def __str__(self):
return self.__class__.__name__
_max_and_argmax = MaxAndArgmax()
@_redefine_asRoutine(_max_and_argmax)
def max_and_argmax(a):
pass
......@@ -1979,13 +1993,14 @@ def max(x, axis=None):
:note: we return an error as numpy when we reduce a dim with a shape of 0
"""
if isinstance(axis,(list,tuple)) and len(axis)>1:
return CAReduce(scal.maximum,axis)(x)
if isinstance(axis, (list, tuple)) and len(axis) > 1:
return CAReduce(scal.maximum, axis)(x)
try:
const = get_constant_value(axis)
return CAReduce(scal.maximum,list(const))(x)
return CAReduce(scal.maximum, list(const))(x)
except Exception:
return max_and_argmax(x,axis)[0]
return max_and_argmax(x, axis)[0]
@constructor
def argmax(x, axis=None):
......@@ -1998,7 +2013,8 @@ def argmax(x, axis=None):
# In python (using MaxAndArgmax.perform()) this leads to an wasteful
# implementation that goes through the data twice instead of once
# but when Argmax.c_impl() is in place, it should be fine.
return max_and_argmax(x,axis)[1]
return max_and_argmax(x, axis)[1]
@constructor
def min(x, axis=None):
......@@ -2009,6 +2025,7 @@ def min(x, axis=None):
#Be careful about unsigned integers, complex
raise NotImplementedError()
@constructor
def argmin(x, axis=None):
str_x_type = str(x.dtype)
......@@ -2018,6 +2035,7 @@ def argmin(x, axis=None):
#Be careful about unsigned integers, complex
raise NotImplementedError()
@constructor
def smallest(*args):
"""Return the [elementwise] smallest of a variable number of arguments (like python's min)."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论