提交 4c019fb9 authored 作者: gdesjardins's avatar gdesjardins

merge

...@@ -461,7 +461,13 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn ...@@ -461,7 +461,13 @@ def pydotprint(fct, outfile=os.path.join(config.compiledir,'theano.pydotprint.pn
var_shape='box' var_shape='box'
for node_idx,node in enumerate(topo): for node_idx,node in enumerate(topo):
astr=apply_name(node) astr=apply_name(node)
g.add_node(pd.Node(astr,shape=apply_shape))
if node.op.__class__.__name__ in ('GpuFromHost','HostFromGpu'):
# highlight CPU-GPU transfers to simplify optimization
g.add_node(pd.Node(astr,color='red',shape=apply_shape))
else:
g.add_node(pd.Node(astr,shape=apply_shape))
for id,var in enumerate(node.inputs): for id,var in enumerate(node.inputs):
varstr=var_name(var) varstr=var_name(var)
label='' label=''
......
...@@ -51,7 +51,8 @@ class CudaNdarrayType(Type): ...@@ -51,7 +51,8 @@ class CudaNdarrayType(Type):
def __init__(self, broadcastable, name=None, dtype=None): def __init__(self, broadcastable, name=None, dtype=None):
if dtype != None and dtype != 'float32': if dtype != None and dtype != 'float32':
raise TypeError(self.__class__.__name__+' only support dtype float32 for now. Tried to use dtype %s'%dtype) raise TypeError(self.__class__.__name__+' only support dtype float32 for now.'\
'Tried using dtype %s for variable %s' % (dtype, name))
self.broadcastable = tuple(broadcastable) self.broadcastable = tuple(broadcastable)
self.name = name self.name = name
self.dtype_specs() # error checking is done there self.dtype_specs() # error checking is done there
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论