华为云AI开发平台ModelArtsStep2 准备脚本文件并上传至OBS中_云淘科技
准备本案例所需训练脚本 mindspore-verification.py 文件和 Ascend 的启动脚本文件(共5个)。
训练脚本文件具体内容请参见训练脚本 mindspore-verification.py 文件。
Ascend 的启动脚本文件包括以下5个,具体脚本内容请参见Ascend 的启动脚本文件。
run_ascend.py
common.py
rank_table.py
manager.py
fmk.py
mindspore-verification.py 和run_ascend.py脚本文件在创建训练作业时的“启动命令”参数中调用,具体请参见•启动命令:“python ${MA_JOB_D…。
run_ascend.py脚本运行时会调用common.py、rank_table.py、manager.py、fmk.py脚本。
上传训练脚本 mindspore-verification.py 文件至OBS桶的“obs://test-modelarts/ascend/demo-code/”文件夹下。
图1 训练脚本文件上传完成后的OBS列表
上传Ascend的启动脚本文件(共5个)至OBS桶的“obs://test-modelarts/ascend/demo-code/run_ascend/”文件夹下。
图2 Ascend的启动脚本文件上传完成后的OBS列表
训练脚本 mindspore-verification.py 文件
mindspore-verification.py 文件内容如下:
import os import numpy as np from mindspore import Tensor import mindspore.ops as ops import mindspore.context as context print('Ascend Envs') print('------') print('JOB_ID: ', os.environ['JOB_ID']) print('RANK_TABLE_FILE: ', os.environ['RANK_TABLE_FILE']) print('RANK_SIZE: ', os.environ['RANK_SIZE']) print('ASCEND_DEVICE_ID: ', os.environ['ASCEND_DEVICE_ID']) print('DEVICE_ID: ', os.environ['DEVICE_ID']) print('RANK_ID: ', os.environ['RANK_ID']) print('------') context.set_context(device_target="Ascend") x = Tensor(np.ones([1,3,3,4]).astype(np.float32)) y = Tensor(np.ones([1,3,3,4]).astype(np.float32)) print(ops.add(x, y))
Ascend 的启动脚本文件
1. run_ascend.py
import sys import os from common import RunAscendLog from common import RankTableEnv from rank_table import RankTable, RankTableTemplate1, RankTableTemplate2 from manager import FMKManager if __name__ == '__main__': log = RunAscendLog.setup_run_ascend_logger() if len(sys.argv) = 1: log.info('set rank table %s env to %s' % (RankTableEnv.RANK_TABLE_FILE, rank_table.get_rank_table_path())) RankTableEnv.set_rank_table_env(rank_table.get_rank_table_path()) else: log.info('device num < 1, unset rank table %s env' % RankTableEnv.RANK_TABLE_FILE) RankTableEnv.unset_rank_table_env() instance = rank_table.get_current_instance() server = rank_table.get_server(instance.server_id) current_instance = RankTable.convert_server_to_instance(server) fmk_manager = FMKManager(current_instance) fmk_manager.run(rank_table.get_device_num(), train_command) return_code = fmk_manager.monitor() fmk_manager.destroy() sys.exit(return_code)
2. common.py
import logging import os logo = 'Training' # Rank Table Constants class RankTableEnv: RANK_TABLE_FILE = 'RANK_TABLE_FILE' RANK_TABLE_FILE_V1 = 'RANK_TABLE_FILE_V_1_0' HCCL_CONNECT_TIMEOUT = 'HCCL_CONNECT_TIMEOUT' # jobstart_hccl.json is provided by the volcano controller of Cloud-Container-Engine(CCE) HCCL_JSON_FILE_NAME = 'jobstart_hccl.json' RANK_TABLE_FILE_DEFAULT_VALUE = '/user/config/%s' % HCCL_JSON_FILE_NAME @staticmethod def get_rank_table_template1_file_dir(): parent_dir = os.environ[ModelArts.MA_MOUNT_PATH_ENV] return os.path.join(parent_dir, 'rank_table') @staticmethod def get_rank_table_template2_file_path(): rank_table_file_path = os.environ.get(RankTableEnv.RANK_TABLE_FILE) if rank_table_file_path is None: return RankTableEnv.RANK_TABLE_FILE_DEFAULT_VALUE return os.path.join(os.path.normpath(rank_table_file_path), RankTableEnv.HCCL_JSON_FILE_NAME) @staticmethod def set_rank_table_env(path): os.environ[RankTableEnv.RANK_TABLE_FILE] = path @staticmethod def unset_rank_table_env(): del os.environ[RankTableEnv.RANK_TABLE_FILE] class ModelArts: MA_MOUNT_PATH_ENV = 'MA_MOUNT_PATH' MA_CURRENT_INSTANCE_NAME_ENV = 'MA_CURRENT_INSTANCE_NAME' MA_VJ_NAME = 'MA_VJ_NAME' MA_CURRENT_HOST_IP = 'MA_CURRENT_HOST_IP' CACHE_DIR = '/cache' TMP_LOG_DIR = '/tmp/log/' FMK_WORKSPACE = 'workspace' @staticmethod def get_current_instance_name(): return os.environ[ModelArts.MA_CURRENT_INSTANCE_NAME_ENV] @staticmethod def get_current_host_ip(): return os.environ.get(ModelArts.MA_CURRENT_HOST_IP) @staticmethod def get_job_id(): ma_vj_name = os.environ[ModelArts.MA_VJ_NAME] return ma_vj_name.replace('ma-job', 'modelarts-job', 1) @staticmethod def get_parent_working_dir(): if ModelArts.MA_MOUNT_PATH_ENV in os.environ: return os.path.join(os.environ.get(ModelArts.MA_MOUNT_PATH_ENV), ModelArts.FMK_WORKSPACE) return ModelArts.CACHE_DIR class RunAscendLog: @staticmethod def setup_run_ascend_logger(): name = logo formatter = logging.Formatter(fmt='[run ascend] %(asctime)s - %(levelname)s - %(message)s') handler = logging.StreamHandler() handler.setFormatter(formatter) logger = logging.getLogger(name) logger.setLevel(logging.INFO) logger.addHandler(handler) logger.propagate = False return logger @staticmethod def get_run_ascend_logger(): return logging.getLogger(logo)
3. rank_table.py
import json import time import os from common import ModelArts from common import RunAscendLog from common import RankTableEnv log = RunAscendLog.get_run_ascend_logger() class Device: def __init__(self, device_id, device_ip, rank_id): self.device_id = device_id self.device_ip = device_ip self.rank_id = rank_id class Instance: def __init__(self, pod_name, server_id, devices): self.pod_name = pod_name self.server_id = server_id self.devices = self.parse_devices(devices) @staticmethod def parse_devices(devices): if devices is None: return [] device_object_list = [] for device in devices: device_object_list.append(Device(device['device_id'], device['device_ip'], '')) return device_object_list def set_devices(self, devices): self.devices = devices class Group: def __init__(self, group_name, device_count, instance_count, instance_list): self.group_name = group_name self.device_count = int(device_count) self.instance_count = int(instance_count) self.instance_list = self.parse_instance_list(instance_list) @staticmethod def parse_instance_list(instance_list): instance_object_list = [] for instance in instance_list: instance_object_list.append( Instance(instance['pod_name'], instance['server_id'], instance['devices'])) return instance_object_list class RankTable: STATUS_FIELD = 'status' COMPLETED_STATUS = 'completed' def __init__(self): self.rank_table_path = "" self.rank_table = {} @staticmethod def read_from_file(file_path): with open(file_path) as json_file: return json.load(json_file) @staticmethod def wait_for_available(rank_table_file, period=1): log.info('Wait for Rank table file at %s ready' % rank_table_file) complete_flag = False while not complete_flag: with open(rank_table_file) as json_file: data = json.load(json_file) if data[RankTable.STATUS_FIELD] == RankTable.COMPLETED_STATUS: log.info('Rank table file is ready for read') log.info(' ' + json.dumps(data, indent=4)) return True time.sleep(period) return False @staticmethod def convert_server_to_instance(server): device_list = [] for device in server['device']: device_list.append( Device(device_id=device['device_id'], device_ip=device['device_ip'], rank_id=device['rank_id'])) ins = Instance(pod_name='', server_id=server['server_id'], devices=[]) ins.set_devices(device_list) return ins def get_rank_table_path(self): return self.rank_table_path def get_server(self, server_id): for server in self.rank_table['server_list']: if server['server_id'] == server_id: log.info('Current server') log.info(' ' + json.dumps(server, indent=4)) return server log.error('server [%s] is not found' % server_id) return None class RankTableTemplate2(RankTable): def __init__(self, rank_table_template2_path): super().__init__() json_data = self.read_from_file(file_path=rank_table_template2_path) self.status = json_data[RankTableTemplate2.STATUS_FIELD] if self.status != RankTableTemplate2.COMPLETED_STATUS: return # sorted instance list by the index of instance # assert there is only one group json_data["group_list"][0]["instance_list"] = sorted(json_data["group_list"][0]["instance_list"], key=RankTableTemplate2.get_index) self.group_count = int(json_data['group_count']) self.group_list = self.parse_group_list(json_data['group_list']) self.rank_table_path, self.rank_table = self.convert_template2_to_template1_format_file() @staticmethod def parse_group_list(group_list): group_object_list = [] for group in group_list: group_object_list.append( Group(group['group_name'], group['device_count'], group['instance_count'], group['instance_list'])) return group_object_list @staticmethod def get_index(instance): # pod_name example: job94dc1dbf-job-bj4-yolov4-15 pod_name = instance["pod_name"] return int(pod_name[pod_name.rfind("-") + 1:]) def get_current_instance(self): """ get instance by pod name specially, return the first instance when the pod name is None :return: """ pod_name = ModelArts.get_current_instance_name() if pod_name is None: if len(self.group_list) > 0: if len(self.group_list[0].instance_list) > 0: return self.group_list[0].instance_list[0] return None for group in self.group_list: for instance in group.instance_list: if instance.pod_name == pod_name: return instance return None def convert_template2_to_template1_format_file(self): rank_table_template1_file = { 'status': 'completed', 'version': '1.0', 'server_count': '0', 'server_list': [] } logic_index = 0 server_map = {} # collect all devices in all groups for group in self.group_list: if group.device_count == 0: continue for instance in group.instance_list: if instance.server_id not in server_map: server_map[instance.server_id] = [] for device in instance.devices: template1_device = { 'device_id': device.device_id, 'device_ip': device.device_ip, 'rank_id': str(logic_index) } logic_index += 1 server_map[instance.server_id].append(template1_device) server_count = 0 for server_id in server_map: rank_table_template1_file['server_list'].append({ 'server_id': server_id, 'device': server_map[server_id] }) server_count += 1 rank_table_template1_file['server_count'] = str(server_count) log.info('Rank table file (Template1)') log.info(' ' + json.dumps(rank_table_template1_file, indent=4)) if not os.path.exists(RankTableEnv.get_rank_table_template1_file_dir()): os.makedirs(RankTableEnv.get_rank_table_template1_file_dir()) path = os.path.join(RankTableEnv.get_rank_table_template1_file_dir(), RankTableEnv.HCCL_JSON_FILE_NAME) with open(path, 'w') as f: f.write(json.dumps(rank_table_template1_file)) log.info('Rank table file (Template1) is generated at %s', path) return path, rank_table_template1_file def get_device_num(self): total_device_num = 0 for group in self.group_list: total_device_num += group.device_count return total_device_num class RankTableTemplate1(RankTable): def __init__(self, rank_table_template1_path): super().__init__() self.rank_table_path = rank_table_template1_path self.rank_table = self.read_from_file(file_path=rank_table_template1_path) def get_current_instance(self): current_server = None server_list = self.rank_table['server_list'] if len(server_list) == 1: current_server = server_list[0] elif len(server_list) > 1: host_ip = ModelArts.get_current_host_ip() if host_ip is not None: for server in server_list: if server['server_id'] == host_ip: current_server = server break else: current_server = server_list[0] if current_server is None: log.error('server is not found') return None return self.convert_server_to_instance(current_server) def get_device_num(self): server_list = self.rank_table['server_list'] device_num = 0 for server in server_list: device_num += len(server['device']) return device_num
4. manager.py
import time import os import os.path import signal from common import RunAscendLog from fmk import FMK log = RunAscendLog.get_run_ascend_logger() class FMKManager: # max destroy time: ~20 (15 + 5) # ~ 15 (1 + 2 + 4 + 8) MAX_TEST_PROC_CNT = 4 def __init__(self, instance): self.instance = instance self.fmk = [] self.fmk_processes = [] self.get_sigterm = False self.max_test_proc_cnt = FMKManager.MAX_TEST_PROC_CNT # break the monitor and destroy processes when get terminate signal def term_handle(func): def receive_term(signum, stack): log.info('Received terminate signal %d, try to destroyed all processes' % signum) stack.f_locals['self'].get_sigterm = True def handle_func(self, *args, **kwargs): origin_handle = signal.getsignal(signal.SIGTERM) signal.signal(signal.SIGTERM, receive_term) res = func(self, *args, **kwargs) signal.signal(signal.SIGTERM, origin_handle) return res return handle_func def run(self, rank_size, command): for index, device in enumerate(self.instance.devices): fmk_instance = FMK(index, device) self.fmk.append(fmk_instance) self.fmk_processes.append(fmk_instance.run(rank_size, command)) @term_handle def monitor(self, period=1): # busy waiting for all fmk processes exit by zero # or there is one process exit by non-zero fmk_cnt = len(self.fmk_processes) zero_ret_cnt = 0 while zero_ret_cnt != fmk_cnt: zero_ret_cnt = 0 for index in range(fmk_cnt): fmk = self.fmk[index] fmk_process = self.fmk_processes[index] if fmk_process.poll() is not None: if fmk_process.returncode != 0: log.error('proc-rank-%s-device-%s (pid: %d) has exited with non-zero code: %d' % (fmk.rank_id, fmk.device_id, fmk_process.pid, fmk_process.returncode)) return fmk_process.returncode zero_ret_cnt += 1 if self.get_sigterm: break time.sleep(period) return 0 def destroy(self, base_period=1): log.info('Begin destroy training processes') self.send_sigterm_to_fmk_process() self.wait_fmk_process_end(base_period) log.info('End destroy training processes') def send_sigterm_to_fmk_process(self): # send SIGTERM to fmk processes (and process group) for r_index in range(len(self.fmk_processes) - 1, -1, -1): fmk = self.fmk[r_index] fmk_process = self.fmk_processes[r_index] if fmk_process.poll() is not None: log.info('proc-rank-%s-device-%s (pid: %d) has exited before receiving the term signal', fmk.rank_id, fmk.device_id, fmk_process.pid) del self.fmk_processes[r_index] del self.fmk[r_index] try: os.killpg(fmk_process.pid, signal.SIGTERM) except ProcessLookupError: pass def wait_fmk_process_end(self, base_period): test_cnt = 0 period = base_period while len(self.fmk_processes) > 0 and test_cnt 0: for r_index in range(len(self.fmk_processes) - 1, -1, -1): fmk = self.fmk[r_index] fmk_process = self.fmk_processes[r_index] if fmk_process.poll() is None: log.warn('proc-rank-%s-device-%s (pid: %d) has not exited within the max waiting time, ' 'send kill signal', fmk.rank_id, fmk.device_id, fmk_process.pid) os.killpg(fmk_process.pid, signal.SIGKILL)
5. fmk.py
import os import subprocess import pathlib from contextlib import contextmanager from common import RunAscendLog from common import RankTableEnv from common import ModelArts log = RunAscendLog.get_run_ascend_logger() class FMK: def __init__(self, index, device): self.job_id = ModelArts.get_job_id() self.rank_id = device.rank_id self.device_id = str(index) def gen_env_for_fmk(self, rank_size): current_envs = os.environ.copy() current_envs['JOB_ID'] = self.job_id current_envs['ASCEND_DEVICE_ID'] = self.device_id current_envs['DEVICE_ID'] = self.device_id current_envs['RANK_ID'] = self.rank_id current_envs['RANK_SIZE'] = str(rank_size) FMK.set_env_if_not_exist(current_envs, RankTableEnv.HCCL_CONNECT_TIMEOUT, str(1800)) log_dir = FMK.get_log_dir() process_log_path = os.path.join(log_dir, self.job_id, 'ascend', 'process_log', 'rank_' + self.rank_id) FMK.set_env_if_not_exist(current_envs, 'ASCEND_PROCESS_LOG_PATH', process_log_path) pathlib.Path(current_envs['ASCEND_PROCESS_LOG_PATH']).mkdir(parents=True, exist_ok=True) return current_envs @contextmanager def switch_directory(self, directory): owd = os.getcwd() try: os.chdir(directory) yield directory finally: os.chdir(owd) def get_working_dir(self): fmk_workspace_prefix = ModelArts.get_parent_working_dir() return os.path.join(os.path.normpath(fmk_workspace_prefix), 'device%s' % self.device_id) @staticmethod def get_log_dir(): parent_path = os.getenv(ModelArts.MA_MOUNT_PATH_ENV) if parent_path: log_path = os.path.join(parent_path, 'log') if os.path.exists(log_path): return log_path return ModelArts.TMP_LOG_DIR @staticmethod def set_env_if_not_exist(envs, env_name, env_value): if env_name in os.environ: log.info('env already exists. env_name: %s, env_value: %s ' % (env_name, env_value)) return envs[env_name] = env_value def run(self, rank_size, command): envs = self.gen_env_for_fmk(rank_size) log.info('bootstrap proc-rank-%s-device-%s' % (self.rank_id, self.device_id)) log_dir = FMK.get_log_dir() if not os.path.exists(log_dir): os.makedirs(log_dir) log_file = '%s-proc-rank-%s-device-%s.txt' % (self.job_id, self.rank_id, self.device_id) log_file_path = os.path.join(log_dir, log_file) working_dir = self.get_working_dir() if not os.path.exists(working_dir): os.makedirs(working_dir) with self.switch_directory(working_dir): # os.setsid: change the process(forked) group id to itself training_proc = subprocess.Popen(command, env=envs, preexec_fn=os.setsid, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) log.info('proc-rank-%s-device-%s (pid: %d)', self.rank_id, self.device_id, training_proc.pid) # https://docs.python.org/3/library/subprocess.html#subprocess.Popen.wait subprocess.Popen(['tee', log_file_path], stdin=training_proc.stdout) return training_proc
父主题: 示例:从 0 到 1 制作自定义镜像并用于训练(MindSpore+Ascend)
同意关联代理商云淘科技,购买华为云产品更优惠(QQ 78315851)
内容没看懂? 不太想学习?想快速解决? 有偿解决: 联系专家