提交 b0c67648 authored 作者: james@X40's avatar james@X40

merge

......@@ -2054,8 +2054,10 @@ class Reshape(Op):
The number of dimensions to which to reshape to (ndim) must be known at graph
build time."""
view_map = {0: [0]} #output 0 is potentially aliased to inputs [0]
def __init__(self, ndim):
def __init__(self, ndim, name = None):
self.ndim = ndim
if name:
self.name = name
def __eq__(self, other):
return (type(other) is Reshape) and (other.ndim == self.ndim)
def __hash__(self):
......@@ -2075,10 +2077,10 @@ class Reshape(Op):
def grad(self, (x, shp), (g_out,)):
return [reshape(g_out, shape(x), ndim=x.ndim), None]
def reshape(x, newshape, ndim=None):
def reshape(x, newshape, ndim=None, name=None):
if ndim is None:
ndim = get_vector_length(newshape)
op = Reshape(ndim)
op = Reshape(ndim, name)
return op(x, newshape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论