Caffe源码学习笔记1:tools/caffe.cpp

Caffe源码学习笔记1:tools/caffe.cpp

caffe-master/Tools文件夹下提供了caffe框架的主要工具(经编译后为可执行文件,在build/tools/下)。

tools/caffe.cpp是caffe程序的入口(即main函数),一条标准的训练指令为:

./build/tools/caffe train --solver=models/bvlc_reference_caffenet/solver.prototxt

首先是caffe指令,可执行指令,train为caffe指令第一条参数,然后是指定solver文件。

我们对照着该标准指令一步一步来“解析”,caffe.cpp中main函数代码如下:

int main(int argc, char** argv) {

// gflags库,具体说明紧接代码(未找到其定义,估计在gflags库文件中定义)

FLAGS_alsologtostderr = 1;

// gflags库中为main函数设置usage信息:extern void SetUsageMessage(const std::string& usage);

gflags::SetUsageMessage("command line brew\n"

"usage: caffe \n\n"

"commands:\n"

" train train or finetune a model\n"

" test score a model\n"

" device_query show GPU diagnostic information\n"

" time benchmark model execution time");

// include/caffe/commom.hpp中声明的函数:Currently it initializes google flags and google logging.即初始化FLAGS.

caffe::GlobalInit(&argc, &argv);

// 判断参数,参数为2,继续执行action函数,否则输出usage信息。

if (argc == 2) {

#ifdef WITH_PYTHON_LAYER

try {

#endif

// GetBrewFunction函数返回函数指针,对于上面标准指令,则返回train函数指针,这里先不具体讲解。

return GetBrewFunction(caffe::string(argv[1]))();

#ifdef WITH_PYTHON_LAYER

} catch (bp::error_already_set) {

PyErr_Print();

return 1;

}

#endif

} else {

// glags中为main函数提供usage信息:

// extern void ShowUsageWithFlags(const char *argv0); // what --help does

// extern void ShowUsageWithFlagsRestrict(const char *argv0, const char *restrict);

// 其信息中会有“tools/caffe.cpp”中FLAG信息,如:-gpu,-weights,-solver,-snapshot,-model...

gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");

}

}

gflags是google的一个开源的处理命令行参数的库。在使用命令行参数的文件文件中(源文件或头文件),首先使用一下定义语句进行变量的定义。DEFINE_int32,DEFINE_int64,DEFINE_bool,DEFINE_double,DEFINE_string等,语法为:DEFINE_int32(name, default_value, "description")。接着你就可以使用FLAGS_name变量了,这些变量的值则是由命令行参数传递,无则为默认值,在其他代码文件中若想用该命令参数,可以用DECLARE_int32(name)声明(name为int32类型,也可以使用其他支持的类型)。在caffe.cpp中有很多FLAGS_name定义,如DEFINE_string(gpu,"","some description"),则命令行后-gpu 0,表示FLAGS_gpu=0,默认值为空。

在main函数中出现了GetBrewFunction函数,即在标准指令下,main函数将执行GetBrewFunction函数。首先看看caffe.cpp中一些重要代码:

// 这里定义函数指针类型BrewFunction

typedef int (*BrewFunction)();

// c++标准map容器,caffe执行的action name与对应函数的映射,容器类型名为BrewMap

typedef std::map BrewMap;

// 声明map容器变量g_brew_map

BrewMap g_brew_map;

// 宏定义,比如RegisterBrewFunction(train)时,相当于在容器g_brew_map中注册了train函数的函数指针和其对应的名字“train”,对于#和##的用法见下文。

#define RegisterBrewFunction(func) \

namespace { \

class __Registerer_##func { \

public: \

__Registerer_##func() { \

g_brew_map[#func] = &func; \

} \

}; \

__Registerer_##func g_registerer_##func; \

}

C++中#和##用法:在C/C++的宏中,”#”的功能是将其后面的宏参数进行字符串化操作(Stringfication),简单说就是在对它所引用的宏变量通过替换后在其左右各加上一个双引号。”##”被称为连接符(concatenator),用来将两个子串Token连接为一个Token。注意这里连接的对象是Token就行,而不一定是宏的变量。

凡是宏定义里有用’#’或’##’的地方宏参数是不会再展开。若要使’#’和’##’的宏参数被展开,可以加多一层中间转换宏。

在caffe.cpp中定义了一些BrewFunction类的函数,通过RegisterBrewFunction(function)注册进容器g_brew_map:

int device_query() :用来查询GPU信息

int train():训练神经网络

int test() :测试神经网络

int time():测试model执行时间

GetBrewFunction函数通过caffe命令后第一个参数在g_brew_map容器中查找对应函数指针并返回。代码如下:

static BrewFunction GetBrewFunction(const caffe::string& name) {

if (g_brew_map.count(name)) {

return g_brew_map[name];

} else {

LOG(ERROR) << "Available caffe actions:";

for (BrewMap::iterator it = g_brew_map.begin();

it != g_brew_map.end(); ++it) {

LOG(ERROR) << "\t" << it->first;

}

LOG(FATAL) << "Unknown action: " << name;

return NULL; // not reachable, just to suppress old compiler warnings.

}

}

//代码中LOG来源于google的glog库,控制程序的日志输出消息和测试消息(根据不同的lever输出消息)。

最后是执行相应的函数,如执行train函数,执行成功则返回0,main函数返回0.(caffe执行完毕)

最后看看caffe.cpp中train函数(也是caffe框架的关键),代码如下:

int train() {

// google的glog库,检查--solver、--snapshot和--weight并输出消息;必须有指定solver,snapshot和weight两者指定其一;

CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train.";

CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size())

<< "Give a snapshot to resume training or weights to finetune "

"but not both.";

// 实例化SolverParameter类,该类保存solver参数和相应的方法(SoverParameter是由google protobuffer编译过来的类,具体声明可以见代码文件build/src/caffe/proto/caffe.pb.h);

caffe::SolverParameter solver_param;

// 将-solver指定solver.prototxt文件内容解析到solver_param中,该函数声明在include/caffe/util/upgrade_proto.hpp中,实现在src/caffe/util/upgrade_proto.cpp中;

caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);

// 根据命令参数-gpu或者solver.prototxt提供的信息设置GPU;

if (FLAGS_gpu.size() == 0

&& solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) {

if (solver_param.has_device_id()) {

FLAGS_gpu = "" +

boost::lexical_cast(solver_param.device_id());

} else { // Set default GPU if unspecified

// boost::lexical_cast(0)是将数值0转换为字符串\'“0”;

FLAGS_gpu = "" + boost::lexical_cast(0);

}

}

// 多GPU下,将GPU编号存入vector容器中(get_gpus()函数通过FLAGS_gpu获取);

vector gpus;

get_gpus(&gpus);

if (gpus.size() == 0) {

LOG(INFO) << "Use CPU.";

Caffe::set_mode(Caffe::CPU);

} else {

ostringstream s;

for (int i = 0; i < gpus.size(); ++i) {

s << (i ? ", " : "") << gpus[i];

}

LOG(INFO) << "Using GPUs " << s.str();

solver_param.set_device_id(gpus[0]);

Caffe::SetDevice(gpus[0]);

Caffe::set_mode(Caffe::GPU);

Caffe::set_solver_count(gpus.size());

}

// 处理snapshot, stop or none信号,其声明在include/caffe/util/signal_Handler.h中;

// GetRequestedAction在caffe.cpp中,将‘stop’,‘snapshot’,‘none’转换为标准信号,即解析;

caffe::SignalHandler signal_handler(

GetRequestedAction(FLAGS_sigint_effect),

GetRequestedAction(FLAGS_sighup_effect));

// 声明boost库中智能指针solver,指向caffe::Solver对象,该对象由CreateSolver创建,后续细讲;

shared_ptr >

solver(caffe::SolverRegistry::CreateSolver(solver_param));

// Solver对象中方法的使用

solver->SetActionFunction(signal_handler.GetActionFunction());

// 从snapshot或caffemodel中恢复train;

if (FLAGS_snapshot.size()) {

LOG(INFO) << "Resuming from " << FLAGS_snapshot;

solver->Restore(FLAGS_snapshot.c_str());

} else if (FLAGS_weights.size()) {

CopyLayers(solver.get(), FLAGS_weights);

}

if (gpus.size() > 1) {

// 这里是对于多GPU下的处理,我们暂时不去深究了;

caffe::P2PSync sync(solver, NULL, solver->param());

sync.run(gpus);

} else {

LOG(INFO) << "Starting Optimization";

// 初始化完成,开始优化网络(核心,重要);

solver->Solve();

}

LOG(INFO) << "Optimization Done.";

return 0;

}

这里,本代码解析基本完成,总结一下程序运行流程:

main()函数--->>GetBrewFunction函数--->>train函数--->>Solve()

接下来,CreateSolver函数和Solver类是需要弄清楚的。