Initial commit

This commit is contained in:
Laurent El Shafey 2024-12-10 08:56:11 -08:00
commit 9fdd561586
246 changed files with 58283 additions and 0 deletions

163
include/convnet.cuh Normal file
View file

@ -0,0 +1,163 @@
/*
* Copyright (c) 2011, Alex Krizhevsky (akrizhevsky@gmail.com)
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* - Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* - Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
* EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef CONVNET3
#define CONVNET3
#include <vector>
#include <string>
#include <set>
#include <map>
#include <helper_cuda.h>
#include <time.h>
#include <queue.h>
#include <thread.h>
#include <math.h>
#include <sync.h>
#include <quantizer.cuh>
#include <messages.cuh>
#include <pipedispenser.cuh>
#include "layer.cuh"
#include "data.cuh"
#include "worker.cuh"
#include "weights.cuh"
#include "hostmem.cuh"
class Worker;
class WorkResult;
class Layer;
class DataLayer;
class CostLayer;
class ConvNetGPU;
class ConvNet : public Thread {
protected:
std::map<std::string,Layer*> _layerMap;
std::vector<DataLayer*> _dataLayers;
std::vector<ConvNetGPU*> _convNetThreads; // List of convnet threads
DataProvider* _dp;
CPUData* _data;
ThreadSynchronizer* _sync;
PipeDispenser* _pd;
intv* _deviceIDs;
std::vector<intv*>* _deviceCPUs;
Queue<Worker*> _workerQueue;
Queue<WorkResult*> _resultQueue;
Queue<Message*> _msgQueue;
int _numFwdTerminal, _numBwdTerminal;
int _weightUpdateFreq, _numBwdMiniPasses;
// For gradient checking
int _numFailures;
int _numTests;
// Training progress (between 0 and 1).
// Used to determine learning rate based on LearningRateSchedule.
double _trainingProgress;
double _baseErr;
void waitForTerminals(int numMsgs, MESSAGES msg);
void sendMessage(MESSAGES msg, bool sync);
void findBwdTerminal(Layer& l, std::set<std::string>& visited, std::set<std::string> &terminal);
void* run();
public:
ConvNet(PyObject* layerParams, intv& deviceIDs, std::vector<intv*>& deviceCPUs, int minibatchSize, int weightUpdateFreq);
Queue<Message*>& getMessageQueue();
Queue<Worker*>& getWorkerQueue();
Queue<WorkResult*>& getResultQueue();
DataProvider& getDataProvider();
Layer& operator[](string& name);
Layer& getLayer(string& name);
void copyToCPU();
void copyToGPU();
void updateWeights();
void reset();
void bprop(PASS_TYPE passType);
void fprop(PASS_TYPE passType);
void fprop(int miniIdx, PASS_TYPE passType);
void fprop(CPUData& data, PASS_TYPE passType);
void setTrainingProgress(double progress);
double getTrainingProgress() const;
bool checkGradient(const std::string& name, float eps, Weights& weights);
void checkGradients();
Cost& getCost();
Cost& getCost(Cost& cost);
double getCostValue();
int getDeviceID(int gpuIdx);
intv& getDeviceIDs();
ThreadSynchronizer& getSync();
void syncWithChildren();
int getWeightUpdateFreq();
int getNumBwdMiniPasses();
int getMinibatchSize();
PipeDispenser& getPipeDispenser();
};
class ConvNetGPU : public Thread {
protected:
std::map<std::string,Layer*> _layerMap;
std::vector<CostLayer*> _costs;
ConvNet* _convNet;
int _deviceID;
Queue<Message*> _msgQueue;
void initCuda();
virtual void initLayer(PyObject* paramsDict);
void* run();
void copyToCPU();
void copyToGPU();
void updateWeights();
void reset();
public:
ConvNetGPU(PyObject* layerList, int deviceID, intv& deviceCPUs, ConvNet* convNet);
std::map<std::string, Layer*>& getLayerMap();
void bprop(PASS_TYPE passType);
void fprop(PASS_TYPE passType);
void fprop(int miniIdx, PASS_TYPE passType);
int getDeviceID();
ConvNet& getConvNet();
void enqueueMessage(Message* msg);
Queue<Message*>& getMessageQueue();
std::vector<CostLayer*>& getCostLayers();
Cost& getCost(int numCases);
Layer& operator[](string& name);
Layer& getLayer(string& name);
};
#endif /* CONVNET */

66
include/cost.cuh Normal file
View file

@ -0,0 +1,66 @@
/*
* Copyright (c) 2011, Alex Krizhevsky (akrizhevsky@gmail.com)
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* - Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* - Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
* EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef COST_CUH
#define COST_CUH
#include <vector>
#include <map>
#include <helper_cuda.h>
#include "layer.cuh"
#include "util.cuh"
class CostLayer;
/*
* Wrapper for dictionary mapping cost name to vector of returned values.
*/
class Cost {
private:
int _numCases;
CostMap _costMap;
CostCoeffMap _costCoeffMap;
public:
Cost(int numCases);
Cost(int numCases, std::vector<CostLayer*>& costs);
doublev& operator [](const std::string s);
CostMap& getCostMap();
CostCoeffMap& getCostCoeffMap();
int getNumCases();
/*
* Returns sum of first values returned by all the costs, weighted by the cost coefficients.
*/
double getValue();
Cost& operator += (Cost& er);
Cost& operator |= (Cost& er);
Cost& operator /= (const double v);
virtual ~Cost();
};
#endif /* COST_CUH */

31
include/cpuCNN.cuh Normal file
View file

@ -0,0 +1,31 @@
/*
* File: cpuFuncs.h
* Author: Alex Krizhevsky
*
* Created on September 10, 2012, 5:05 PM
*/
#ifndef CPUFUNCS_H
#define CPUFUNCS_H
#include <helper_cuda.h>
#include <softmaxtree.cuh>
/*
* weights: (numNodes, numFeatures)
* nodes: numNodesAtDepth-length array of ushort2
* where x coordinate gives node idx and y coordinate gives parent idx
* targets: (numNodes, numFeatures)
*
*/
void cpuSoftmaxTreeFwd(float* weights, float* targets, const int numFeatures, SoftmaxTree& tree);
/*
* grads: (numNodes, numFeatures)
*
*/
void cpuSoftmaxTreeBwd(float* grads, const int numFeatures, SoftmaxTree& tree);
void cpuSoftmaxTreeUpdateWeights(float* weights, float* weightsInc, float* weightsGrad,
const int numFeatures, float eps, const float mom, float wc, SoftmaxTree& tree);
#endif /* CPUFUNCS_H */

111
include/data.cuh Normal file
View file

@ -0,0 +1,111 @@
/*
* Copyright (c) 2011, Alex Krizhevsky (akrizhevsky@gmail.com)
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* - Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* - Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
* EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef DATA_CUH
#define DATA_CUH
#include <vector>
#include <algorithm>
#include "util.cuh"
class Data {
protected:
MatrixV* _data;
void assertDimensions() {
assert(_data->size() > 0);
for (int i = 1; i < _data->size(); i++) {
assert(_data->at(i-1)->getNumCols() == _data->at(i)->getNumCols());
assert(_data->at(i-1)->isTrans() == _data->at(i)->isTrans());
}
assert(_data->at(0)->getNumCols() > 0);
}
public:
typedef typename MatrixV::iterator T_iter;
// Cases in columns, but array may be transposed
// (so in memory they can really be in rows -- in which case the array is transposed
// during the copy to GPU).
Data(PyObject* pyData) {
_data = getMatrixV(pyData);
assertDimensions();
}
Data(MatrixV* data) : _data(data) {
assertDimensions();
}
~Data() {
for (T_iter it = _data->begin(); it != _data->end(); ++it) {
delete *it;
}
delete _data;
}
Matrix& operator [](int idx) const {
return *_data->at(idx);
}
int getSize() const {
return _data->size();
}
MatrixV& getData() const {
return *_data;
}
Matrix& getData(int i) const {
return *_data->at(i);
}
bool isTrans() const {
return _data->at(0)->isTrans();
}
int getNumCases() const {
return _data->at(0)->getNumCols();
}
};
typedef Data CPUData;
class DataProvider {
protected:
CPUData* _hData;
NVMatrixV _data;
int _minibatchSize;
public:
DataProvider(int minibatchSize);
void setData(CPUData&);
void clearData();
CPUData& getMinibatch(int idx);
CPUData& getDataSlice(int startCase, int endCase);
int getNumMinibatches();
int getMinibatchSize();
int getNumCases();
int getNumCasesInMinibatch(int idx);
};
#endif /* DATA_CUH */

51
include/hostmem.cuh Normal file
View file

@ -0,0 +1,51 @@
/*
* Copyright (c) 2011, Alex Krizhevsky (akrizhevsky@gmail.com)
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* - Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* - Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
* EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef HOSTMEM_CUH
#define HOSTMEM_CUH
#include <helper_cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
/*
* A utility class for transferring untyped memory from CPU to GPU and vice versa.
*/
class PinnedHostMem {
protected:
uint _numBytes;
void* _data;
public:
PinnedHostMem();
~PinnedHostMem();
void resize(uint bytes);
void copyFrom(void* src, uint bytes);
void copyTo(void* dst);
void* getData();
};
#endif /* HOSTMEM_CUH */

654
include/layer.cuh Normal file
View file

@ -0,0 +1,654 @@
/*
* Copyright (c) 2011, Alex Krizhevsky (akrizhevsky@gmail.com)
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* - Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* - Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
* EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef LAYER_CUH
#define LAYER_CUH
#include <algorithm>
#include <string>
#include <vector>
#include <map>
#include <assert.h>
#include <nvmatrix.cuh>
#include <multisoftmax.h>
#include <helper_timer.h>
#include "convnet.cuh"
#include "cost.cuh"
#include "weights.cuh"
#include "neuron.cuh"
#include "data.cuh"
#include "layer_kernels.cuh"
#include "hostmem.cuh"
#include "softmaxtree.cuh"
#include "pipedispenser.cuh"
class Cost;
class ConvNet;
class ConvNetGPU;
class CostLayer;
class DataLayer;
//class Message;
//class FpropMessage;
// The input matrix here is the squared norm.
// This replaces the squared norm with:
// 1 if it is below the threshold given by norm2
// norm/sqrt(a) otherwise -- i.e. the desired norm (not squared)
class WeightConstraintOperator {
private:
float _norm, _norm2;
public:
WeightConstraintOperator(float norm) : _norm(norm), _norm2(norm*norm) {
}
__device__ inline float operator()(const float a) const {
return a > _norm2 ? __fdividef(_norm, sqrtf(a)) : 1.0f;
}
};
class WeightContrastNormOperator {
private:
float _min, _max, _scale;
public:
WeightContrastNormOperator(float min, float max, float scale) : _min(min), _max(max), _scale(scale) {
}
__device__ inline float operator()(float a) const {
a = sqrtf(a) * _scale;
return a < _min ? __fdividef(_min, a) : a > _max ? __fdividef(_max, a) : 1.0f;
}
};
/*
* Abstract layer.
*/
class Layer {
protected:
ConvNetGPU* _convNetGPU;
std::vector<Layer*> _prev, _next;
int _rcvdFInputs;
std::map<int, int> _rcvdBInputs;
int _rcvdBInputMsgs;
int _numOutputs;
NVMatrixV _inputs;
std::map<int, NVMatrix*> _outputs;
std::map<int, NVMatrix*> _actsGrad; // Layer activity gradients
bool _gradConsumer, _foundGradConsumers, _trans;
bool _conserveMem;
bool _bwdTerminal;
int _numGradProducersNext;
int _actsTarget, _actsGradTarget;
std::string _name, _type;
int _deviceID;
intv _nextDeviceIDs;
HostNVMatrix _hostMemFwd, _hostMemBwd;
Quantizer* _fwdQuantizer, *_bwdQuantizer;
virtual void fpropNext(PASS_TYPE passType);
virtual void truncBwdActs();
virtual void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType) = 0;
virtual void bpropCommon(NVMatrix& v, PASS_TYPE passType) {
// Do nothing by default
}
virtual void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType) {
assert(!isGradProducer()); // Only do nothing if not grad producer
}
void shuffle(intv& v);
public:
static bool _saveActsGrad, _saveActs;
Layer(ConvNetGPU* convNetGPU, PyObject* paramsDict, bool trans);
virtual void fprop(PASS_TYPE passType);
void fprop(NVMatrix& v, PASS_TYPE passType);
virtual void fprop(NVMatrixV& v, PASS_TYPE passType);
virtual void bprop(PASS_TYPE passType);
virtual void bprop(NVMatrix& v, PASS_TYPE passType);
virtual void reset();
int getNumCases(NVMatrix& v);
int incRcvdBInputs(int deviceID);
int getRcvdFInputs();
int getRcvdBInputs(int deviceID);
int incRcvdBInputMsgs();
bool isGradConsumer();
bool hasGradProducerNext(std::string& layerName);
// Does this layer produce a gradient for any layer?
virtual bool isGradProducer();
// Does this layer produce a gradient for layer of given name?
virtual bool isGradProducer(std::string& layerName);
std::string& getName();
std::string& getType();
void addNext(Layer* l);
void addPrev(Layer* l);
std::vector<Layer*>& getPrev();
std::vector<Layer*>& getNext();
virtual NVMatrix& getActs();
virtual NVMatrix& getActs(int deviceID);
virtual NVMatrix& getActsGrad(int deviceID);
virtual NVMatrix& getActsGrad();
virtual void postInit();
int getDeviceID();
ConvNetGPU& getConvNetGPU();
ConvNet& getConvNet();
PipeDispenser& getPipeDispenser();
void setBwdTerminal(bool t);
// Do nothing if this layer has no weights
virtual bool updateWeights() {
return false;
}
virtual void checkGradients() {
}
virtual void copyToCPU() {
}
virtual void copyToGPU() {
}
};
class NeuronLayer : public Layer {
protected:
Neuron* _neuron;
string _neuronType;
virtual void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
virtual void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
NeuronLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
std::string& getNeuronType();
};
class WeightLayer : public Layer {
protected:
WeightList _weights;
Weights *_biases;
float _wStep, _bStep;
bool _gradComputed;
void bpropCommon(NVMatrix& v, PASS_TYPE passType);
virtual void bpropBiases(NVMatrix& v, PASS_TYPE passType) = 0;
virtual void bpropWeights(NVMatrix& v, int inpIdx, PASS_TYPE passType) = 0;
virtual void constrainWeights() = 0;
public:
WeightLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict, bool trans, bool useGrad, bool initWeights);
virtual bool updateWeights();
virtual void copyToCPU();
virtual void copyToGPU();
virtual void checkGradients();
Weights& getWeights(int idx);
};
class FCLayer : public WeightLayer {
protected:
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropBiases(NVMatrix& v, PASS_TYPE passType);
void bpropWeights(NVMatrix& v, int inpIdx, PASS_TYPE passType);
virtual void constrainWeights();
public:
FCLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict, bool useGrad, bool initWeights);
FCLayer();
};
class TreeFCLayer : public FCLayer {
protected:
TreeWeights* _treeWeights;
static void makeTree(PyObject* pyTree, SoftmaxNode& rootNode);
void constrainWeights();
public:
TreeFCLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
void checkGradients();
};
class SoftmaxLayer : public Layer {
protected:
bool _doLogregGrad;
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
SoftmaxLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
void setDoLogregGrad(bool b);
};
class ConcatenationLayer : public Layer {
protected:
intv* _copyOffsets;
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
ConcatenationLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
void setDoLogregGrad(bool b);
};
class EltwiseSumLayer : public Layer {
protected:
floatv* _coeffs;
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
EltwiseSumLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class EltwiseMaxLayer : public Layer {
protected:
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
EltwiseMaxLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class DataLayer : public Layer {
protected:
bool _useBuffer;
int _dataIdx;
int _bufferMinibatchIdx;
std::map<int, NVMatrix*> _outputs2; // Buffer for copying data during computation
CPUData* _bufferData;
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void postInit();
void copyData(CPUData& data, bool other);
void fpropNext(PASS_TYPE passType);
public:
DataLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
NVMatrix& getActs(int deviceID);
NVMatrix& getActs(int deviceID, bool other);
bool isGradProducer();
void fprop(PASS_TYPE passType);
void fprop(NVMatrixV& data, PASS_TYPE passType);
void setBuffer(CPUData& data, int minibatchIdx);
void startFprop(CPUData& data, PASS_TYPE passType);
void startFpropFromBuffer(PASS_TYPE passType);
int getBufferMinibatchIdx();
CPUData* getBufferData();
};
class LocalLayer : public WeightLayer {
protected:
struct FilterConns {
int* hFilterConns;
int* dFilterConns;
};
vector<FilterConns>* _filterConns;
intv* _padding, *_stride, *_filterSize, *_channels, *_imgSize, *_groups;
intv* _imgPixels, *_filterPixels, *_filterChannels, *_overSample, *_randSparse;
int _modulesX, _modules, _numFilters;
void copyToGPU();
public:
LocalLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict, bool useGrad);
};
class ConvLayer : public LocalLayer {
protected:
int _partialSum;
bool _sharedBiases;
floatv* _weightContrastNormMin, *_weightContrastNormMax;
NVMatrix _weightGradTmp, _actGradTmp;
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropBiases(NVMatrix& v, PASS_TYPE passType);
void bpropWeights(NVMatrix& v, int inpIdx, PASS_TYPE passType);
void truncBwdActs();
void constrainWeights();
public:
ConvLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class LocalUnsharedLayer : public LocalLayer {
protected:
NVMatrix _sexMask;
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropBiases(NVMatrix& v, PASS_TYPE passType);
void bpropWeights(NVMatrix& v, int inpIdx, PASS_TYPE passType);
void constrainWeights();
public:
LocalUnsharedLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class PoolLayer : public Layer {
protected:
int _channels, _sizeX, _start, _stride, _outputsX;
int _imgSize;
string _pool;
public:
PoolLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict, bool trans);
static PoolLayer& makePoolLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class AvgPoolLayer : public PoolLayer {
protected:
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
AvgPoolLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class MaxPoolLayer : public PoolLayer {
protected:
bool _abs;
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
MaxPoolLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict, bool abs);
};
class RandomPoolLayer : public PoolLayer {
protected:
bool _doMax;
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
RandomPoolLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class RandomScaleLayer : public Layer {
protected:
int _channels, _imgSize, _tgtSize, _minScaledSize;
float _maxScale; // should be >= 1
NVMatrix _rescaledActs;
std::vector<double> _scaleProbs;
public:
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
RandomScaleLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class NailbedLayer : public Layer {
protected:
int _channels, _start, _stride, _outputsX;
int _imgSize;
public:
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
NailbedLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class GaussianBlurLayer : public Layer {
protected:
int _channels;
Matrix* _hFilter;
NVMatrix _filter;
NVMatrix _actGradsTmp;
public:
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
void copyToGPU();
GaussianBlurLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class HorizontalReflectionLayer : public Layer {
protected:
int _channels, _imgSize;
public:
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
HorizontalReflectionLayer(ConvNetGPU* convNet, PyObject* paramsDict);
};
class ResizeLayer : public Layer {
protected:
int _channels;
float _scale;
int _imgSize, _tgtSize;
public:
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
ResizeLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class HiddenSexLayer : public Layer {
protected:
bool _enable;
float _keep;
NVMatrix _sexMask;
public:
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
void truncBwdActs();
HiddenSexLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class RGBToYUVLayer : public Layer {
public:
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
RGBToYUVLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class RGBToLABLayer : public Layer {
protected:
bool _center;
public:
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
RGBToLABLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class ResponseNormLayer : public Layer {
protected:
int _channels, _size;
float _scale, _pow;
NVMatrix _denoms;
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
void truncBwdActs();
public:
ResponseNormLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class CrossMapResponseNormLayer : public ResponseNormLayer {
protected:
bool _blocked;
float _minDiv;
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
CrossMapResponseNormLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class ContrastNormLayer : public ResponseNormLayer {
protected:
int _imgSize;
NVMatrix _meanDiffs;
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
void truncBwdActs();
public:
ContrastNormLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class CostLayer : public Layer {
protected:
float _coeff;
doublev _costv;
public:
CostLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict, bool trans);
void bprop(NVMatrix& v, PASS_TYPE passType);
// void bprop(PASS_TYPE passType); // Pure idiocy... it won't compile without this useless definition.
void fprop(PASS_TYPE passType);
virtual doublev& getCost();
float getCoeff();
bool isGradProducer();
void setSendTerminalMessages(bool send);
static CostLayer& makeCostLayer(ConvNetGPU* convNetGPU, string& type, PyObject* paramsDict);
};
/*
* Input 0: labels
* Input 1: softmax outputs
*/
class CrossEntCostLayer : public CostLayer {
protected:
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
CrossEntCostLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
/*
* Input 0: labels
* Input 1: softmax outputs
*/
class LogregCostLayer : public CostLayer {
protected:
NVMatrix _correctProbs, _topkProbs;
NVMatrix _probsAccum;
int _numAccumed;
int _topk;
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
LogregCostLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
NVMatrix& getProbsAccum();
};
/*
* Input 0: labels
* Input 1: logistic outputs
*/
class CrossEnt2CostLayer : public CostLayer {
protected:
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
CrossEnt2CostLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
class CrossEntOperator {
public:
__device__ inline float operator()(const float t, const float y) const {
return t * safelog(y) + (1.0f - t) * safelog(1.0f - y);
}
};
// Only for use with non-logistic units
class CrossEntGradientOperator {
private:
float _coeff;
public:
CrossEntGradientOperator(float coeff) : _coeff(coeff) {
}
__device__ inline float operator()(const float t, const float y) const {
return _coeff * (__fdividef(t, y) + __fdividef(1.0f - t, 1.0f - y));
}
};
};
/*
* Input 0: labels
* Input 1: logistic outputs
*/
class RobustFlickrCost : public CostLayer {
protected:
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
RobustFlickrCost(ConvNetGPU* convNetGPU, PyObject* paramsDict);
class RobustFlickrCostOperator {
public:
__device__ inline float operator()(const float t, const float y) const {
const float d = (y-t) * (y-t);
return __logf(1 + d);// - (t * safelog(y));
}
};
// Only for use with non-logistic units
class RobustFlickrCostGradientOperator {
private:
float _coeff;
public:
RobustFlickrCostGradientOperator(float coeff) : _coeff(coeff) {
}
__device__ inline float operator()(const float t, const float y) const {
const float d = y - t;
return -_coeff * (__fdividef(2.0f * d, 1.0f + d*d) /*- __fdividef(t, y)*/);
}
};
};
class SumOfSquaresCostLayer : public CostLayer {
protected:
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
SumOfSquaresCostLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
/*
* Input 0: labels
* Input 1: energies
*/
class MultiSoftmaxCostLayer : public CostLayer {
protected:
NVMatrix _probsT;
Matrix _cpuProbs, _cpuLabels, _energies_T_CPU;
std::vector<Matrix*> B;
int _setSize, _numOut, _threads;
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
MultiSoftmaxCostLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
void computeCost(bool useEnergies);
};
/*
* input 0: gates
* input 1: what to sum and square
*/
class GatedSumOfSquaresCostLayer : public CostLayer {
protected:
NVMatrix _ungated;
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
GatedSumOfSquaresCostLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
class TICACostLayer : public CostLayer {
protected:
int _sizeX, _channels;
void fpropActs(int inpIdx, float scaleTargets, PASS_TYPE passType);
void bpropActs(NVMatrix& v, int inpIdx, float scaleTargets, PASS_TYPE passType);
public:
TICACostLayer(ConvNetGPU* convNetGPU, PyObject* paramsDict);
};
#endif /* LAYER_CUH */

65
include/layer_kernels.cuh Normal file
View file

@ -0,0 +1,65 @@
/*
* Copyright (c) 2011, Alex Krizhevsky (akrizhevsky@gmail.com)
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* - Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* - Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
* EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef LAYER_KERNELS_CUH
#define LAYER_KERNELS_CUH
#include <vector>
#include <helper_cuda.h>
#include <nvmatrix.cuh>
#define LOGREG_GRAD_THREADS_X 32
#define LOGREG_GRAD_THREADS_Y 4
#define LOGREG_ERR_THREADS_X 128
#define LOGREG_ERR_THREADS_Y 1
__device__ inline float safelog(const float x) {
return x > 0.0f ? __logf(x) : -50.0f;
}
void computeCrossEntCost(NVMatrix& labels, NVMatrix& probs, NVMatrix& labelLogProbs_out, NVMatrix& correctProbs_out);
void computeCrossEntGrad(NVMatrix& labels, NVMatrix& probs, NVMatrix& target, bool add, float coeff);
void computeSoftmaxGrad(NVMatrix& acts, NVMatrix& actsGrad, NVMatrix& target, bool add);
void computeLogregCost(NVMatrix& labels, NVMatrix& probs, NVMatrix& labelLogProbs_out, NVMatrix& correctProbs_out);
void computeLogregGrad(NVMatrix& labels, NVMatrix& probs, NVMatrix& target, bool add, float coeff);
// Numerical stability optimization: this routine combines computeLogregGrad with computeSoftmaxGrad
// to avoi dividing and then multiplying by quantities that may be near zero.
void computeCrossEntSoftmaxGrad(NVMatrix& labels, NVMatrix& probs, NVMatrix& target, bool add, float coeff);
void computeLogregSoftmaxGrad(NVMatrix& labels, NVMatrix& probs, NVMatrix& target, bool add, float coeff);
void computeEltwiseMaxGrad(NVMatrix& actGrad, NVMatrix& input, NVMatrix& output, NVMatrix& target, bool add);
void MSMBackward(NVMatrix& energies, NVMatrix& bLattice, int setSize);
void MultiSoftmaxCPU(Matrix& elts, Matrix& B, Matrix& probs, int size, int fixed);
void MultiSoftmaxCPU_T(Matrix& elts, Matrix& B, Matrix& probs, Matrix& fixed, int size);
void computeMultiSoftmaxCost(NVMatrix& labels, NVMatrix& probs, NVMatrix& energies, NVMatrix& labelLogProbs_out,
NVMatrix& correctProbs_out, NVMatrix& top5Probs_out, int setSize, bool useEnergies);
#endif /* LAYER_KERNELS_CUH */

77
include/lr.cuh Normal file
View file

@ -0,0 +1,77 @@
#ifndef LR_CUH
#define LR_CUH
#include <string>
#include <vector>
#include <iostream>
#include <helper_cuda.h>
#include <assert.h>
#include <nvmatrix.cuh>
#include <matrix.h>
#include <util.cuh>
#include <Python.h>
/*
* The maximum learning rate is _baseRate.
* The minimum learning rate is _baseRate / _tgtFactor.
*
* These classes define annealing schedules that interpolate between these
* two extrema.
*/
class LearningRateSchedule {
protected:
double _baseRate, _noiseStdev, _randnSpare;
bool _haveRandnSpare;
virtual double _getRate(double progress);
double randn();
double rand() const;
double abs(double x) const;
public:
LearningRateSchedule(double base);
LearningRateSchedule(double base, double noiseStdev);
double getRate(double progress);
double getBaseRate() const;
virtual ~LearningRateSchedule();
static LearningRateSchedule& make(PyObject* lrsDict, double base);
};
class LinearLRS : public LearningRateSchedule {
protected:
double _finalRate;
public:
LinearLRS(double base, double tgtFactor, double noiseStdev);
virtual double _getRate(double progress);
};
class ExpLRS : public LearningRateSchedule {
protected:
double _pow;
public:
ExpLRS(double baseRate, double tgtFactor, double noiseStdev);
virtual double _getRate(double progress);
};
class TanhLRS : public LearningRateSchedule {
protected:
double _alpha, _beta;
public:
TanhLRS(double baseRate, double tgtFactor, double noiseStdev);
virtual double _getRate(double progress);
};
class DiscreteExpLRS : public LearningRateSchedule {
protected:
std::vector<double> _rates;
public:
DiscreteExpLRS(double baseRate, double tgtFactor, double noiseStdev, int numSteps);
virtual double _getRate(double progress);
};
class JumpyDiscreteExpLRS : public DiscreteExpLRS {
public:
JumpyDiscreteExpLRS(double baseRate, double tgtFactor, double noiseStdev, int numSteps);
virtual double _getRate(double progress);
};
#endif /* LR_CUH */

133
include/messages.cuh Normal file
View file

@ -0,0 +1,133 @@
/*
* messages.cuh
*
* Created on: 2013-02-25
* Author: spoon
*/
#ifndef MESSAGES_CUH_
#define MESSAGES_CUH_
#include <string>
enum MESSAGES { FPROP_TERMINAL,
BPROP_TERMINAL,
BPROP_READY,
FPROP_READY,
SYNC,
COPY_TO_CPU,
COPY_TO_GPU,
UPDATE_WEIGHTS,
RESET,
COST_COMPUTED,
BPROP_START,
// COPY,
// DEQUANTIZE,
RUNME};
class Message {
protected:
MESSAGES _messageType;
public:
MESSAGES getMessageType() {
return _messageType;
}
Message(MESSAGES messageType) : _messageType(messageType) {
}
virtual ~Message() {
}
};
/*
* A message that performs some simple function in its run method.
*/
class RunMeMessage : public Message {
public:
RunMeMessage() : Message(RUNME) {
}
virtual void run() = 0;
virtual ~RunMeMessage() {
}
};
class CopyMessage : public RunMeMessage {
protected:
NVMatrix* _src, *_tgt;
public:
CopyMessage(NVMatrix* src, NVMatrix* tgt) : _src(src), _tgt(tgt), RunMeMessage() {
}
void run() {
_src->copy(*_tgt);
}
~CopyMessage() {
assert(_src->isView());
delete _src;
}
};
class DequantizeMessage : public RunMeMessage {
protected:
Quantizer* _q;
NVMatrix *_tgt;
public:
DequantizeMessage(Quantizer* q, NVMatrix* tgt) : _q(q), _tgt(tgt), RunMeMessage() {
}
void run() {
_q->dequantize(*_tgt);
}
~DequantizeMessage() {
}
};
class PropMessage : public Message {
protected:
std::string _fromLayer, _toLayer;
PASS_TYPE _passType;
public:
std::string& getFromLayer() {
return _fromLayer;
}
std::string& getToLayer() {
return _toLayer;
}
PASS_TYPE getPassType() {
return _passType;
}
PropMessage(std::string fromLayer, std::string toLayer, PASS_TYPE passType, MESSAGES msgType)
: _fromLayer(fromLayer), _toLayer(toLayer), _passType(passType), Message(msgType) {
}
};
class FpropMessage : public PropMessage {
public:
FpropMessage(std::string fromLayer, std::string toLayer, PASS_TYPE passType)
: PropMessage(fromLayer, toLayer, passType, FPROP_READY) {
}
};
class BpropMessage : public PropMessage {
public:
BpropMessage(std::string fromLayer, std::string toLayer, PASS_TYPE passType)
: PropMessage(fromLayer, toLayer, passType, BPROP_READY) {
}
};
class BpropStartMessage : public Message {
protected:
PASS_TYPE _passType;
public:
PASS_TYPE getPassType() {
return _passType;
}
BpropStartMessage(PASS_TYPE passType)
: _passType(passType), Message(BPROP_START) {
}
};
#endif /* MESSAGES_CUH_ */

38
include/multisoftmax.h Normal file
View file

@ -0,0 +1,38 @@
/*
* File: multisoftmax.h
* Author: Alex Krizhevsky
*
* Created on May 9, 2012, 5:36 PM
*/
#ifndef MULTISOFTMAX_H
#define MULTISOFTMAX_H
#include <algorithm>
#include <thread.h>
#include <matrix.h>
#include <vector>
#ifndef DIVUP
#define DIVUP(x, y) (((x) + (y) - 1) / (y))
#endif
#define EXP exp
#define LOG log
#define INF 1e35f
class MultiSoftmaxWorker : public Thread {
protected:
Matrix* _elts, *_B, *_probs, *_fixed;
int _size;
bool _nofix;
void* run();
public:
MultiSoftmaxWorker(Matrix* elts, Matrix* B, Matrix* probs, Matrix* _fixed, int size, bool nofix);
virtual ~MultiSoftmaxWorker();
};
void MultiSoftmaxCPU_T_parallel(Matrix& elts, std::vector<Matrix*>& B, Matrix& probs, Matrix& fixed, int size, bool nofix);
#endif /* MULTISOFTMAX_H */

529
include/neuron.cuh Normal file
View file

@ -0,0 +1,529 @@
/*
* Copyright (c) 2011, Alex Krizhevsky (akrizhevsky@gmail.com)
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* - Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* - Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
* EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef NEURONS_CUH
#define NEURONS_CUH
#include <assert.h>
#include <string>
#include <nvmatrix.cuh>
#include <helper_cuda.h>
template <class GradientOp>
class AddGradientBinaryOperator {
GradientOp _op;
public:
AddGradientBinaryOperator(GradientOp op) : _op(op) {
}
__device__ inline float operator()(const float unitActGrad, const float unitAct, const float target) const {
return _op(unitActGrad, unitAct) + target;
}
};
template <class GradientOp>
class AddGradientOperator {
GradientOp _op;
public:
AddGradientOperator(GradientOp op) : _op(op) {
}
__device__ inline float operator()(const float unitActGrad, const float target) const {
return target + _op(unitActGrad);
}
};
/* =======================
* Neuron
* -----------------------
*
* f(x) = x
* =======================
*/
class Neuron {
protected:
bool _activated;
// Inputs and outputs potentially point to the same matrix, depending on the neuron
NVMatrix* _inputs, *_outputs;
virtual void _activate() {
if (_inputs != _outputs) {
_inputs->copy(*_outputs);
}
}
virtual void _computeInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
if (&target != &actsGrad) {
actsGrad.copy(target);
}
}
virtual void _addInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
if (&target != &actsGrad) {
target.add(actsGrad);
}
}
public:
Neuron() : _activated(false), _inputs(NULL), _outputs(NULL) {
}
virtual void activate(NVMatrix& inputs, NVMatrix& outputs) {
_activated = true;
_inputs = &inputs;
_outputs = &outputs;
_activate();
}
virtual void computeInputGrad(NVMatrix& actsGrad, NVMatrix& target, bool add) {
assert(_activated);
if (!add) {
target.resize(actsGrad);
_computeInputGrad(actsGrad, target);
} else {
_addInputGrad(actsGrad, target);
}
}
static Neuron& makeNeuron(PyObject* neuronDict);
};
/* =======================
* LogisticNeuron
* -----------------------
*
* f(x) = 1 / (1 + e^-x)
* =======================
*/
class LogisticNeuron : public Neuron {
protected:
void _activate() {
_inputs->apply(NVMatrixOps::Logistic(), *_outputs);
}
void _computeInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyBinary(LogisticGradientOperator(), *_outputs, target);
}
void _addInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyTernary(AddGradientBinaryOperator<LogisticGradientOperator>(LogisticGradientOperator()), *_outputs, target, target);
}
public:
class LogisticGradientOperator {
public:
__device__ inline float operator()(float unitActGrad, float unitAct) const {
return unitActGrad * unitAct * (1.0f - unitAct);
}
};
LogisticNeuron() : Neuron() {
}
};
/* =======================
* ReluNeuron
* -----------------------
*
* f(x) = max(0, x)
* =======================
*/
class ReluNeuron : public Neuron {
protected:
virtual void _activate() {
_inputs->apply(ReluOperator(), *_outputs);
}
void _computeInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyBinary(ReluGradientOperator(), *_outputs, target);
}
void _addInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyTernary(AddGradientBinaryOperator<ReluGradientOperator>(ReluGradientOperator()), *_outputs, target, target);
}
public:
class ReluOperator {
public:
__device__ inline float operator()(float x) const {
return x < 0.0f ? 0.0f : x;
}
};
class ReluGradientOperator {
public:
__device__ inline float operator()(float unitActGrad, float unitAct) const {
return unitActGrad * (unitAct > 0.0f);
}
};
ReluNeuron() : Neuron() {
}
};
/* =======================
* NoisyReluNeuron
* -----------------------
*
* f(x) = max(0, max(0, x) + gaussian noise with variance equal to max(0, x))
* =======================
*/
class NoisyReluNeuron : public ReluNeuron {
protected:
void _activate() {
ReluNeuron::_activate();
_outputs->addGaussianNoise(*_outputs, false);
_outputs->apply(ReluOperator());
}
public:
NoisyReluNeuron() : ReluNeuron() {
}
};
/* =======================
* BoundedReluNeuron
* -----------------------
*
* f(x) = min(a, max(0, x))
* =======================
*/
class BoundedReluNeuron : public Neuron {
protected:
float _a;
void _activate() {
_inputs->apply(BoundedReluOperator(_a), *_outputs);
}
void _computeInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyBinary(BoundedReluGradientOperator(_a), *_outputs, target);
}
void _addInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyTernary(AddGradientBinaryOperator<BoundedReluGradientOperator>(BoundedReluGradientOperator(_a)), *_outputs, target, target);
}
public:
class BoundedReluOperator {
private:
float _a;
public:
BoundedReluOperator(float a) : _a(a) {
}
__device__ inline float operator()(float x) const {
return x < 0.0f ? 0.0f : x > _a ? _a : x;
}
};
class BoundedReluGradientOperator {
private:
float _a;
public:
BoundedReluGradientOperator(float a) : _a(a) {
}
__device__ inline float operator()(float unitActGrad, float unitAct) const {
return unitActGrad * (unitAct > 0.0f) * (unitAct < _a);
}
};
BoundedReluNeuron(float a) : Neuron(), _a(a) {
}
};
/* =======================
* AbsNeuron
* -----------------------
*
* f(x) = abs(x)
* =======================
*/
class AbsNeuron : public Neuron {
protected:
void _activate() {
assert(_inputs != _outputs);
_inputs->apply(NVMatrixOps::Abs(), *_outputs);
}
void _computeInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyBinary(AbsGradientOperator(), *_inputs, target);
}
void _addInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyTernary(AddGradientBinaryOperator<AbsGradientOperator>(AbsGradientOperator()), *_inputs, target, target);
}
public:
class AbsGradientOperator {
public:
__device__ inline float operator()(float unitActGrad, float unitInput) const {
return unitActGrad * (unitInput > 0.0f ? 1.0f : -1.0f);
}
};
AbsNeuron() : Neuron() {
}
};
/* =======================
* TanhNeuron
* -----------------------
*
* f(x) = a*tanh(b*x)
* =======================
*/
class TanhNeuron : public Neuron {
protected:
float _a, _b;
void _activate() {
_inputs->apply(TanhOperator(_a, _b), *_outputs);
}
void _computeInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyBinary(TanhGradientOperator(_a, _b), *_outputs, target);
}
void _addInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyTernary(AddGradientBinaryOperator<TanhGradientOperator>(TanhGradientOperator(_a, _b)), *_outputs, target, target);
}
public:
class TanhOperator {
private:
float _a, _n2b;
public:
TanhOperator(float a, float b) : _a(a), _n2b(-2*b) {
}
virtual __device__ inline float operator()(float x) const {
return _a * (__fdividef(2.0f, 1.0f + __expf(x * _n2b)) - 1.0f);
}
};
class TanhGradientOperator {
private:
float _b, _a;
public:
TanhGradientOperator(float a, float b) : _b(b), _a(a) {
}
__device__ inline float operator()(float unitActGrad, float unitAct) const {
// const float t = (1.0f - __fdividef(unitAct, _a)) / 2.0f;
// return unitActGrad * _n4ab * (t * (t - 1.0f));
return unitActGrad * _b * (_a - __fdividef(unitAct * unitAct, _a));
}
};
TanhNeuron(float a, float b) : Neuron(), _a(a), _b(b) {
}
};
/* =======================
* DoubleReluNeuron
* -----------------------
*
* f(x) = x - a*tanh(x/a)
* =======================
*/
class DoubleReluNeuron : public Neuron {
protected:
float _a;
void _activate() {
assert(_inputs != _outputs);
_inputs->apply(DoubleReluOperator(_a), *_outputs);
}
void _computeInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyBinary(DoubleReluGradientOperator(_a), *_inputs, target);
}
void _addInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyTernary(AddGradientBinaryOperator<DoubleReluGradientOperator>(DoubleReluGradientOperator(_a)), *_inputs, target, target);
}
public:
class DoubleReluOperator {
private:
float _a, _n2a;
public:
DoubleReluOperator(float a) : _a(a), _n2a(-2.0f / a) {
}
virtual __device__ inline float operator()(float x) const {
return x - _a * (__fdividef(2.0f, 1.0f + __expf(_n2a * x)) - 1.0f);
}
};
class DoubleReluGradientOperator {
private:
float _n2a;
public:
DoubleReluGradientOperator(float a) : _n2a(-2.0f / a) {
}
__device__ inline float operator()(float unitActGrad, float unitInput) const {
const float tanh = __fdividef(2.0f, 1.0f + __expf(_n2a * unitInput)) - 1.0f;
return unitActGrad * (tanh*tanh);
}
};
DoubleReluNeuron(float a) : Neuron(), _a(a) {
}
};
/* =======================
* SoftReluNeuron
* -----------------------
*
* f(x) = log(1 + e^x)
* =======================
*/
class SoftReluNeuron : public Neuron {
protected:
void _activate() {
assert(_inputs != _outputs);
_inputs->apply(SoftReluOperator(), *_outputs);
}
void _computeInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyBinary(SoftReluGradientOperator(), *_inputs, target);
}
void _addInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyTernary(AddGradientBinaryOperator<SoftReluGradientOperator>(SoftReluGradientOperator()), *_inputs, target, target);
}
public:
class SoftReluOperator {
public:
__device__ inline float operator()(float x) const {
// This piece-wise implementation has better numerical stability than
// simply computing log(1 + e^x).
return x > 4.0f ? x : __logf(1.0f + __expf(x));
}
};
class SoftReluGradientOperator {
public:
__device__ inline float operator()(float unitActGrad, float unitInput) const {
if (unitInput > 4.0f) {
return unitActGrad;
}
const float f = __expf(unitInput);
return unitActGrad * __fdividef(f, 1.0f + f);
}
};
SoftReluNeuron() : Neuron() {
}
};
/* =======================
* SquareNeuron
* -----------------------
*
* f(x) = x^2
* =======================
*/
class SquareNeuron : public Neuron {
protected:
void _activate() {
assert(_inputs != _outputs);
_inputs->apply(NVMatrixOps::Square(), *_outputs);
}
void _computeInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyBinary(SquareGradientOperator(), *_inputs, target);
}
void _addInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyTernary(AddGradientBinaryOperator<SquareGradientOperator>(SquareGradientOperator()), *_inputs, target, target);
}
public:
class SquareGradientOperator {
public:
__device__ inline float operator()(float unitActGrad, float unitInput) const {
return unitActGrad * 2.0f * unitInput;
}
};
SquareNeuron() : Neuron() {
}
};
/* =======================
* SqrtNeuron
* -----------------------
*
* f(x) = sqrt(x)
* =======================
*/
class SqrtNeuron : public Neuron {
protected:
void _activate() {
_inputs->apply(NVMatrixOps::Sqrt(), *_outputs);
}
void _computeInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyBinary(SqrtGradientOperator(), *_outputs, target);
}
void _addInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyTernary(AddGradientBinaryOperator<SqrtGradientOperator>(SqrtGradientOperator()), *_outputs, target, target);
}
public:
class SqrtGradientOperator {
public:
__device__ inline float operator()(float unitActGrad, float unitAct) const {
return __fdividef(unitActGrad, 2.0f * unitAct);
}
};
SqrtNeuron() : Neuron() {
}
};
/* =======================
* LinearNeuron
* -----------------------
*
* f(x) = a*x + b
* =======================
*/
class LinearNeuron : public Neuron {
protected:
float _a, _b;
void _activate() {
_inputs->apply(LinearOperator(_a, _b), *_outputs);
}
void _computeInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.scale(_a, target);
}
void _addInputGrad(NVMatrix& actsGrad, NVMatrix& target) {
actsGrad.applyBinary(AddGradientOperator<NVMatrixOps::MultByScalar>(NVMatrixOps::MultByScalar(_a)), target, target);
}
public:
class LinearOperator {
protected:
float _a, _b;
public:
__device__ inline float operator()(float x) const {
return _a * x + _b;
}
LinearOperator(float a, float b) : _a(a), _b(b) {
}
};
LinearNeuron(float a, float b) : Neuron(), _a(a), _b(b) {
}
};
#endif /* NEURONS_CUH */

139
include/pipedispenser.cuh Normal file
View file

@ -0,0 +1,139 @@
/*
* pipedispenser.cuh
*
* Created on: 2013-03-01
* Author: spoon
*/
#ifndef PIPEDISPENSER_CUH_
#define PIPEDISPENSER_CUH_
#include <pthread.h>
#include <set>
#include <algorithm>
#include <iterator>
#include <util.cuh>
class PipeDispenser {
protected:
int _numPipes;
seti _pipes;
pthread_mutex_t *_mutex;
void lock() {
pthread_mutex_lock(_mutex);
}
void unlock() {
pthread_mutex_unlock(_mutex);
}
public:
PipeDispenser(const seti& pipes) {
_pipes.insert(pipes.begin(), pipes.end());
_mutex = (pthread_mutex_t*)(malloc(sizeof (pthread_mutex_t)));
pthread_mutex_init(_mutex, NULL);
}
virtual ~PipeDispenser() {
pthread_mutex_destroy(_mutex);
free(_mutex);
}
virtual int getPipe(const seti& interested) = 0;
int getPipe(int interested) {
seti tmp;
tmp.insert(interested);
return getPipe(tmp);
}
virtual void freePipe(int pipe) = 0;
};
/*
* This one blocks until there is a free pipe to return.
*/
class PipeDispenserBlocking : public PipeDispenser {
protected:
pthread_cond_t *_cv;
void wait() {
pthread_cond_wait(_cv, _mutex);
}
void broadcast() {
pthread_cond_broadcast(_cv);
}
int getAvailablePipes(const seti& interested, intv& available) {
available.clear();
std::set_intersection(_pipes.begin(), _pipes.end(), interested.begin(), interested.end(), std::back_inserter(available));
return available.size();
}
public:
PipeDispenserBlocking(const seti& pipes) : PipeDispenser(pipes) {
_cv = (pthread_cond_t*)(malloc(sizeof (pthread_cond_t)));
pthread_cond_init(_cv, NULL);
}
~PipeDispenserBlocking() {
pthread_cond_destroy(_cv);
free(_cv);
}
int getPipe(const seti& interested) {
lock();
intv avail;
while (getAvailablePipes(interested, avail) == 0) {
wait();
}
int pipe = avail[0];
_pipes.erase(pipe);
unlock();
return pipe;
}
void freePipe(int pipe) {
lock();
_pipes.insert(pipe);
broadcast();
unlock();
}
};
/*
* This one returns the least-occupied pipe.
*/
class PipeDispenserNonBlocking : public PipeDispenser {
protected:
std::map<int,int> _pipeUsers;
public:
PipeDispenserNonBlocking(const seti& pipes) : PipeDispenser(pipes) {
for (seti::iterator it = pipes.begin(); it != pipes.end(); ++it) {
_pipeUsers[*it] = 0;
}
}
int getPipe(const seti& interested) {
lock();
int pipe = -1, users = 1 << 30;
for (seti::iterator it = _pipes.begin(); it != _pipes.end(); ++it) {
if (interested.count(*it) > 0 && _pipeUsers[*it] < users) {
pipe = *it;
users = _pipeUsers[*it];
}
}
if (pipe >= 0) {
_pipeUsers[pipe]++;
}
unlock();
return pipe;
}
void freePipe(int pipe) {
lock();
_pipeUsers[pipe]--;
unlock();
}
};
#endif /* PIPEDISPENSER_CUH_ */

43
include/pyconvnet.cuh Normal file
View file

@ -0,0 +1,43 @@
/*
* Copyright (c) 2011, Alex Krizhevsky (akrizhevsky@gmail.com)
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* - Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* - Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
* EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef PYCONVNET3_CUH
#define PYCONVNET3_CUH
#define _QUOTEME(x) #x
#define QUOTEME(x) _QUOTEME(x)
extern "C" void INITNAME();
PyObject* initModel(PyObject *self, PyObject *args);
PyObject* startBatch(PyObject *self, PyObject *args);
PyObject* finishBatch(PyObject *self, PyObject *args);
PyObject* checkGradients(PyObject *self, PyObject *args);
PyObject* syncWithHost(PyObject *self, PyObject *args);
PyObject* startMultiviewTest(PyObject *self, PyObject *args);
PyObject* startFeatureWriter(PyObject *self, PyObject *args);
PyObject* startDataGrad(PyObject *self, PyObject *args);
#endif

43
include/quantizer.cuh Normal file
View file

@ -0,0 +1,43 @@
/*
* quantizer.cuh
*
* Created on: 2013-02-15
* Author: spoon
*/
#ifndef QUANTIZER_CUH_
#define QUANTIZER_CUH_
#include <Python.h>
#include <util.cuh>
#include <string>
#include <nvmatrix.cuh>
#include <conv_util.cuh>
class Quantizer {
protected:
NVMatrix* _quantized;
int _numRows, _numCols;
bool _trans;
virtual void _quantize(NVMatrix& src, NVMatrix& tgt);
virtual void _dequantize(NVMatrix& tgt, float scaleTarget, float scaleOutput);
public:
Quantizer();
virtual ~Quantizer();
void quantize(NVMatrix& src, NVMatrix& tgt);
void dequantize(NVMatrix& tgt);
void dequantize(NVMatrix& tgt, float scaleTarget, float scaleOutput);
static Quantizer& make(PyObject* qDict);
};
class HalfQuantizer : public Quantizer {
protected:
void _quantize(NVMatrix& src, NVMatrix& tgt);
void _dequantize(NVMatrix& tgt, float scaleTarget, float scaleOutput);
public:
HalfQuantizer();
};
#endif /* QUANTIZER_CUH_ */

144
include/softmaxtree.cuh Normal file
View file

@ -0,0 +1,144 @@
/*
* File: softmaxtree.h
* Author: Alex Krizhevsky
*
* Created on September 9, 2012, 5:50 PM
*/
#ifndef SOFTMAXTREE_H
#define SOFTMAXTREE_H
#include <helper_cuda.h>
#include <string>
#include <map>
#include <vector>
#include <algorithm>
#include <assert.h>
#include <nvmatrix.cuh>
#include <matrix.h>
class SoftmaxNode;
class SoftmaxTree;
typedef std::vector<SoftmaxNode*> SoftmaxNodeV;
class SoftmaxNode {
friend class SoftmaxTree;
protected:
SoftmaxNodeV _children;
SoftmaxNode* _parent;
int _depth, _height, _size;
int _label;
/*
* Computes height for entire subtree rooted at this node and populates
* given height->nodes map.
*/
int setDistances(std::map<int, SoftmaxNodeV*>& nodeHeights,
std::map<int, SoftmaxNodeV*>& nodeDepths);
void setNodeCounts(int &nodes, int& leaves);
/*
* Compute the number of leaves in this subtree, which is a good estimate
* of the number of training cases it represents.
*/
int setSizes(ushort* nodeSizes);
public:
SoftmaxNode(SoftmaxNode* parent, int label);
~SoftmaxNode();
SoftmaxNode& addChild(int label);
int getDepth() const;
int getHeight() const;
int getLabel() const;
int getSize() const;
SoftmaxNode* getParent(); // Might be null, so must be pointer
SoftmaxNodeV& getChildren();
};
/*
* numLabels: the number of leaves in the tree (normally 1000)
* numNodes: the total number of nodes in the tree
*/
class SoftmaxTree {
friend class SoftmaxNode;
protected:
SoftmaxNode* _root;
std::map<int, SoftmaxNodeV*> _nodeHeights, _nodeDepths;
/*
* Map from depth --> ushort2[]
* where each ushort2 gives the index and parent index
* of a node at the given depth.
*/
std::map<int, ushort2*> _nodeFwdMeta;
/*
* Map from height --> ushort2[]
* where each ushort2 gives the index and number of children
* of a node at the given height.
*/
std::map<int, ushort2*> _nodeBwdMeta;
/*
* Map from height --> ushort[][]
* where each ushort[] gives children of a given node at a given height.
*/
std::map<int, ushort**> _nodeChildMeta;
/*
* An array of length numNodes with index i storing the number
* of leaves in subtree rooted at node with label i.
*/
ushort* _nodeSizes;
int _numNodes, _numLeaves;
void setDistances();
void setNodeCounts();
void setNodeSizes();
void setFwdMeta();
void setBwdMeta();
void preprocess(NVMatrix& inp);
void postprocess(NVMatrix& inp);
public:
SoftmaxTree(int rootLabel);
~SoftmaxTree();
void finalize();
SoftmaxNode& getRoot();
SoftmaxNodeV& getNodesAtHeight(int height);
SoftmaxNodeV& getNodesAtDepth(int depth);
int getHeight() const;
int getDepth() const;
int getNumLeaves() const;
int getNumNodes() const;
/*
* offsets: (numNodes, numFeatures)
* targets: (numNodes, numFeatures)
*/
void makeWeights(NVMatrix& offsets, NVMatrix& targets);
/*
* grads: (numNodes, numFeatures)
*
* The idea is that grads contains gradients for the leaves
* (i.e. the first numLabels rows), so this routine will
* distribute them up the tree.
*/
void distributeGradients(NVMatrix& grads);
/*
* inc := mom * inc - wc * epsW * weight + epsW * grad
* weight := weight + inc
*
* weights: (numNodes, numFeatures)
* incs: (numNodes, numFeatures)
* grads: (numNodes , numFeatures)
*/
void updateWeights(NVMatrix& weights, NVMatrix& incs, NVMatrix& grads, float epsWBase, float mom, float wcBase);
};
#endif /* SOFTMAXTREE_H */

113
include/util.cuh Normal file
View file

@ -0,0 +1,113 @@
/*
* Copyright (c) 2011, Alex Krizhevsky (akrizhevsky@gmail.com)
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* - Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* - Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
* EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef UTIL_H
#define UTIL_H
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <vector>
#include <map>
#include <set>
#include <string>
#include <sstream>
#include <string>
#include <Python.h>
#include <nvmatrix.cuh>
#include <matrix.h>
/*
* The types of passes that the convnet supports. Used in the fprop and bprop functions in
* ConvNet class. Most of the layers ignore the pass type, but some make use of it.
*/
//enum PASS_TYPE {PASS_TRAIN,
// PASS_TEST,
// PASS_GC,
// PASS_MULTIVIEW_TEST,
// PASS_MULTIVIEW_TEST_START,
// PASS_MULTIVIEW_TEST_END,
// PASS_FEATURE_GEN};
#define PASS_TYPE uint
#define PASS_TRAIN 0x1
#define PASS_TEST 0x2
#define PASS_GC 0x4
#define PASS_MULTIVIEW_TEST (PASS_TEST | 0x8)
#define PASS_MULTIVIEW_TEST_START (PASS_MULTIVIEW_TEST | 0x10)
#define PASS_MULTIVIEW_TEST_END (PASS_MULTIVIEW_TEST | 0x20)
#define PASS_FEATURE_GEN 0x40
#define HAS_FLAG(f, x) (((x) & (f)) == (f))
#define IS_MULTIVIEW_TEST(x) HAS_FLAG(PASS_MULTIVIEW_TEST, x)
#define IS_MULTIVIEW_TEST_START(x) HAS_FLAG(PASS_MULTIVIEW_TEST_START, x)
#define IS_MULTIVIEW_TEST_END(x) HAS_FLAG(PASS_MULTIVIEW_TEST_END, x)
// For gradient checking
#define GC_SUPPRESS_PASSES false
#define GC_REL_ERR_THRESH 0.02
/*
* Generates a random floating point number in the range 0-1.
*/
#define randf ((float)rand() / RAND_MAX)
typedef std::vector<Matrix*> MatrixV;
typedef std::vector<NVMatrix*> NVMatrixV;
typedef std::map<std::string,std::vector<double>*> CostMap;
typedef std::map<std::string,double> CostCoeffMap;
typedef std::vector<double> doublev;
typedef std::vector<float> floatv;
typedef std::vector<int> intv;
typedef std::vector<std::string> stringv;
typedef std::set<int> seti;
stringv* getStringV(PyObject* pyList);
floatv* getFloatV(PyObject* pyList);
intv* getIntV(PyObject* pyList);
MatrixV* getMatrixV(PyObject* pyList);
MatrixV* getMatrixV(PyObject* pyList, int len);
int* getIntA(PyObject* pyList);
int pyDictGetInt(PyObject* dict, const char* key);
intv* pyDictGetIntV(PyObject* dict, const char* key);
std::string pyDictGetString(PyObject* dict, const char* key);
float pyDictGetFloat(PyObject* dict, const char* key);
floatv* pyDictGetFloatV(PyObject* dict, const char* key);
Matrix* pyDictGetMatrix(PyObject* dict, const char* key);
MatrixV* pyDictGetMatrixV(PyObject* dict, const char* key);
int* pyDictGetIntA(PyObject* dict, const char* key);
stringv* pyDictGetStringV(PyObject* dict, const char* key);
template<typename T>
std::string tostr(T n) {
std::ostringstream result;
result << n;
return result.str();
}
#endif /* UTIL_H */

150
include/weights.cuh Normal file
View file

@ -0,0 +1,150 @@
/*
* Copyright (c) 2011, Alex Krizhevsky (akrizhevsky@gmail.com)
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* - Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* - Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
* EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef WEIGHTS_CUH
#define WEIGHTS_CUH
#include <string>
#include <vector>
#include <iostream>
#include <helper_cuda.h>
#include <assert.h>
#include <nvmatrix.cuh>
#include <matrix.h>
#include "util.cuh"
#include "softmaxtree.cuh"
#include <lr.cuh>
using namespace std;
class Weights {
protected:
Matrix* _hWeights, *_hWeightsInc;
NVMatrix* _weights, *_weightsInc, *_weightsGrad;
NVMatrix* _weightsGradAvg, *_weightsGrad2Avg;
LearningRateSchedule* _lrs;
float _wc, _mom, _wball, _superEps;
bool _onGPU, _useGrad, _cleanup;
int _numUpdates;
// Non-NULL if these weights are really shared from some other layer
Weights* _srcWeights;
public:
class Grad2AvgOperator {
private:
float _mom;
public:
Grad2AvgOperator(float mom) : _mom(mom) {
}
__device__ inline float operator()(const float G2, const float g) const {
return _mom * G2 + (1.0f - _mom) * g * g;
}
};
NVMatrix& operator*() const;
Weights(Weights& srcWeights, LearningRateSchedule& lrs);
Weights(Matrix& hWeights, Matrix& hWeightsInc, LearningRateSchedule& lrs, float wc, float wball, float mom, float superEps, bool useGrad, bool cleanup=true);
virtual ~Weights();
virtual NVMatrix& getW() const;
virtual NVMatrix& getInc() const;
virtual NVMatrix& getGrad() const;
virtual Matrix& getCPUW() const;
virtual Matrix& getCPUWInc() const;
virtual LearningRateSchedule& getLearningRateSchedule() const;
virtual int getNumRows() const;
virtual int getNumCols() const;
virtual void copyToCPU();
// This function is assumed to be called in the order in which the layers
// were defined
virtual void copyToGPU();
virtual void update(float progress);
int incNumUpdates();
// Returns the number of times a gradient has been computed for this
// weight matrix during the current pass (interval between two calls of update())
// through the net. This number will only be greater than 1 if this weight matrix
// is *shared* by multiple layers in the net.
int getNumUpdates() const;
float getEps(float progress) const;
float getMom() const;
float getWC() const;
float getWBall() const;
bool isUseGrad() const;
bool isOwner() const;
float getSuperEps() const;
};
class TreeWeights : public Weights {
protected:
NVMatrix _effWeights;
NVMatrix* _leafWeights, *_leafGrad, *_leafInc;
SoftmaxTree* _tree;
public:
void copyToGPU();
void update(float progress);
NVMatrix& getW() const;
NVMatrix& getInc() const;
NVMatrix& getGrad() const;
NVMatrix& getAllW() const;
NVMatrix& getAllInc() const;
NVMatrix& getAllGrad() const;
int getNumRows() const;
void makeWeights();
void distributeGradients();
TreeWeights(SoftmaxTree& tree, Matrix& hWeights, Matrix& hWeightsInc, LearningRateSchedule& lrs, float wcBase, float mom);
};
class DummyWeights : public Weights {
public:
DummyWeights(Matrix& hWeights, Matrix& hWeightsInc, NVMatrix& weights, NVMatrix& incs, NVMatrix& grads);
};
class WeightList {
private:
std::vector<Weights*> _weightList;
public:
Weights& operator[](const int idx) const;
~WeightList();
WeightList();
void addWeights(Weights& w);
void update(float progress);
void copyToCPU();
void copyToGPU();
int getSize() const;
};
#endif /* WEIGHTS_CUH */

122
include/worker.cuh Normal file
View file

@ -0,0 +1,122 @@
/*
* Copyright (c) 2011, Alex Krizhevsky (akrizhevsky@gmail.com)
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* - Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* - Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
* EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef WORKER_CUH
#define WORKER_CUH
#include "convnet.cuh"
#include "cost.cuh"
#include "data.cuh"
class ConvNet;
class Cost;
class WorkResult {
public:
enum RESULTS {BATCH_DONE, SYNC_DONE};
protected:
WorkResult::RESULTS _resultType;
Cost* _results;
public:
WorkResult(WorkResult::RESULTS resultType, Cost& results);
WorkResult(WorkResult::RESULTS resultType);
virtual ~WorkResult();
Cost& getResults() const;
WorkResult::RESULTS getResultType() const;
};
class Worker {
protected:
ConvNet* _convNet;
public:
Worker(ConvNet& convNet);
virtual void run() = 0;
};
class DataWorker : public Worker {
protected:
CPUData* _data;
DataProvider* _dp;
public:
DataWorker(ConvNet& convNet, CPUData& data);
virtual ~DataWorker();
};
class TrainingWorker : public DataWorker {
protected:
bool _test;
double _progress;
public:
TrainingWorker(ConvNet& convNet, CPUData& data, double progress, bool test);
void run();
};
class SyncWorker : public Worker {
public:
SyncWorker(ConvNet& convNet);
void run();
};
class GradCheckWorker : public DataWorker {
public:
GradCheckWorker(ConvNet& convNet, CPUData& data);
void run();
};
class MultiviewTestWorker : public DataWorker {
protected:
int _numViews;
Matrix* _cpuProbs;
std::string _logregName;
public:
MultiviewTestWorker(ConvNet& convNet, CPUData& data, int numViews, Matrix& cpuProbs, const char* softmaxName);
MultiviewTestWorker(ConvNet& convNet, CPUData& data, int numViews);
~MultiviewTestWorker();
virtual void run();
};
class FeatureWorker : public DataWorker {
protected:
MatrixV *_ftrs;
stringv *_layerNames;
public:
FeatureWorker(ConvNet& convNet, CPUData& data, MatrixV& ftrs, stringv& layerNames);
~FeatureWorker();
void run();
};
class DataGradWorker : public DataWorker {
protected:
Matrix* _dataGrads;
int _dataLayerIdx, _softmaxLayerIdx;
public:
DataGradWorker(ConvNet& convNet, CPUData& data, Matrix& dataGrads, int dataLayerIdx, int softmaxLayerIdx);
~DataGradWorker();
void run();
};
#endif /* WORKER_CUH */