提交 cc041dc3 authored 作者: James Bergstra's avatar James Bergstra

removed automatic casts, and changed mean to cast input to FP, and added print op

上级 a06073d6
"""Pretty-printing graphs, and the 'Print' Op.
"""
import gof import gof
from copy import copy from copy import copy
import sys import sys
from gof import Op, Apply
class Print(Op):
"""This identity-like Op has the side effect of printing a message followed by its inputs
when it runs.
"""
def __init__(self,message=""):
self.message=message
self.view_map={0:[0]}
def make_node(self,xin):
xout = xin.type.make_result()
return Apply(op = self, inputs = [xin], outputs=[xout])
def perform(self,node,inputs,output_storage):
xin, = inputs
xout, = output_storage
xout[0] = xin
print self.message,xin
def grad(self,input,output_gradients):
return output_gradients
class PrinterState(gof.utils.scratchpad): class PrinterState(gof.utils.scratchpad):
...@@ -232,3 +255,4 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n ...@@ -232,3 +255,4 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n
pp = pprint pp = pprint
...@@ -21,7 +21,7 @@ from .. import scalar as scal ...@@ -21,7 +21,7 @@ from .. import scalar as scal
from ..gof.python25 import partial from ..gof.python25 import partial
from .. import compile, printing from .. import compile, printing
from ..printing import pprint from ..printing import pprint, Print
### set up the external interface ### set up the external interface
...@@ -456,10 +456,11 @@ class _tensor_py_operators: ...@@ -456,10 +456,11 @@ class _tensor_py_operators:
def __abs__(self): return abs_(self) def __abs__(self): return abs_(self)
def __neg__(self): return neg(self) def __neg__(self): return neg(self)
#CASTS #CASTS
def __int__(self): return AsInt(self).out #### REMOVED THESE BECAUSE PYTHON appears to require __int__ to return an int. -JB 20081112
def __float__(self): return AsInt(self).out #def __int__(self): return convert_to_int32(self)
def __complex__(self): return AsComplex(self).out #def __float__(self): return convert_to_float64(self)
#def __complex__(self): return convert_to_complex128(self)
#COMPARISONS #COMPARISONS
def __lt__(self,other): return lt(self, other) def __lt__(self,other): return lt(self, other)
...@@ -1012,6 +1013,10 @@ pprint.assign(Sum(), printing.FunctionPrinter('sum')) ...@@ -1012,6 +1013,10 @@ pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor @constructor
def mean(input, axis = None): def mean(input, axis = None):
"""WRITEME""" """WRITEME"""
if str(input.dtype).startswith('int'):
# we need to cast eventually anyway, and this helps
# to prevents overflow
input = convert_to_float64(input)
s = sum(input, axis) s = sum(input, axis)
shp = shape(input) shp = shape(input)
if axis is None: if axis is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论