slurm提交Tensorflow任務實現

主要目的

目前tensorflow單機多卡模式可以參考tutorial很容易使用,但是如果想在集羣多節點搭建分佈式tensorflow訓練任務部署,官方沒有一個很好的示例代碼,只能通過很naive的方法,指定ps node/worker node,在不同的節點分別執行對應的程序來實現多機協同訓練模型的效果.這種方式對於集羣環境,存在大量節點的情況就顯得非常的不方便.本文是基於slurm集羣資源管理工具,實現分佈式tensorflow訓練任務的分發.

實現

#定義function用與讀取slurm提交一個任務後,分配的集羣計算資源.
#傳遞兩個參數:ps_number代表需要的parameter server節點個數,默認剩餘其它節點均作爲worker節點.
#作爲ps的node也可以作爲worker,但是爲了避免端口的衝突,我們不這麼做.
#port_number傳遞本次任務多節點通信的端口,如果ps所在的node同時還啓動了worker進程,那麼不同的worker進程需要指定不同的端口,爲方便,默認使用的節點個數num_nodes>1,worker與ps不分配在相同節點.

def tf_config_from_slurm(ps_number, port_number=2222):
    """
    Creates configuration for a distributed tensorflow session 
    from environment variables  provided by the Slurm cluster
    management system.

    @param: ps_number number of parameter servers to run
    @param: port_number port number to be used for communication
    @return: a tuple containing cluster with fields cluster_spec,
             task_name and task_id 
    """

    nodelist = os.environ["SLURM_JOB_NODELIST"]
    print(nodelist)
    print("jacob")
    nodename = os.environ["SLURMD_NODENAME"]
    nodelist = _expand_nodelist(nodelist)
    num_nodes = int(os.getenv("SLURM_JOB_NUM_NODES"))

    if len(nodelist) != num_nodes:
        raise ValueError("Number of slurm nodes {} not equal to {}".format(len(nodelist), num_nodes))

    if nodename not in nodelist:
        raise ValueError("Nodename({}) not in nodelist({}). This should not happen! ".format(nodename,nodelist))
  if ps_number > num_nodes :
        raise ValueError("Number of ps node is largger than nodes be given by slurm!")
    ps_nodes = [node for i, node in enumerate(nodelist) if i < ps_number]
    worker_nodes = [node for i, node in enumerate(nodelist) if i >= ps_number]

    if nodename in ps_nodes:
        my_job_name = "ps"
        my_task_index = ps_nodes.index(nodename)
    else:
        my_job_name = "worker"
        my_task_index = worker_nodes.index(nodename)

    worker_sockets = [":".join([node, str(port_number)]) for node in worker_nodes]
    ps_sockets = [":".join([node, str(port_number)]) for node in ps_nodes]
    cluster = {"worker": worker_sockets, "ps" : ps_sockets}

    return cluster, my_job_name, my_task_index

def _pad_zeros(iterable, length):
    return (str(t).rjust(length, '0') for t in iterable)
def _expand_ids(ids):
    ids = ids.split(',')
    result = []
    for id in ids:
        if '-' in id:
            begin, end = [int(token) for token in id.split('-')]
            result.extend(_pad_zeros(range(begin, end+1), len(token)))
        else:
            result.append(id)
    return result

def _expand_nodelist(nodelist):
    prefix, ids = re.findall("(.*)\[(.*)\]", nodelist)[0]
    ids = _expand_ids(ids)
    result = [prefix + str(id) for id in ids]
    return result

def _worker_task_id(nodelist, nodename):
    return nodelist.index(nodename)

tensorflow構建網絡模型

# 獲取slurm分配的集羣計算資源,以及當前執行節點的job name,配置clusterspec並啓動server.
# 另外需要注意的是ps節點因爲要保持接收worker的消息,完成參數的同步更新,所以其服務需要一直join,不能直接退出.
cluster, my_job_name, my_task_index = tf_config_from_slurm(ps_number=3)
cluster_spec = tf.train.ClusterSpec(cluster)
server = tf.train.Server(server_or_cluster_def=cluster_spec,
                         job_name=my_job_name,
                         task_index=my_task_index)

if my_job_name == 'ps':
    server.join()
    sys.exit(0)
後續完善後繼續更新
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章