提交 c7fe4e81 authored 作者: David Warde-Farley's avatar David Warde-Farley

Merge pull request #2407 from mohammadpz/Composite_pretty_print

now pretty print works for scaler variables too.
...@@ -22,7 +22,7 @@ import numpy ...@@ -22,7 +22,7 @@ import numpy
import theano import theano
from theano.compat import PY3 from theano.compat import PY3
from theano import gof from theano import gof, printing
from theano.gof import (Op, utils, Variable, Constant, Type, Apply, from theano.gof import (Op, utils, Variable, Constant, Type, Apply,
FunctionGraph) FunctionGraph)
from theano.gof.python25 import partial, all, any from theano.gof.python25 import partial, all, any
...@@ -31,6 +31,8 @@ from theano.configparser import config ...@@ -31,6 +31,8 @@ from theano.configparser import config
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.gradient import grad_undefined from theano.gradient import grad_undefined
from theano.printing import pprint
builtin_complex = complex builtin_complex = complex
builtin_int = int builtin_int = int
builtin_float = float builtin_float = float
...@@ -2166,6 +2168,14 @@ class Neg(UnaryScalarOp): ...@@ -2166,6 +2168,14 @@ class Neg(UnaryScalarOp):
return "%(z)s = -%(x)s;" % locals() return "%(z)s = -%(x)s;" % locals()
neg = Neg(same_out, name='neg') neg = Neg(same_out, name='neg')
pprint.assign(add, printing.OperatorPrinter('+', -2, 'either'))
pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either'))
pprint.assign(sub, printing.OperatorPrinter('-', -2, 'left'))
pprint.assign(neg, printing.OperatorPrinter('-', 0, 'either'))
pprint.assign(true_div, printing.OperatorPrinter('/', -1, 'left'))
pprint.assign(int_div, printing.OperatorPrinter('//', -1, 'left'))
pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right'))
pprint.assign(mod, printing.OperatorPrinter('%', -1, 'lef'))
class Inv(UnaryScalarOp): class Inv(UnaryScalarOp):
""" multiplicative inverse. Also called reciprocal""" """ multiplicative inverse. Also called reciprocal"""
...@@ -3099,7 +3109,8 @@ class Composite(ScalarOp): ...@@ -3099,7 +3109,8 @@ class Composite(ScalarOp):
for i, r in enumerate(self.fgraph.variables): for i, r in enumerate(self.fgraph.variables):
if r not in io and len(r.clients) > 1: if r not in io and len(r.clients) > 1:
r.name = 't%i' % i r.name = 't%i' % i
rval = "Composite{%s}" % str(self.fgraph) rval = "Composite{%s}" % ', '.join([pprint(output) for output
in self.fgraph.outputs])
self.name = rval self.name = rval
def init_fgraph(self): def init_fgraph(self):
......
...@@ -146,6 +146,28 @@ class test_composite(unittest.TestCase): ...@@ -146,6 +146,28 @@ class test_composite(unittest.TestCase):
fn = gof.DualLinker().accept(g).make_function() fn = gof.DualLinker().accept(g).make_function()
assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5] assert fn(1.0, 2.0, 3.0) == [6.0, 7.0, 0.5]
def test_composite_printing(self):
x, y, z = floats('xyz')
e0 = x + y + z
e1 = x + y * z
e2 = x / y
e3 = x // 5
e4 = -x
e5 = x - y
e6 = x ** y + (-z)
e7 = x % 3
C = Composite([x, y, z], [e0, e1, e2, e3, e4, e5, e6, e7])
c = C.make_node(x, y, z)
g = FunctionGraph([x, y, z], c.outputs)
fn = gof.DualLinker().accept(g).make_function()
assert str(g) == ('[*1 -> Composite{((i0 + i1) + i2),'
' (i0 + (i1 * i2)), (i0 / i1), '
'(i0 // Constant{5}), '
'(-i0), (i0 - i1), ((i0 ** i1) + (-i2)),'
' (i0 % Constant{3})}(x, y, z), '
'*1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7]')
def test_make_node_continue_graph(self): def test_make_node_continue_graph(self):
# This is a test for a bug (now fixed) that disabled the # This is a test for a bug (now fixed) that disabled the
# local_gpu_elemwise_0 optimization and printed an # local_gpu_elemwise_0 optimization and printed an
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论