AlexNet/include/worker.cuh
Laurent El Shafey 9fdd561586 Initial commit
2024-12-10 08:56:11 -08:00

122 lines
3.5 KiB
Text

/*
* Copyright (c) 2011, Alex Krizhevsky (akrizhevsky@gmail.com)
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification,
* are permitted provided that the following conditions are met:
*
* - Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* - Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
* EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#ifndef WORKER_CUH
#define WORKER_CUH
#include "convnet.cuh"
#include "cost.cuh"
#include "data.cuh"
class ConvNet;
class Cost;
class WorkResult {
public:
enum RESULTS {BATCH_DONE, SYNC_DONE};
protected:
WorkResult::RESULTS _resultType;
Cost* _results;
public:
WorkResult(WorkResult::RESULTS resultType, Cost& results);
WorkResult(WorkResult::RESULTS resultType);
virtual ~WorkResult();
Cost& getResults() const;
WorkResult::RESULTS getResultType() const;
};
class Worker {
protected:
ConvNet* _convNet;
public:
Worker(ConvNet& convNet);
virtual void run() = 0;
};
class DataWorker : public Worker {
protected:
CPUData* _data;
DataProvider* _dp;
public:
DataWorker(ConvNet& convNet, CPUData& data);
virtual ~DataWorker();
};
class TrainingWorker : public DataWorker {
protected:
bool _test;
double _progress;
public:
TrainingWorker(ConvNet& convNet, CPUData& data, double progress, bool test);
void run();
};
class SyncWorker : public Worker {
public:
SyncWorker(ConvNet& convNet);
void run();
};
class GradCheckWorker : public DataWorker {
public:
GradCheckWorker(ConvNet& convNet, CPUData& data);
void run();
};
class MultiviewTestWorker : public DataWorker {
protected:
int _numViews;
Matrix* _cpuProbs;
std::string _logregName;
public:
MultiviewTestWorker(ConvNet& convNet, CPUData& data, int numViews, Matrix& cpuProbs, const char* softmaxName);
MultiviewTestWorker(ConvNet& convNet, CPUData& data, int numViews);
~MultiviewTestWorker();
virtual void run();
};
class FeatureWorker : public DataWorker {
protected:
MatrixV *_ftrs;
stringv *_layerNames;
public:
FeatureWorker(ConvNet& convNet, CPUData& data, MatrixV& ftrs, stringv& layerNames);
~FeatureWorker();
void run();
};
class DataGradWorker : public DataWorker {
protected:
Matrix* _dataGrads;
int _dataLayerIdx, _softmaxLayerIdx;
public:
DataGradWorker(ConvNet& convNet, CPUData& data, Matrix& dataGrads, int dataLayerIdx, int softmaxLayerIdx);
~DataGradWorker();
void run();
};
#endif /* WORKER_CUH */