提交 a2fc1b7a authored 作者: sentient07's avatar sentient07

Handling tracks and fixed two errors

上级 25ffd68b
...@@ -1250,6 +1250,11 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1250,6 +1250,11 @@ class LocalOptGroup(LocalOptimizer):
for opt in optimizers) for opt in optimizers)
self.apply_all_opts = kwargs.get('apply_all_opts', False) self.apply_all_opts = kwargs.get('apply_all_opts', False)
self.track_map = OrderedDict()
for o in self.opts:
for c in o.tracks():
self.track_map.setdefault(c, []).append(o)
def __str__(self): def __str__(self):
...@@ -1283,8 +1288,8 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1283,8 +1288,8 @@ class LocalOptGroup(LocalOptimizer):
new_node = repl[0].owner new_node = repl[0].owner
apply_mult_opts(opt_list, new_node, True) apply_mult_opts(opt_list, new_node, True)
return repl return repl
opts = self.track_map.get(type(node.op), [])
return apply_mult_opts(self.opts, node, self.apply_all_opts) return apply_mult_opts(opts, node, self.apply_all_opts)
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s id=%i" % ( print("%s%s id=%i" % (
......
...@@ -2354,7 +2354,7 @@ def zeros(shape, dtype=None): ...@@ -2354,7 +2354,7 @@ def zeros(shape, dtype=None):
shape = [shape] shape = [shape]
if dtype is None: if dtype is None:
dtype = config.floatX dtype = config.floatX
if len(shape) == 0: if isinstance(shape, (list, tuple)) and len(shape) == 0:
return constant(0.0, dtype=dtype) return constant(0.0, dtype=dtype)
return alloc(numpy.array(0, dtype=dtype), *shape) return alloc(numpy.array(0, dtype=dtype), *shape)
...@@ -2367,7 +2367,7 @@ def ones(shape, dtype=None): ...@@ -2367,7 +2367,7 @@ def ones(shape, dtype=None):
shape = [shape] shape = [shape]
if dtype is None: if dtype is None:
dtype = config.floatX dtype = config.floatX
if len(shape) == 0: if isinstance(shape, (list, tuple)) and len(shape) == 0:
return constant(1.0, dtype=dtype) return constant(1.0, dtype=dtype)
return alloc(numpy.array(1, dtype=dtype), *shape) return alloc(numpy.array(1, dtype=dtype), *shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论