提交 c78b586e authored 作者: Frederic Bastien's avatar Frederic Bastien

import theano/tensor/nnet/conv by default. fix import problem to allow this.

上级 08ff29bf
from nnet import * from nnet import *
#from conv import conv2d, ConvOp # causes circular import from conv import conv2d, ConvOp
from Conv3D import * from Conv3D import *
from ConvGrad3D import * from ConvGrad3D import *
from ConvTransp3D import * from ConvTransp3D import *
......
...@@ -15,8 +15,8 @@ import logging ...@@ -15,8 +15,8 @@ import logging
import numpy import numpy
import theano import theano
import theano.tensor as tensor from theano.tensor import get_constant_value, blas, as_tensor_variable
from theano import gof, Op, tensor, config from theano import Op, config
from theano.gof.apply_shape import Apply from theano.gof.apply_shape import Apply
from theano.gof.python25 import any from theano.gof.python25 import any
...@@ -29,7 +29,6 @@ try: ...@@ -29,7 +29,6 @@ try:
except ImportError: except ImportError:
pass pass
_logger=logging.getLogger("theano.signal.conv") _logger=logging.getLogger("theano.signal.conv")
def _debug(*msg): def _debug(*msg):
_logger.debug(' '.join([ str(x) for x in msg])) _logger.debug(' '.join([ str(x) for x in msg]))
...@@ -79,14 +78,14 @@ def conv2d(input, filters, image_shape=None, filter_shape=None, ...@@ -79,14 +78,14 @@ def conv2d(input, filters, image_shape=None, filter_shape=None,
image_shape = list(image_shape) image_shape = list(image_shape)
for i in range(len(image_shape)): for i in range(len(image_shape)):
if image_shape[i] is not None: if image_shape[i] is not None:
image_shape[i] = tensor.get_constant_value(tensor.as_tensor_variable(image_shape[i])) image_shape[i] = get_constant_value(as_tensor_variable(image_shape[i]))
assert str(image_shape[i].dtype).startswith('int') assert str(image_shape[i].dtype).startswith('int')
image_shape[i] = int(image_shape[i]) image_shape[i] = int(image_shape[i])
if filter_shape is not None: if filter_shape is not None:
filter_shape = list(filter_shape) filter_shape = list(filter_shape)
for i in range(len(filter_shape)): for i in range(len(filter_shape)):
if filter_shape[i] is not None: if filter_shape[i] is not None:
filter_shape[i] = tensor.get_constant_value(tensor.as_tensor_variable(filter_shape[i])) filter_shape[i] = get_constant_value(as_tensor_variable(filter_shape[i]))
assert str(filter_shape[i].dtype).startswith('int') assert str(filter_shape[i].dtype).startswith('int')
filter_shape[i] = int(filter_shape[i]) filter_shape[i] = int(filter_shape[i])
...@@ -537,8 +536,8 @@ class ConvOp(Op): ...@@ -537,8 +536,8 @@ class ConvOp(Op):
kerns - 4 dim: nkern x stackidx x rows x cols kerns - 4 dim: nkern x stackidx x rows x cols
""" """
outdim = kerns.ndim outdim = kerns.ndim
_inputs = tensor.as_tensor_variable(inputs) _inputs = as_tensor_variable(inputs)
_kerns = tensor.as_tensor_variable(kerns) _kerns = as_tensor_variable(kerns)
# TODO: lift this restriction by upcasting either inputs or kerns # TODO: lift this restriction by upcasting either inputs or kerns
if _inputs.ndim != 4: if _inputs.ndim != 4:
raise TypeError('make_node requires 4D tensor of inputs') raise TypeError('make_node requires 4D tensor of inputs')
...@@ -551,7 +550,7 @@ class ConvOp(Op): ...@@ -551,7 +550,7 @@ class ConvOp(Op):
bcastable23 = [self.outshp[0]==1, self.outshp[1]==1] bcastable23 = [self.outshp[0]==1, self.outshp[1]==1]
else: else:
bcastable23 = [False, False] bcastable23 = [False, False]
output = tensor.tensor(dtype=_inputs.type.dtype, output = theano.tensor.tensor(dtype=_inputs.type.dtype,
broadcastable=[_inputs.broadcastable[0], broadcastable=[_inputs.broadcastable[0],
_kerns.broadcastable[0]]+bcastable23); _kerns.broadcastable[0]]+bcastable23);
...@@ -857,7 +856,7 @@ class ConvOp(Op): ...@@ -857,7 +856,7 @@ class ConvOp(Op):
#define VALID 0 #define VALID 0
#define MOD % #define MOD %
using namespace std; using namespace std;
""" + tensor.blas.blas_header_text() """ + blas.blas_header_text()
def use_blas(self): def use_blas(self):
""" Return True if we will generate code that use gemm. """ Return True if we will generate code that use gemm.
...@@ -874,22 +873,22 @@ using namespace std; ...@@ -874,22 +873,22 @@ using namespace std;
def c_libraries(self): def c_libraries(self):
if self.use_blas(): if self.use_blas():
return tensor.blas.ldflags() return blas.ldflags()
return [] return []
def c_compile_args(self): def c_compile_args(self):
if self.use_blas(): if self.use_blas():
return tensor.blas.ldflags(libs=False, flags=True) return blas.ldflags(libs=False, flags=True)
return [] return []
def c_lib_dirs(self): def c_lib_dirs(self):
if self.use_blas(): if self.use_blas():
return tensor.blas.ldflags(libs=False, libs_dir=True) return blas.ldflags(libs=False, libs_dir=True)
return [] return []
def c_header_dirs(self): def c_header_dirs(self):
if self.use_blas(): if self.use_blas():
return tensor.blas.ldflags(libs=False, include_dir=True) return blas.ldflags(libs=False, include_dir=True)
return [] return []
def c_code(self, node, name, (img2d, filtersflipped), (z, ), sub): def c_code(self, node, name, (img2d, filtersflipped), (z, ), sub):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论