四、Tensorflow的分布式训练
TensorFlow中的集群(cluster)指的是一系列能够针对图(Graph)进行分布式计算任务(task)。每个任务是同服务(server)相关联的。TensorFlow中的服务会包含一个用于创建session的主节点和至少一个用于图运算的工作节点,一个集群可以被拆分为一个活着多个作业(job),每个作业可以包含至少一个任务。
以下的例子是一个最简单的例子
1、服务端代码:
import tensorflow as tf \'\'\' 运行命令: python tensf_server_01 --job_name=ps --task_index=0 python tensf_server_01 --job_name=ps --task_index=0 python tensf_server_01 --job_name=work --task_index=0 python tensf_server_01 --job_name=work --task_index=1 python tensf_server_01 --job_name=work --task_index=2 \'\'\' #1、配置服务器相关信息 #因为tensorflow底层代码中,默认就是使用ps和work分别表示两类不同的工作节点 #ps:变量/张量的初始化,存储相关节点 #work:变量/张量的计算/运算的相关节点 ps_host = [\'127.0.0.1:33331\',\'127.0.0.1:33332\'] work_hosts = [\'127.0.0.1:33333\',\'127.0.0.1:33334\',\'127.0.0.1:33335\'] cluster = tf.train.ClusterSpec({\'ps\':ps_host,\'work\':work_hosts}) #2、定义一些运行参数(在运行该python文件的时候就可以制定这些参数了) tf.app.flags.DEFINE_string(\'job_name\',default_value=\'work\',docstring="One of \'ps\' or \'work\'") tf.app.flags.DEFINE_integer(\'task_index\',default_value=0,docstring="Index of task within the job") FLAGS = tf.app.flags.FLAGS #2、启动服务 #_下划线表示占位符 def main(_): print(FLAGS.job_name) server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index) server.join() if __name__ == \'__main__\': #底层默认会调用main方法 tf.app.run()
2、client端的代码:
import tensorflow as tf import numpy as np #1、构建图 #表示使用ps的job,task:0表示使用第一个配置,也就是127.0.0.1:33331 with tf.device(\'/job:ps/task:0\'): #构造函数 x = tf.constant(np.random.rand(100).astype(np.float32)) with tf.device(\'/job:ps/task:1\'): y = y = x * 0.2 +0.3 #2、运行 with tf.Session(target=\'grpc://127.0.0.1:33335\', config=tf.ConfigProto(log_device_placement=True)) as sess: sess.run(y)
- 上一篇 »Java 分布式通信的几种方式及其特点?
- 下一篇 »Java 远程调用与分布式通信的区别