提交 b7547f84 authored 作者: Sina Honari's avatar Sina Honari

applying changes to remove Flatten Op

上级 eced0049
......@@ -3322,23 +3322,23 @@ class GpuIncSubtensor(tensor.IncSubtensor, GpuOp):
return ()
#class GpuFlatten(gof.HideC, tensor.Reshape, GpuOp):
# """
# Implement Flatten on the gpu.
#
# """
#
# def make_node(self, x):
# warnings.warn(
# "GpuFlatten class is deprecated, "
# "please use gpu_flatten method instead.",
# DeprecationWarning,
# stacklevel=4)
# assert isinstance(x.type, CudaNdarrayType)
# rval = tensor.Reshape.make_node(self, x, [tensor.prod(x.shape)])
# host_out_broadcastable = rval.outputs[0].type.broadcastable
# out_type = CudaNdarrayType(broadcastable=host_out_broadcastable)
# return Apply(self, [x], [out_type()])
class GpuFlatten(gof.HideC, tensor.Reshape, GpuOp):
"""
Implement Flatten on the gpu.
"""
def make_node(self, x):
warnings.warn(
"GpuFlatten class is deprecated, "
"please use gpu_flatten method instead.",
DeprecationWarning,
stacklevel=4)
assert isinstance(x.type, CudaNdarrayType)
rval = tensor.Reshape.make_node(self, x, [tensor.prod(x.shape)])
host_out_broadcastable = rval.outputs[0].type.broadcastable
out_type = CudaNdarrayType(broadcastable=host_out_broadcastable)
return Apply(self, [x], [out_type()])
......
差异被折叠。
......@@ -3877,24 +3877,24 @@ def local_useless_split(node):
################
# Flatten Opts #
################
#@register_canonicalize
#@register_stabilize
#@gof.local_optimizer([T.Flatten])
#def local_flatten_lift(node):
# ""
# Flatten(UnaryElemwise(x)) -> UnaryElemwise(Flatten(x))
#
# This optimization is needed by optimization
# nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten.
#
# ""
# if (isinstance(node.op, T.Flatten) and
# node.inputs[0].owner and
# isinstance(node.inputs[0].owner.op, T.Elemwise) and
# len(node.inputs[0].owner.inputs) == 1):
# f = node.op(node.inputs[0].owner.inputs[0])
# e = node.inputs[0].owner.op(f)
# return [e]
@register_canonicalize
@register_stabilize
@gof.local_optimizer([T.Flatten])
def local_flatten_lift(node):
""
Flatten(UnaryElemwise(x)) -> UnaryElemwise(Flatten(x))
This optimization is needed by optimization
nnet/sigm.py:log1msigm_to_softplus to get applied when there is a flatten.
""
if (isinstance(node.op, T.Flatten) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, T.Elemwise) and
len(node.inputs[0].owner.inputs) == 1):
f = node.op(node.inputs[0].owner.inputs[0])
e = node.inputs[0].owner.op(f)
return [e]
##################
# Reshape opts #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论