提交 968f6f9a authored 作者: James Bergstra's avatar James Bergstra

merge

...@@ -341,6 +341,8 @@ class Value(Variable): ...@@ -341,6 +341,8 @@ class Value(Variable):
if value is not None: if value is not None:
raise ValueError("Value instances cannot have an owner.") raise ValueError("Value instances cannot have an owner.")
owner = property(lambda self: None, __set_owner) owner = property(lambda self: None, __set_owner)
value = property(lambda self: self.data,
doc='read-only data access method')
# index is not defined, because the `owner` attribute must necessarily be None # index is not defined, because the `owner` attribute must necessarily be None
......
...@@ -359,7 +359,8 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n ...@@ -359,7 +359,8 @@ pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is n
pp = pprint pp = pprint
def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.png'), compact=True, mode=None, format='png'): def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.png'),
compact=True, mode=None, format='png', with_ids=False):
""" """
print to a file in png format the graph of op of a compile theano fct. print to a file in png format the graph of op of a compile theano fct.
...@@ -390,14 +391,15 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -390,14 +391,15 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
g=pd.Dot() g=pd.Dot()
var_str={} var_str={}
all_strings = set()
def var_name(var): def var_name(var):
if var in var_str: if var in var_str:
return var_str[var] return var_str[var]
if var.name is not None: if var.name is not None:
varstr = var.name+" "+str(var.type) varstr = 'name='+var.name+" "+str(var.type)
elif isinstance(var,gof.Constant): elif isinstance(var,gof.Constant):
dstr = str(var.data) dstr = 'val='+str(var.data)
if '\n' in dstr: if '\n' in dstr:
dstr = dstr[:dstr.index('\n')] dstr = dstr[:dstr.index('\n')]
if len(dstr) > 30: if len(dstr) > 30:
...@@ -408,12 +410,17 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -408,12 +410,17 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
else: else:
#a var id is needed as otherwise var with the same type will be merged in the graph. #a var id is needed as otherwise var with the same type will be merged in the graph.
varstr = str(var.type) varstr = str(var.type)
varstr += ' ' + str(len(var_str)) if (varstr in all_strings) or with_ids:
varstr += ' id=' + str(len(var_str))
var_str[var]=varstr var_str[var]=varstr
all_strings.add(varstr)
return varstr return varstr
topo = fct.maker.env.toposort() topo = fct.maker.env.toposort()
apply_name_cache = {}
def apply_name(node): def apply_name(node):
if node in apply_name_cache:
return apply_name_cache[node]
prof_str='' prof_str=''
if mode: if mode:
time = mode.apply_time.get((topo.index(node),node),0) time = mode.apply_time.get((topo.index(node),node),0)
...@@ -425,7 +432,12 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -425,7 +432,12 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
pf=0 pf=0
else: pf = time*100/mode.fct_call_time[fct] else: pf = time*100/mode.fct_call_time[fct]
prof_str=' (%.3fs,%.3f%%,%.3f%%)'%(time,pt,pf) prof_str=' (%.3fs,%.3f%%,%.3f%%)'%(time,pt,pf)
return str(node.op).replace(':','_')+' '+str(topo.index(node))+prof_str applystr = str(node.op).replace(':','_')
if (applystr in all_strings) or with_ids:
applystr = applystr+' id='+str(topo.index(node))+prof_str
all_strings.add(applystr)
apply_name_cache[node] = applystr
return applystr
# Update the inputs that have an update function # Update the inputs that have an update function
input_update={} input_update={}
...@@ -434,16 +446,18 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -434,16 +446,18 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
if i.update is not None: if i.update is not None:
input_update[outputs.pop()] = i input_update[outputs.pop()] = i
apply_shape='ellipse'
var_shape='box'
for node_idx,node in enumerate(topo): for node_idx,node in enumerate(topo):
astr=apply_name(node) astr=apply_name(node)
g.add_node(pd.Node(astr,shape='box')) g.add_node(pd.Node(astr,shape=apply_shape))
for id,var in enumerate(node.inputs): for id,var in enumerate(node.inputs):
varstr=var_name(var) varstr=var_name(var)
label='' label=''
if len(node.inputs)>1: if len(node.inputs)>1:
label=str(id) label=str(id)
if var.owner is None: if var.owner is None:
g.add_node(pd.Node(varstr,color='green')) g.add_node(pd.Node(varstr,color='green',shape=var_shape))
g.add_edge(pd.Edge(varstr,astr, label=label)) g.add_edge(pd.Edge(varstr,astr, label=label))
elif var.name or not compact: elif var.name or not compact:
g.add_edge(pd.Edge(varstr,astr, label=label)) g.add_edge(pd.Edge(varstr,astr, label=label))
...@@ -460,10 +474,10 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -460,10 +474,10 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
label=str(id) label=str(id)
if out: if out:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
g.add_node(pd.Node(varstr,color='blue')) g.add_node(pd.Node(varstr,color='blue',shape=var_shape))
elif len(var.clients)==0: elif len(var.clients)==0:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
g.add_node(pd.Node(varstr,color='grey')) g.add_node(pd.Node(varstr,color='grey',shape=var_shape))
elif var.name or not compact: elif var.name or not compact:
g.add_edge(pd.Edge(astr, varstr, label=label)) g.add_edge(pd.Edge(astr, varstr, label=label))
# else: # else:
...@@ -495,9 +509,9 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -495,9 +509,9 @@ def pydot_var(vars, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
return var_str[var] return var_str[var]
if var.name is not None: if var.name is not None:
varstr = var.name varstr = 'name='+var.name
elif isinstance(var,gof.Constant): elif isinstance(var,gof.Constant):
dstr = str(var.data) dstr = 'val='+str(var.data)
if '\n' in dstr: if '\n' in dstr:
dstr = dstr[:dstr.index('\n')] dstr = dstr[:dstr.index('\n')]
if len(dstr) > 30: if len(dstr) > 30:
......
...@@ -932,6 +932,7 @@ class IntDiv(BinaryScalarOp): ...@@ -932,6 +932,7 @@ class IntDiv(BinaryScalarOp):
return [None] * len(inputs) return [None] * len(inputs)
int_div = IntDiv(upcast_out, name = 'int_div') int_div = IntDiv(upcast_out, name = 'int_div')
floor_div = int_div
class Mod(BinaryScalarOp): class Mod(BinaryScalarOp):
def impl(self, x, y): def impl(self, x, y):
......
...@@ -887,6 +887,11 @@ class _tensor_py_operators: ...@@ -887,6 +887,11 @@ class _tensor_py_operators:
except Exception, e: except Exception, e:
return NotImplemented return NotImplemented
def __truediv__(self,other): return true_div(self, other)
def __floordiv__(self,other): return floor_div(self, other)
def __rtruediv__(self,other): return true_div(other, self)
def __rfloordiv__(self,other): return floor_div(other, self)
# ##### DON"T USE THESE BECAUSE INPLACE OPS SHOULD BE INSERTED BY OPTIMIZATION ONLY # ##### DON"T USE THESE BECAUSE INPLACE OPS SHOULD BE INSERTED BY OPTIMIZATION ONLY
# #ARITHMETIC - INPLACE # #ARITHMETIC - INPLACE
# def __iadd__(self,other): return _add_inplace(self,other) # def __iadd__(self,other): return _add_inplace(self,other)
...@@ -2066,6 +2071,11 @@ def true_div(a, b): ...@@ -2066,6 +2071,11 @@ def true_div(a, b):
"""elementwise [true] division (inverse of multiplication)""" """elementwise [true] division (inverse of multiplication)"""
# see decorator for function body # see decorator for function body
@_scal_elemwise
def floor_div(a, b):
"""elementwise [floor] division (inverse of multiplication)"""
# see decorator for function body
@_scal_elemwise @_scal_elemwise
def int_div(a, b): def int_div(a, b):
"""elementwise integer-division""" """elementwise integer-division"""
......
...@@ -889,6 +889,7 @@ def _gemm_from_node2(node): ...@@ -889,6 +889,7 @@ def _gemm_from_node2(node):
if len(lst) > 1: if len(lst) > 1:
lst = _factor_canonicalized(lst) lst = _factor_canonicalized(lst)
rval = _gemm_from_factored_list(lst) rval = _gemm_from_factored_list(lst)
#print "RVAL", rval
if rval: if rval:
assert rval[0].type == node.outputs[0].type, (rval[0].type, node.outputs[0].type) assert rval[0].type == node.outputs[0].type, (rval[0].type, node.outputs[0].type)
return rval return rval
...@@ -909,7 +910,6 @@ class GemmOptimizer(Optimizer): ...@@ -909,7 +910,6 @@ class GemmOptimizer(Optimizer):
did_something = False did_something = False
nodelist.reverse() nodelist.reverse()
for node in nodelist: for node in nodelist:
#new_outputs = _gemm_from_node(node)
try: try:
new_outputs = _gemm_from_node2(node) new_outputs = _gemm_from_node2(node)
except InconsistencyError, e: except InconsistencyError, e:
...@@ -1193,9 +1193,10 @@ blas_optdb.register('local_dot22_to_dot22scalar', ...@@ -1193,9 +1193,10 @@ blas_optdb.register('local_dot22_to_dot22scalar',
11, 'fast_run') 11, 'fast_run')
from opt import register_specialize from opt import register_specialize, register_canonicalize
#@register_specialize #@register_specialize
@local_optimizer([]) @local_optimizer([])
def local_print_as_we_go_along(node): def local_print_as_we_go_along(node):
if node.op in (T.sub, T.add): if node.op in (T.sub, T.add):
debugprint(node) debugprint(node)
...@@ -338,3 +338,47 @@ register_local_1msigmoid = False ...@@ -338,3 +338,47 @@ register_local_1msigmoid = False
if register_local_1msigmoid: if register_local_1msigmoid:
opt.register_canonicalize(local_1msigmoid) opt.register_canonicalize(local_1msigmoid)
if 0:
# This code is if'd out because it is not complete,
# and it isn't obviously a good idea anyway.
# The motivation here was to identify the last exp() node
# in the SciPy2010 article, which was not optimized away at the time of publication,
# so the example is actually not numerically stable, even though it should be.
@opt.register_stabilize
@gof.local_optimizer([tensor.mul])
def local_sigm_gest(node):
print "CANONICALIZE"
print sigm_canonicalize(node)
def sigm_canonicalize(node):
add = tensor.add
mul = tensor.mul
div = tensor.true_div
if node.op == tensor.add:
rval = []
for i in node.inputs:
rval += sigm_canonicalize(i)
return rval
if node.op == tensor.mul:
rval = sigm_canonicalize(node.inputs[0])
for i in node.inputs[1:]:
old_rval = rval
rval = []
for t1 in sigm_canonicalize(i):
for t0 in old_rval:
assert t1.owner.op == div
assert t0.owner.op == div
t0top, t0bot = t0.owner.inputs
t1top, t1bot = t1.owner.inputs
rval.append(div(mul(*(t0top+t1top)), mul(*(t0bot+t1bot))))
if len(rval) > 100:
# This loop can be exponentially long.
# aborting
return []
elif len(node.outputs)>1:
return []
else:
return [node.outputs[0]]
...@@ -194,6 +194,11 @@ def local_dimshuffle_lift(node): ...@@ -194,6 +194,11 @@ def local_dimshuffle_lift(node):
register_canonicalize(local_dimshuffle_lift) register_canonicalize(local_dimshuffle_lift)
register_specialize(local_dimshuffle_lift) register_specialize(local_dimshuffle_lift)
@register_canonicalize
@local_optimizer([])
def local_dimshuffle_no_inplace_at_canonicalize(node):
if isinstance(node.op, T.DimShuffle) and node.op.inplace:
return [T.DimShuffle(node.op.input_broadcastable, node.op.new_order, inplace=False)(node.inputs[0])]
##################################### #####################################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论