提交 c9300c5d authored 作者: sentient07's avatar sentient07

Fixed cycles in graph and added back the condition checking empty shape

上级 21aaf13c
...@@ -1256,7 +1256,6 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1256,7 +1256,6 @@ class LocalOptGroup(LocalOptimizer):
for c in o.tracks(): for c in o.tracks():
self.track_map.setdefault(c, []).append(o) self.track_map.setdefault(c, []).append(o)
def __str__(self): def __str__(self):
return getattr(self, '__name__', return getattr(self, '__name__',
('LocalOptGroup(%s)' % ('LocalOptGroup(%s)' %
...@@ -1273,6 +1272,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -1273,6 +1272,7 @@ class LocalOptGroup(LocalOptimizer):
def transform(self, node): def transform(self, node):
if len(self.opts) == 0: if len(self.opts) == 0:
return return
def apply_mult_opts(opt_list, node, multiple_opts=False): def apply_mult_opts(opt_list, node, multiple_opts=False):
repl = False repl = False
for opt in opt_list: for opt in opt_list:
......
...@@ -413,8 +413,10 @@ class LocalGroupDB(DB): ...@@ -413,8 +413,10 @@ class LocalGroupDB(DB):
ret = opt.LocalOptGroup(*opts, apply_all_opts=self.apply_all_opts) ret = opt.LocalOptGroup(*opts, apply_all_opts=self.apply_all_opts)
return ret return ret
class TopoDB(DB): class TopoDB(DB):
""" """
Generate a local optimizer of type LocalOptGroup instead Generate a local optimizer of type LocalOptGroup instead
of a global optimizer. of a global optimizer.
......
...@@ -2354,6 +2354,8 @@ def zeros(shape, dtype=None): ...@@ -2354,6 +2354,8 @@ def zeros(shape, dtype=None):
shape = [shape] shape = [shape]
if dtype is None: if dtype is None:
dtype = config.floatX dtype = config.floatX
if isinstance(shape, (list, tuple)) and len(shape) == 0:
return constant(0.0, dtype=dtype)
return alloc(numpy.array(0, dtype=dtype), *shape) return alloc(numpy.array(0, dtype=dtype), *shape)
...@@ -2365,6 +2367,8 @@ def ones(shape, dtype=None): ...@@ -2365,6 +2367,8 @@ def ones(shape, dtype=None):
shape = [shape] shape = [shape]
if dtype is None: if dtype is None:
dtype = config.floatX dtype = config.floatX
if isinstance(shape, (list, tuple)) and len(shape) == 0:
return constant(1.0, dtype=dtype)
return alloc(numpy.array(1, dtype=dtype), *shape) return alloc(numpy.array(1, dtype=dtype), *shape)
......
...@@ -22,7 +22,7 @@ from theano import gof ...@@ -22,7 +22,7 @@ from theano import gof
from theano.compat import izip from theano.compat import izip
from theano.gof import opt, InconsistencyError, TopoOptimizer, graph from theano.gof import opt, InconsistencyError, TopoOptimizer, graph
from theano.gof import Variable, Constant from theano.gof import Variable, Constant
from theano.gof.opt import copy_stack_trace, in2out, out2in, LocalOptGroup from theano.gof.opt import copy_stack_trace, in2out
from theano.gof.utils import MethodNotDefined from theano.gof.utils import MethodNotDefined
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.configparser import config from theano.configparser import config
...@@ -56,6 +56,7 @@ _logger = logging.getLogger('theano.tensor.opt') ...@@ -56,6 +56,7 @@ _logger = logging.getLogger('theano.tensor.opt')
# Utilities # Utilities
def _fill_chain(new_out, orig_inputs): def _fill_chain(new_out, orig_inputs):
for i in orig_inputs: for i in orig_inputs:
new_out = T.fill(i, new_out) new_out = T.fill(i, new_out)
...@@ -1901,7 +1902,7 @@ def local_subtensor_remove_broadcastable_index(node): ...@@ -1901,7 +1902,7 @@ def local_subtensor_remove_broadcastable_index(node):
@register_specialize @register_specialize
@register_canonicalize('fast_compile_gpu') @register_canonicalize('fast_compile_gpu')
#@register_useless # @register_useless
@gof.local_optimizer([Subtensor, AdvancedSubtensor1]) @gof.local_optimizer([Subtensor, AdvancedSubtensor1])
def local_subtensor_make_vector(node): def local_subtensor_make_vector(node):
""" """
...@@ -2035,7 +2036,7 @@ def local_useless_elemwise(node): ...@@ -2035,7 +2036,7 @@ def local_useless_elemwise(node):
return [node.inputs[0]] return [node.inputs[0]]
elif (node.op.scalar_op == theano.scalar.identity and elif (node.op.scalar_op == theano.scalar.identity and
len(node.inputs) == 1): len(node.inputs) == 1):
return [node.inputs[0]] return
elif (isinstance(node.op.scalar_op, scalar.AND) and elif (isinstance(node.op.scalar_op, scalar.AND) and
len(node.inputs) == 2): len(node.inputs) == 2):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论