提交 712660e7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Make reshape ndim kwarg only

This prevents surprises when passing two scalars, which are interpreted differently in the numpy API
上级 f27ac453
...@@ -708,7 +708,10 @@ class Repeat(Op): ...@@ -708,7 +708,10 @@ class Repeat(Op):
shape = [x.shape[k] for k in range(x.ndim)] shape = [x.shape[k] for k in range(x.ndim)]
shape.insert(axis, repeats) shape.insert(axis, repeats)
return [gz.reshape(shape, x.ndim + 1).sum(axis=axis), DisconnectedType()()] return [
gz.reshape(shape, ndim=x.ndim + 1).sum(axis=axis),
DisconnectedType()(),
]
elif repeats.ndim == 1: elif repeats.ndim == 1:
# For this implementation, we would need to specify the length # For this implementation, we would need to specify the length
# of repeats in order to split gz in the right way to sum # of repeats in order to split gz in the right way to sum
......
...@@ -2196,7 +2196,7 @@ def _tensordot_as_dot(a, b, axes, dot, batched): ...@@ -2196,7 +2196,7 @@ def _tensordot_as_dot(a, b, axes, dot, batched):
b_reshaped = b.reshape(b_shape) b_reshaped = b.reshape(b_shape)
out_reshaped = dot(a_reshaped, b_reshaped) out_reshaped = dot(a_reshaped, b_reshaped)
out = out_reshaped.reshape(outshape, outndim) out = out_reshaped.reshape(outshape, ndim=outndim)
# Make sure the broadcastable pattern of the result is correct, # Make sure the broadcastable pattern of the result is correct,
# since some shape information can be lost in the reshapes. # since some shape information can be lost in the reshapes.
if out.type.broadcastable != outbcast: if out.type.broadcastable != outbcast:
......
...@@ -155,7 +155,7 @@ def transform_take(a, indices, axis): ...@@ -155,7 +155,7 @@ def transform_take(a, indices, axis):
ndim = a.ndim + indices.ndim - 1 ndim = a.ndim + indices.ndim - 1
return transform_take(a, indices.flatten(), axis).reshape(shape, ndim) return transform_take(a, indices.flatten(), axis).reshape(shape, ndim=ndim)
def is_full_slice(x): def is_full_slice(x):
......
...@@ -580,7 +580,7 @@ def kron(a, b): ...@@ -580,7 +580,7 @@ def kron(a, b):
f"You passed {int(a.ndim)} and {int(b.ndim)}." f"You passed {int(a.ndim)} and {int(b.ndim)}."
) )
o = atm.outer(a, b) o = atm.outer(a, b)
o = o.reshape(at.concatenate((a.shape, b.shape)), a.ndim + b.ndim) o = o.reshape(at.concatenate((a.shape, b.shape)), ndim=a.ndim + b.ndim)
shf = o.dimshuffle(0, 2, 1, *list(range(3, o.ndim))) shf = o.dimshuffle(0, 2, 1, *list(range(3, o.ndim)))
if shf.ndim == 3: if shf.ndim == 3:
shf = o.dimshuffle(1, 0, 2) shf = o.dimshuffle(1, 0, 2)
......
...@@ -283,7 +283,7 @@ class _tensor_py_operators: ...@@ -283,7 +283,7 @@ class _tensor_py_operators:
# "Variable) due to Python restriction. You can use " # "Variable) due to Python restriction. You can use "
# "PyTensorVariable.shape[0] instead.") # "PyTensorVariable.shape[0] instead.")
def reshape(self, shape, ndim=None): def reshape(self, shape, *, ndim=None):
"""Return a reshaped view/copy of this variable. """Return a reshaped view/copy of this variable.
Parameters Parameters
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论