提交 cc041dc3 authored 作者: James Bergstra's avatar James Bergstra

removed automatic casts, and changed mean to cast input to FP, and added print op

上级 a06073d6
"""Pretty-printing graphs, and the 'Print' Op.
"""
import gof
from copy import copy
import sys
from gof import Op, Apply
class Print(Op):
"""This identity-like Op has the side effect of printing a message followed by its inputs
when it runs.
"""
def __init__(self,message=""):
self.message=message
self.view_map={0:[0]}
def make_node(self,xin):
xout = xin.type.make_result()
return Apply(op = self, inputs = [xin], outputs=[xout])
def perform(self,node,inputs,output_storage):
xin, = inputs
xout, = output_storage
xout[0] = xin
print self.message,xin
def grad(self,input,output_gradients):
return output_gradients
class PrinterState(gof.utils.scratchpad):
......@@ -232,3 +255,4 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n
pp = pprint
......@@ -21,7 +21,7 @@ from .. import scalar as scal
from ..gof.python25 import partial
from .. import compile, printing
from ..printing import pprint
from ..printing import pprint, Print
### set up the external interface
......@@ -456,10 +456,11 @@ class _tensor_py_operators:
def __abs__(self): return abs_(self)
def __neg__(self): return neg(self)
#CASTS
def __int__(self): return AsInt(self).out
def __float__(self): return AsInt(self).out
def __complex__(self): return AsComplex(self).out
#CASTS
#### REMOVED THESE BECAUSE PYTHON appears to require __int__ to return an int. -JB 20081112
#def __int__(self): return convert_to_int32(self)
#def __float__(self): return convert_to_float64(self)
#def __complex__(self): return convert_to_complex128(self)
#COMPARISONS
def __lt__(self,other): return lt(self, other)
......@@ -1012,6 +1013,10 @@ pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor
def mean(input, axis = None):
"""WRITEME"""
if str(input.dtype).startswith('int'):
# we need to cast eventually anyway, and this helps
# to prevents overflow
input = convert_to_float64(input)
s = sum(input, axis)
shp = shape(input)
if axis is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论