主要目的
目前tensorflow單機多卡模式可以參考tutorial很容易使用,但是如果想在集羣多節點搭建分佈式tensorflow訓練任務部署,官方沒有一個很好的示例代碼,只能通過很naive的方法,指定ps node/worker node,在不同的節點分別執行對應的程序來實現多機協同訓練模型的效果.這種方式對於集羣環境,存在大量節點的情況就顯得非常的不方便.本文是基於slurm集羣資源管理工具,實現分佈式tensorflow訓練任務的分發.
實現
#定義function用與讀取slurm提交一個任務後,分配的集羣計算資源.
#傳遞兩個參數:ps_number代表需要的parameter server節點個數,默認剩餘其它節點均作爲worker節點.
#作爲ps的node也可以作爲worker,但是爲了避免端口的衝突,我們不這麼做.
#port_number傳遞本次任務多節點通信的端口,如果ps所在的node同時還啓動了worker進程,那麼不同的worker進程需要指定不同的端口,爲方便,默認使用的節點個數num_nodes>1,worker與ps不分配在相同節點.
def tf_config_from_slurm(ps_number, port_number=2222):
"""
Creates configuration for a distributed tensorflow session
from environment variables provided by the Slurm cluster
management system.
@param: ps_number number of parameter servers to run
@param: port_number port number to be used for communication
@return: a tuple containing cluster with fields cluster_spec,
task_name and task_id
"""
nodelist = os.environ["SLURM_JOB_NODELIST"]
print(nodelist)
print("jacob")
nodename = os.environ["SLURMD_NODENAME"]
nodelist = _expand_nodelist(nodelist)
num_nodes = int(os.getenv("SLURM_JOB_NUM_NODES"))
if len(nodelist) != num_nodes:
raise ValueError("Number of slurm nodes {} not equal to {}".format(len(nodelist), num_nodes))
if nodename not in nodelist:
raise ValueError("Nodename({}) not in nodelist({}). This should not happen! ".format(nodename,nodelist))
if ps_number > num_nodes :
raise ValueError("Number of ps node is largger than nodes be given by slurm!")
ps_nodes = [node for i, node in enumerate(nodelist) if i < ps_number]
worker_nodes = [node for i, node in enumerate(nodelist) if i >= ps_number]
if nodename in ps_nodes:
my_job_name = "ps"
my_task_index = ps_nodes.index(nodename)
else:
my_job_name = "worker"
my_task_index = worker_nodes.index(nodename)
worker_sockets = [":".join([node, str(port_number)]) for node in worker_nodes]
ps_sockets = [":".join([node, str(port_number)]) for node in ps_nodes]
cluster = {"worker": worker_sockets, "ps" : ps_sockets}
return cluster, my_job_name, my_task_index
def _pad_zeros(iterable, length):
return (str(t).rjust(length, '0') for t in iterable)
def _expand_ids(ids):
ids = ids.split(',')
result = []
for id in ids:
if '-' in id:
begin, end = [int(token) for token in id.split('-')]
result.extend(_pad_zeros(range(begin, end+1), len(token)))
else:
result.append(id)
return result
def _expand_nodelist(nodelist):
prefix, ids = re.findall("(.*)\[(.*)\]", nodelist)[0]
ids = _expand_ids(ids)
result = [prefix + str(id) for id in ids]
return result
def _worker_task_id(nodelist, nodename):
return nodelist.index(nodename)
tensorflow構建網絡模型
# 獲取slurm分配的集羣計算資源,以及當前執行節點的job name,配置clusterspec並啓動server.
# 另外需要注意的是ps節點因爲要保持接收worker的消息,完成參數的同步更新,所以其服務需要一直join,不能直接退出.
cluster, my_job_name, my_task_index = tf_config_from_slurm(ps_number=3)
cluster_spec = tf.train.ClusterSpec(cluster)
server = tf.train.Server(server_or_cluster_def=cluster_spec,
job_name=my_job_name,
task_index=my_task_index)
if my_job_name == 'ps':
server.join()
sys.exit(0)
後續完善後繼續更新