提交 ee650637 authored 作者: Sina Honari's avatar Sina Honari

adding documentation to addbroadcast() unbroadcast() and patternbroadcast()

上级 e9fee711
...@@ -3265,9 +3265,26 @@ class Split(Op): ...@@ -3265,9 +3265,26 @@ class Split(Op):
def addbroadcast(x, *axes): def addbroadcast(x, *axes):
""" """
Make the input broadcastable in the specified axes. Make the input broadcastable in the specified axes.
For example, addbroadcast(x, 0) will make the first dimension of
x broadcastable. When performing the function, if the length of
x along that dimension is not 1, a ValueError will be raised.
We apply the opt here not to pollute the graph especially during We apply the opt here not to pollute the graph especially during
the gpu optimization the gpu optimization
Parameters:
------------
x : tensor_like
Input theano tensor.
axis : an int or an iterable object such as list or tuple
of int values
The dimension along which the tensor x should be broadcastable.
if the length of x along these dimensions is not 1,
a ValueError will be raised.
returns:
----------
a theano tensor, which is broadcastable along the specified dimensions.
""" """
rval = Rebroadcast(*[(axis, True) for axis in axes])(x) rval = Rebroadcast(*[(axis, True) for axis in axes])(x)
return theano.tensor.opt.apply_rebroadcast_opt(rval) return theano.tensor.opt.apply_rebroadcast_opt(rval)
...@@ -3276,9 +3293,26 @@ def addbroadcast(x, *axes): ...@@ -3276,9 +3293,26 @@ def addbroadcast(x, *axes):
def unbroadcast(x, *axes): def unbroadcast(x, *axes):
""" """
Make the input impossible to broadcast in the specified axes. Make the input impossible to broadcast in the specified axes.
For example, addbroadcast(x, 0) will make the first dimension
of x broadcastable. When performing the function, if the length
of x along that dimension is not 1, a ValueError will be raised.
We apply the opt here not to pollute the graph especially during We apply the opt here not to pollute the graph especially during
the gpu optimization the gpu optimization
Parameters:
------------
x : tensor_like
Input theano tensor.
axis : an int or an iterable object such as list or tuple
of int values
The dimension along which the tensor x should be unbroadcastable.
if the length of x along these dimensions is not 1,
a ValueError will be raised.
returns:
----------
a theano tensor, which is unbroadcastable along the specified dimensions.
""" """
rval = Rebroadcast(*[(axis, False) for axis in axes])(x) rval = Rebroadcast(*[(axis, False) for axis in axes])(x)
return theano.tensor.opt.apply_rebroadcast_opt(rval) return theano.tensor.opt.apply_rebroadcast_opt(rval)
...@@ -3287,9 +3321,28 @@ def unbroadcast(x, *axes): ...@@ -3287,9 +3321,28 @@ def unbroadcast(x, *axes):
def patternbroadcast(x, broadcastable): def patternbroadcast(x, broadcastable):
""" """
Make the input adopt a specific broadcasting pattern. Make the input adopt a specific broadcasting pattern.
broadcastable must be iterable. For example,
patternbroadcast(x, (True, False)) will make the first
dimension of x broadcastable and the second dimension
not broadcastable, so x will now be a row.
We apply the opt here not to pollute the graph especially during the gpu We apply the opt here not to pollute the graph especially during the gpu
optimization. optimization.
Parameters:
------------
x : tensor_like
Input theano tensor.
broadcastable : an iterable object such as list or tuple
of bool values
a set of boolean values indicating whether a dimension
should be broadcastable or not.
if the length of x along these dimensions is not 1,
a ValueError will be raised.
returns:
----------
a theano tensor, which is unbroadcastable along the specified dimensions.
""" """
rval = Rebroadcast(*[(i, broadcastable[i]) rval = Rebroadcast(*[(i, broadcastable[i])
for i in xrange(len(broadcastable))])(x) for i in xrange(len(broadcastable))])(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论