提交 2e404de3 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added comments

上级 996396ef
...@@ -64,6 +64,16 @@ def constant(x): ...@@ -64,6 +64,16 @@ def constant(x):
class Scalar(Type): class Scalar(Type):
"""
Internal class, should not be used by clients
Primarily used by tensor.elemwise and tensor.reduce
Analogous to TensorType, but for zero-dimensional objects
Maps directly to C primitives
TODO: refactor to be named ScalarType for consistency with TensorType
"""
def __init__(self, dtype): def __init__(self, dtype):
if dtype == 'floatX': if dtype == 'floatX':
dtype = config.floatX dtype = config.floatX
......
...@@ -537,7 +537,10 @@ class Elemwise(Op): ...@@ -537,7 +537,10 @@ class Elemwise(Op):
def grad(self, inputs, ograds): def grad(self, inputs, ograds):
# Gradients (especially on the final costs) don't have to be symbolic # Gradients (especially on the final costs) don't have to be symbolic
# e.g., ograds will be [ 1. ] if your objective is c and the output
# of the current apply node is c
ograds = map(as_tensor_variable, ograds) ograds = map(as_tensor_variable, ograds)
scalar_inputs = [Scalar(dtype = t.type.dtype)() for t in inputs] scalar_inputs = [Scalar(dtype = t.type.dtype)() for t in inputs]
scalar_ograds = [Scalar(dtype = ograd.type.dtype)() for ograd in ograds] scalar_ograds = [Scalar(dtype = ograd.type.dtype)() for ograd in ograds]
scalar_igrads = self.scalar_op.grad(scalar_inputs, scalar_ograds) scalar_igrads = self.scalar_op.grad(scalar_inputs, scalar_ograds)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论