提交 a75410d8 authored 作者: Olivier Breuleux's avatar Olivier Breuleux

improvements to pprint

上级 e3838b3c
from .. import tensor as T
from .. import scalar as S
from .. import gof
from copy import copy
class PrinterState(gof.utils.scratchpad):
def __init__(self, props = {}, **more_props):
......@@ -121,7 +123,8 @@ class DimShufflePrinter:
def __p(self, new_order, pstate, r):
if new_order != () and new_order[0] == 'x':
return "[%s]" % self.__p(new_order[1:], pstate, r)
return "%s" % self.__p(new_order[1:], pstate, r)
# return "[%s]" % self.__p(new_order[1:], pstate, r)
if list(new_order) == range(r.type.ndim):
return pstate.pprinter.process(r)
if list(new_order) == list(reversed(range(r.type.ndim))):
......@@ -138,6 +141,28 @@ class DimShufflePrinter:
raise TypeError("Can only print DimShuffle.")
class SubtensorPrinter:
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print Subtensor.")
elif isinstance(r.owner.op, T.Subtensor):
idxs = r.owner.op.idx_list
inputs = list(r.owner.inputs)
input = inputs.pop()
sidxs = []
inbrack_pstate = pstate.clone(precedence = -1000)
for entry in idxs:
if isinstance(entry, int):
sidxs.append(str(entry))
elif isinstance(entry, S.Scalar):
sidxs.append(inbrack_pstate.pprinter.process(inputs.pop()))
return "%s[%s]" % (pstate.clone(precedence = 1000).pprinter.process(input),
", ".join(sidxs))
else:
raise TypeError("Can only print Subtensor.")
class DefaultPrinter:
def __init__(self):
......@@ -238,6 +263,12 @@ def pprinter():
pp.assign(T.tanh, FunctionPrinter('tanh'))
pp.assign(T.transpose_inplace, MemberPrinter('T'))
pp.assign(T._abs, PatternPrinter(('|%(0)s|', -1000)))
pp.assign(T.sgn, FunctionPrinter('sgn'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Filler) and r.owner.op.value == 0, FunctionPrinter('seros'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Filler) and r.owner.op.value == 1, FunctionPrinter('ones'))
pp.assign(lambda pstate, r: r.owner and isinstance(r.owner.op, T.Subtensor), SubtensorPrinter())
pp.assign(T.shape, MemberPrinter('shape'))
pp.assign(T.fill, FunctionPrinter('fill'))
return pp
pp = pprinter()
......
......@@ -598,7 +598,7 @@ class Sgn(UnaryScalarOp):
#casting is done by compiler
#TODO: use copysign
return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0.0 : 1.0 : -1.0;" % locals()
sgn = Sgn(same_out, name = 'abs')
sgn = Sgn(same_out, name = 'sgn')
class Inv(UnaryScalarOp):
def impl(self, x):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论