提交 cdafffc7 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

upgraded shape op's connection_pattern to the new type

made elemwise ops handle DisconnectedType correctly
上级 bf0e90ff
...@@ -26,6 +26,8 @@ from theano.gof import Op, utils, Variable, Constant, Type, Apply, FunctionGraph ...@@ -26,6 +26,8 @@ from theano.gof import Op, utils, Variable, Constant, Type, Apply, FunctionGraph
from theano.gof.python25 import partial, all, any from theano.gof.python25 import partial, all, any
from theano.configparser import config from theano.configparser import config
from theano.gradient import DisconnectedType
builtin_complex = complex builtin_complex = complex
builtin_int = int builtin_int = int
builtin_float = float builtin_float = float
......
...@@ -2103,7 +2103,7 @@ class Shape(Op): ...@@ -2103,7 +2103,7 @@ class Shape(Op):
# the elements of the tensor variable do not participate # the elements of the tensor variable do not participate
# in the computation of the shape, so they are not really # in the computation of the shape, so they are not really
# part of the graph # part of the graph
return [False] return [[False]]
def grad(self, inp, grads): def grad(self, inp, grads):
# the grad returns the gradient with respect to the # the grad returns the gradient with respect to the
......
...@@ -14,6 +14,7 @@ from theano.scalar import Scalar ...@@ -14,6 +14,7 @@ from theano.scalar import Scalar
from theano.printing import min_informative_str, pprint from theano.printing import min_informative_str, pprint
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
from theano.tensor.utils import hash_from_dict from theano.tensor.utils import hash_from_dict
from theano.gradient import DisconnectedType
config = theano.config config = theano.config
...@@ -680,6 +681,8 @@ class Elemwise(Op): ...@@ -680,6 +681,8 @@ class Elemwise(Op):
def transform(r): def transform(r):
# From a graph of ScalarOps, make a graph of Broadcast ops. # From a graph of ScalarOps, make a graph of Broadcast ops.
if isinstance(r.type, DisconnectedType):
return r
if r in scalar_inputs: if r in scalar_inputs:
return inputs[scalar_inputs.index(r)] return inputs[scalar_inputs.index(r)]
if r in scalar_ograds: if r in scalar_ograds:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论