提交 c78b586e authored 作者: Frederic Bastien's avatar Frederic Bastien

import theano/tensor/nnet/conv by default. fix import problem to allow this.

上级 08ff29bf
from nnet import *
#from conv import conv2d, ConvOp # causes circular import
from conv import conv2d, ConvOp
from Conv3D import *
from ConvGrad3D import *
from ConvTransp3D import *
......
......@@ -15,8 +15,8 @@ import logging
import numpy
import theano
import theano.tensor as tensor
from theano import gof, Op, tensor, config
from theano.tensor import get_constant_value, blas, as_tensor_variable
from theano import Op, config
from theano.gof.apply_shape import Apply
from theano.gof.python25 import any
......@@ -29,7 +29,6 @@ try:
except ImportError:
pass
_logger=logging.getLogger("theano.signal.conv")
def _debug(*msg):
_logger.debug(' '.join([ str(x) for x in msg]))
......@@ -79,14 +78,14 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
image_shape = list(image_shape)
for i in range(len(image_shape)):
if image_shape[i] is not None:
image_shape[i] = tensor.get_constant_value(tensor.as_tensor_variable(image_shape[i]))
image_shape[i] = get_constant_value(as_tensor_variable(image_shape[i]))
assert str(image_shape[i].dtype).startswith('int')
image_shape[i] = int(image_shape[i])
if filter_shape is not None:
filter_shape = list(filter_shape)
for i in range(len(filter_shape)):
if filter_shape[i] is not None:
filter_shape[i] = tensor.get_constant_value(tensor.as_tensor_variable(filter_shape[i]))
filter_shape[i] = get_constant_value(as_tensor_variable(filter_shape[i]))
assert str(filter_shape[i].dtype).startswith('int')
filter_shape[i] = int(filter_shape[i])
......@@ -537,8 +536,8 @@ class ConvOp(Op):
kerns - 4 dim: nkern x stackidx x rows x cols
"""
outdim = kerns.ndim
_inputs = tensor.as_tensor_variable(inputs)
_kerns = tensor.as_tensor_variable(kerns)
_inputs = as_tensor_variable(inputs)
_kerns = as_tensor_variable(kerns)
# TODO: lift this restriction by upcasting either inputs or kerns
if _inputs.ndim != 4:
raise TypeError('make_node requires 4D tensor of inputs')
......@@ -551,7 +550,7 @@ class ConvOp(Op):
bcastable23 = [self.outshp[0]==1, self.outshp[1]==1]
else:
bcastable23 = [False, False]
output = tensor.tensor(dtype=_inputs.type.dtype,
output = theano.tensor.tensor(dtype=_inputs.type.dtype,
broadcastable=[_inputs.broadcastable[0],
_kerns.broadcastable[0]]+bcastable23);
......@@ -857,7 +856,7 @@ class ConvOp(Op):
#define VALID 0
#define MOD %
using namespace std;
""" + tensor.blas.blas_header_text()
""" + blas.blas_header_text()
def use_blas(self):
""" Return True if we will generate code that use gemm.
......@@ -874,22 +873,22 @@ using namespace std;
def c_libraries(self):
if self.use_blas():
return tensor.blas.ldflags()
return blas.ldflags()
return []
def c_compile_args(self):
if self.use_blas():
return tensor.blas.ldflags(libs=False, flags=True)
return blas.ldflags(libs=False, flags=True)
return []
def c_lib_dirs(self):
if self.use_blas():
return tensor.blas.ldflags(libs=False, libs_dir=True)
return blas.ldflags(libs=False, libs_dir=True)
return []
def c_header_dirs(self):
if self.use_blas():
return tensor.blas.ldflags(libs=False, include_dir=True)
return blas.ldflags(libs=False, include_dir=True)
return []
def c_code(self, node, name, (img2d, filtersflipped), (z, ), sub):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论