提交 e205110f authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added "node" argument to ops that already had a connection pattern

上级 f08fe14e
...@@ -2097,7 +2097,7 @@ class Shape(Op): ...@@ -2097,7 +2097,7 @@ class Shape(Op):
def infer_shape(self, node, in_shapes): def infer_shape(self, node, in_shapes):
return [[len(in_shapes[0])]] return [[len(in_shapes[0])]]
def connection_pattern(self): def connection_pattern(self, node):
# the grad returns the gradient with respect to the # the grad returns the gradient with respect to the
# elements of a tensor variable # elements of a tensor variable
# the elements of the tensor variable do not participate # the elements of the tensor variable do not participate
...@@ -2193,6 +2193,9 @@ class SpecifyShape(Op): ...@@ -2193,6 +2193,9 @@ class SpecifyShape(Op):
assert len(new_shape) == len(xshape) assert len(new_shape) == len(xshape)
return [new_shape] return [new_shape]
def connection_pattern(self, node):
return [[True],[False]]
def grad(self, inp, grads): def grad(self, inp, grads):
x, s = inp x, s = inp
gz, = grads gz, = grads
...@@ -3886,6 +3889,7 @@ class Subtensor(Op): ...@@ -3886,6 +3889,7 @@ class Subtensor(Op):
else: else:
return scal.as_scalar(a) return scal.as_scalar(a)
def make_node(self, x, *inputs): def make_node(self, x, *inputs):
x = as_tensor_variable(x) x = as_tensor_variable(x)
inputs = tuple(self.my_as_scalar(a) for a in inputs) inputs = tuple(self.my_as_scalar(a) for a in inputs)
...@@ -6461,6 +6465,7 @@ class Dot(Op): ...@@ -6461,6 +6465,7 @@ class Dot(Op):
raise raise
def grad(self, inp, grads): def grad(self, inp, grads):
x, y = inp x, y = inp
gz, = grads gz, = grads
if gz.type.ndim == 0: if gz.type.ndim == 0:
......
...@@ -329,7 +329,7 @@ def test_disconnected_nan(): ...@@ -329,7 +329,7 @@ def test_disconnected_nan():
return theano.Apply(self, inputs=[x], return theano.Apply(self, inputs=[x],
outputs = [ x.type(), theano.tensor.scalar() ]) outputs = [ x.type(), theano.tensor.scalar() ])
def connection_pattern(self): def connection_pattern(self, node):
return [[True, False]] return [[True, False]]
def grad(self, inputs, output_grads): def grad(self, inputs, output_grads):
...@@ -360,5 +360,7 @@ def test_disconnected_nan(): ...@@ -360,5 +360,7 @@ def test_disconnected_nan():
# connection_pattern functionality worked correctly # connection_pattern functionality worked correctly
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论