提交 7b25ea4b authored 作者: James Bergstra's avatar James Bergstra

Moved get_constant_value to tensor.basic to use it in Alloc.make_node and/or infer_shape

上级 6fa0a8ae
...@@ -304,6 +304,41 @@ def _allclose(a, b): ...@@ -304,6 +304,41 @@ def _allclose(a, b):
rtol = float64_rtol rtol = float64_rtol
return numpy.allclose(a,b, atol=atol, rtol=rtol) return numpy.allclose(a,b, atol=atol, rtol=rtol)
def get_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts,
this function digs through them.
If `v` is not some view of constant data, then raise a TypeError.
:note: There may be another function similar to this one in the code, but I'm not sure where it
is.
"""
if isinstance(v, Constant):
#TODO: consider checking for arrays of the form e.g. [1,1,1,1] where
# it is not a constant, but in some cases it *could* be replaced with one.
# Note that this would have an effect on the broadcasting of inputs and so on
try:
complex(v.data) #works for all numeric scalars
return v.data
except:
raise TypeError(v)
if v.owner:
if isinstance(v.owner.op, Alloc):
return get_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, DimShuffle):
return get_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, Rebroadcast):
return get_constant_value(v.owner.inputs[0])
if v.owner.op == fill:
shape, val = v.owner.inputs
# fill(a,b) fills the shape of 'a' filled with 'b'
return get_constant_value(val)
raise TypeError(v)
class TensorType(Type): class TensorType(Type):
"""Symbolic `Type` representing a numpy.ndarray value.""" """Symbolic `Type` representing a numpy.ndarray value."""
...@@ -980,6 +1015,9 @@ class _tensor_py_operators: ...@@ -980,6 +1015,9 @@ class _tensor_py_operators:
__array_priority__ = 1000 __array_priority__ = 1000
def get_constant_value(self):
return get_constant_value(self)
class TensorVariable(Variable, _tensor_py_operators): class TensorVariable(Variable, _tensor_py_operators):
"""Subclass to add the tensor operators to the basic `Variable` class.""" """Subclass to add the tensor operators to the basic `Variable` class."""
TensorType.Variable = TensorVariable TensorType.Variable = TensorVariable
...@@ -1754,7 +1792,11 @@ class Alloc(gof.Op): ...@@ -1754,7 +1792,11 @@ class Alloc(gof.Op):
if s.type.dtype[:3] not in ('int', 'uin'): if s.type.dtype[:3] not in ('int', 'uin'):
raise TypeError('Shape arguments must be integers', s) raise TypeError('Shape arguments must be integers', s)
# if s is constant 1, then we're broadcastable in that dim # if s is constant 1, then we're broadcastable in that dim
bcast.append(isinstance(s, TensorConstant) and (s.data == 1)) try:
const_shp = get_constant_value(s)
except TypeError:
const_shp = None
bcast.append(numpy.all(1 == const_shp))
otype = TensorType(dtype=v.dtype, broadcastable=bcast) otype = TensorType(dtype=v.dtype, broadcastable=bcast)
return gof.Apply(self, [v]+sh, [otype()]) return gof.Apply(self, [v]+sh, [otype()])
......
...@@ -25,6 +25,7 @@ from theano import compile #to register the optimizer built by this file ...@@ -25,6 +25,7 @@ from theano import compile #to register the optimizer built by this file
from theano.gof.python25 import any, all from theano.gof.python25 import any, all
from theano.gof.opt import Optimizer from theano.gof.opt import Optimizer
from theano.gof import toolbox, DestroyHandler from theano.gof import toolbox, DestroyHandler
from basic import get_constant_value
# Utilities # Utilities
...@@ -63,40 +64,6 @@ def encompasses_broadcastable(b1, b2): ...@@ -63,40 +64,6 @@ def encompasses_broadcastable(b1, b2):
def merge_broadcastables(broadcastables): def merge_broadcastables(broadcastables):
return [all(bcast) for bcast in zip(*broadcastables)] return [all(bcast) for bcast in zip(*broadcastables)]
def get_constant_value(v):
"""return the constant scalar(0-D) value underlying variable `v`
If v is the output of dimshuffles, fills, allocs, rebroadcasts,
this function digs through them.
If `v` is not some view of constant data, then raise a TypeError.
:note: There may be another function similar to this one in the code, but I'm not sure where it
is.
"""
if isinstance(v, Constant):
#TODO: consider checking for arrays of the form e.g. [1,1,1,1] where
# it is not a constant, but in some cases it *could* be replaced with one.
# Note that this would have an effect on the broadcasting of inputs and so on
try:
complex(v.data) #works for all numeric scalars
return v.data
except:
raise TypeError(v)
if v.owner:
if isinstance(v.owner.op, T.Alloc):
return get_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, T.DimShuffle):
return get_constant_value(v.owner.inputs[0])
if isinstance(v.owner.op, T.Rebroadcast):
return get_constant_value(v.owner.inputs[0])
if v.owner.op == T.fill:
shape, val = v.owner.inputs
# fill(a,b) fills the shape of 'a' filled with 'b'
return get_constant_value(val)
raise TypeError(v)
def scalarconsts_rest(inputs): def scalarconsts_rest(inputs):
"""Partition a list of variables into two kinds: """Partition a list of variables into two kinds:
scalar constants, and the rest.""" scalar constants, and the rest."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论