提交 21d3f510 authored 作者: Mohammad Pezeshki's avatar Mohammad Pezeshki

now pretty print works for scaler variables too.

上级 370953c5
...@@ -22,7 +22,7 @@ import numpy ...@@ -22,7 +22,7 @@ import numpy
import theano import theano
from theano.compat import PY3 from theano.compat import PY3
from theano import gof from theano import gof, printing
from theano.gof import (Op, utils, Variable, Constant, Type, Apply, from theano.gof import (Op, utils, Variable, Constant, Type, Apply,
FunctionGraph) FunctionGraph)
from theano.gof.python25 import partial, all, any from theano.gof.python25 import partial, all, any
...@@ -31,6 +31,8 @@ from theano.configparser import config ...@@ -31,6 +31,8 @@ from theano.configparser import config
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
from theano.printing import pprint
builtin_complex = complex builtin_complex = complex
builtin_int = int builtin_int = int
builtin_float = float builtin_float = float
...@@ -2166,6 +2168,13 @@ class Neg(UnaryScalarOp): ...@@ -2166,6 +2168,13 @@ class Neg(UnaryScalarOp):
return "%(z)s = -%(x)s;" % locals() return "%(z)s = -%(x)s;" % locals()
neg = Neg(same_out, name='neg') neg = Neg(same_out, name='neg')
pprint.assign(add, printing.OperatorPrinter('+', -2, 'either'))
pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either'))
pprint.assign(sub, printing.OperatorPrinter('-', -2, 'left'))
pprint.assign(neg, printing.OperatorPrinter('-', 0, 'either'))
pprint.assign(true_div, printing.OperatorPrinter('/', -1, 'left'))
pprint.assign(int_div, printing.OperatorPrinter('//', -1, 'left'))
pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
class Inv(UnaryScalarOp): class Inv(UnaryScalarOp):
""" multiplicative inverse. Also called reciprocal""" """ multiplicative inverse. Also called reciprocal"""
...@@ -3099,7 +3108,7 @@ class Composite(ScalarOp): ...@@ -3099,7 +3108,7 @@ class Composite(ScalarOp):
for i, r in enumerate(self.fgraph.variables): for i, r in enumerate(self.fgraph.variables):
if r not in io and len(r.clients) > 1: if r not in io and len(r.clients) > 1:
r.name = 't%i' % i r.name = 't%i' % i
rval = "Composite{%s}" % str(self.fgraph) rval = "Composite{%s}" % pprint(self.fgraph.outputs[0])
self.name = rval self.name = rval
def init_fgraph(self): def init_fgraph(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论