{"id":18264556,"url":"https://github.com/modeltc/rank_dataset","last_synced_at":"2025-04-04T21:30:37.829Z","repository":{"id":103144878,"uuid":"350985903","full_name":"ModelTC/rank_dataset","owner":"ModelTC","description":"PyTorch Dataset Rank Dataset ","archived":false,"fork":false,"pushed_at":"2021-03-24T08:11:03.000Z","size":1204,"stargazers_count":42,"open_issues_count":1,"forks_count":10,"subscribers_count":4,"default_branch":"main","last_synced_at":"2025-03-20T19:16:00.170Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/ModelTC.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2021-03-24T07:30:37.000Z","updated_at":"2025-02-25T15:43:42.000Z","dependencies_parsed_at":null,"dependency_job_id":"10463ba0-577f-49dc-8919-c09307d8f1e5","html_url":"https://github.com/ModelTC/rank_dataset","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ModelTC%2Frank_dataset","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ModelTC%2Frank_dataset/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ModelTC%2Frank_dataset/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ModelTC%2Frank_dataset/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/ModelTC","download_url":"https://codeload.github.com/ModelTC/rank_dataset/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247251976,"owners_count":20908602,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-11-05T11:15:05.723Z","updated_at":"2025-04-04T21:30:37.823Z","avatar_url":"https://github.com/ModelTC.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# PyTorch 大规模数据集加载 \n\n## [知乎链接](https://zhuanlan.zhihu.com/p/357809861)\n\n## 问题阐述\n对于数据量比较小的数据集，一般来说我们直接加载到内存里即可，不需要考虑内存是否够用的情况。对于大规模数据集（千万级别以上）我们普通的加载方式已经没法满足我们的需求，内存问题已经成为瓶颈之一，因此针对此我们需要作出一些针对性的优化。\n\n### PyTorch的数据集加载背景简单介绍\n\n一般来说我们只需要关注这以下几个部分，dataset 是我们实现的自己数据集的具体原型，继承自torch.dataset, 其getitem 函数负责根据index去拿出我们对应的meta信息，sampler 负责提供想要拿出的index顺序。详细的PyTorch 数据集加载可以阅读官网或者可以看这篇。[Dataset](https://zhuanlan.zhihu.com/p/337850513)\n\n![1.jpg](./images/1.jpg)\n\n### 具体实现\n\n* 背景介绍\n\n以下为一个最简单的dataset 实现，对于8卡任务因为是多进程，所以实际的进程数量为8, 也就是会有8 * metas 需要在内存里存放(实际考虑到dataloader 的worker 数量，这个实际占用量会更大)，当我们的metas信息比较大的时候，我们的内存就可能会出现溢出问题。\n\n* 普通样例\n\n```python\nclass BaseDataset(Dataset):\n    def __init__(self, meta_file):\n        super(BaseDataset, self).__init__()\n        self.metas = self.parse(meta_file)\n\n    def parse(self, meta_file):\n        metas = []\n        with open(meta_file) as f:\n            for line in f.readlines():\n                metas.append(line.strip())\n        return metas\n\n    def __getitem__(self, idx):\n        return self.metas[idx]\n\n```\n\n* meta_file 格式\n```shell\n#filename label (分类任务)\nimage1.jpg 1\nimage2.jpg 0\nimage3.jpg 3\n```\n\n* 训练流程\n\n训练数据的流程可以表示如下:\n\n```python\ndataset = BaseDataset(\"/path/to/meta\")\nsampler = DistributedSampler(datset)\ndataloader = DataLoader(\n            dataset=dataset,\n            batch_size=32,\n            shuffle=False,\n            num_workers=4,\n            sampler=sampler\n        )\nmodel = build_model()\nfor index, batch in enumerate(dataloader):\n    image, label = batch\n    output = model(image)\n    loss = criterion(output, label)\n    loss.backward()\n    \n```\n\n#### 解决方案一\n\n将metas 信息中心化，放到一个中心化的地方进行存储，只保留一份，这样可以存储非常大的metas。然后dataset 从中心化的地方去获取meta信息\n\n![2.jpg](./images/2.jpg)\n\n* example\n\n```python\nclass ServerDataset(BaseDataset):\n    def __init__(self, meta_file, server_ip, server_port, timeout=1000):\n        super(ServerDataset, self).__init__(meta_file)\n        self.server_ip = server_ip\n        self.server_port = server_port\n        self.timeout = timeout\n        self.meta_num = self.get_meta_num()\n\n    @retry(stop_max_delay=10, stop_max_attempt_number=10)\n    def get_meta_num(self):\n        meta_num = requests.get('http://{}:{}/get_len'.format(\n            self.server_ip, self.server_port), timeout=self.timeout).json()\n        return int(meta_num)\n\n    @retry(stop_max_delay=10, stop_max_attempt_number=10)\n    def get_meta(self, idx):\n        meta = requests.get('http://{}:{}/get/{}'.format(\n            self.server_ip, self.server_port, idx), timeout=self.timeout).json()\n        return meta\n```\n\n* 训练流程\n\n**启动server**\n\n```shell\npython server.py --meta_file=\"/path/to/meta\" --port=\"10080\"\n\n```\n\n**启动训练**\n\n```python\ndataset = ServerDataset(\"/path/to/meta\", server_ip=\"10.10.10.10\", server_port=\"10080\")\nsampler = DistributedSampler(datset)\ndataloader = DataLoader(\n            dataset=dataset,\n            batch_size=32,\n            shuffle=False,\n            num_workers=4,\n            sampler=sampler\n        )\nmodel = build_model()\nfor index, batch in enumerate(dataloader):\n    image, label = batch\n    output = model(image)\n    loss = criterion(output, label)\n    loss.backward()\n    \n```\n\n* 弊端\n\n这种做法对于qps 在1k以下还比较实用, 但是当我们的训练的总batchsize 特别大的时候这种做法会有明显的瓶颈问题，受限于中心化的读取上限问题，因此此方法具有一定的局限性。\n\n\n#### 解决方案二\n* 背景知识\n\n从原理出发，在分布式训练的过程中，其实每张卡实际使用的数据量为 len(metas) // world_size, 在一般的训练过程中我们为了访问方便，采用sampler 去划分不同的卡读取的index，每块卡还是会保留所有的meta信息，因此这样会导致前面的内存问题。\n\n* 具体方案\n\n我们的方案具体为 分rank + 切分数据集进一步的动态的去加载我们的数据集。如下图所示，在初始化的时候，每块卡只加载其对应的meta信息，这样总体的内存占用率可减少 world_size 倍。为了进一步的减少内存，我们还可以进一步将数据集进行切分，分成 mini_epoch 进行分组读取。两者配合使用，总体的内存减少量可达 world_size * mini_epoch 倍，基本上可以达到我们的需求。\n\n*实际的流程图*\n![3.jpg](./images/3.jpg)\n\n* 切分流程\n```python\n'''\n                     Metas 切分过程, mini_epoch = 2, world_size = 8\n\n    mini_epoch_idx = 0                            mini_epoch_idx = 1\n---- ---- ---- ---- ---- ---- ---- ---- | ---- ---- ---- ---- ---- ---- ---- ---- \nrk0  rk1  rk2  rk3  rk4  rk5  rk6  rk7  | rk0  rk1  rk2  rk3  rk4  rk5  rk6  rk7 \n\n每次只加载 len(metas) // (world_size * mini_epoch) 这样我们的内存占用就会可以人为的进行调整\n\n'''\n```\n* 注意\n\n对于普通的dataloader，随机性一般由sampler进行控制，我们这里由于已经分rank进行加载我们的meta 信息，因此每隔一个epoch我们需要重新分配一次我们每个 rank 的 meta 信息，为了保证随机性，在分配rank的meta信息时，我们就要引入随机性, 以下是从本地读取的样例。\n\n* 本地读取样例\n\n![4.jpg](./images/4.jpg)\n\n* 训练流程\n\n```python\nfor epoch_num in range(epoch_num):\n    reload_cfg = {\"mini_epoch\": 1, \"seed\": epoch_num, \"mini_epoch_idx\": 0, \"group\": 1}\n    dataset = RankDataset(\"/path/to/meta\", is_test=False, reload_cfg)\n    sampler = RandomSampler(datset)\n    dataloader = DataLoader(\n                dataset=dataset,\n                batch_size=32,\n                shuffle=False,\n                num_workers=4,\n                sampler=sampler\n            )\n```\n\n\n\n* server 读取样例\n\n本地读取常常会受限于文件系统的读取效率，在我们的文件系统读取速度比较差的时候整个加载会比较慢，因此提供一个中心化读取方案，适用于网络较快的情况。\n\n![5.jpg](./images/5.jpg)\n\n**启动server**\n\n```shell\npython server.py --meta_file=\"/path/to/meta\" --port=\"10080\"\n\n```\n\n**启动训练**\n\n```python\nfor epoch_num in range(epoch_num):\n    reload_cfg = {\"mini_epoch\": 1, \"seed\": epoch_num, \"mini_epoch_idx\": 0, \"group\": 1}\n    dataset = RankServerDataset(\"/path/to/meta\", server_ip=\"10.10.10.10\", server_port=\"10080\", is_test=False, reload_cfg)\n    sampler = RandomSampler(datset)\n    dataloader = DataLoader(\n                dataset=dataset,\n                batch_size=32,\n                shuffle=False,\n                num_workers=4,\n                sampler=sampler\n            )\n```\n\n\n**需要注意**\n当我们需要切分mini_epoch 的时候，每个mini_epoch 都需要进行重新构建dataloader\n\n* Sampler\n\n这是切分rank 之后的sampler，这里就不再需要区分rank了，meta 已经根据rank进行区分\n\n```python\nclass RandomSampler(Sampler):\n    r\"\"\"Samples elements randomly, without replacement.\n\n    Arguments:\n        data_source (Dataset): dataset to sample from\n    \"\"\"\n\n    def __init__(self, dataset):\n        self.dataset = dataset\n\n    def __iter__(self):\n        return iter(torch.randperm(len(self.dataset)).tolist())\n\n    def __len__(self):\n        return len(self.dataset)\n\n```\n\n**以下是普通的分布式的sampler**\n\n```python\nclass DistributedSampler(Sampler):\n    def __init__(self, dataset, world_size=None, rank=None):\n        if world_size is None:\n            world_size = get_world_size()\n        if rank is None:\n            rank = get_rank()\n\n        self.dataset = dataset\n        self.world_size = world_size\n        self.rank = rank\n        self.num_samples = int(\n            math.ceil(len(self.dataset) * 1.0 / self.world_size))\n        self.total_size = self.num_samples * self.world_size\n\n    def __iter__(self):\n        # deterministically shuffle based on epoch\n        g = torch.Generator()\n        g.manual_seed(self.epoch)\n        indices = list(torch.randperm(len(self.dataset), generator=g))\n\n        # add extra samples to make it evenly divisible\n        indices += indices[:(self.total_size - len(indices))]\n        assert len(indices) == self.total_size\n\n        # subsample\n        offset = self.num_samples * self.rank\n        indices = indices[offset:offset + self.num_samples]\n        assert len(indices) == self.num_samples\n\n        return iter(indices)\n\n    def __len__(self):\n        return self.num_samples\n```","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmodeltc%2Frank_dataset","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fmodeltc%2Frank_dataset","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmodeltc%2Frank_dataset/lists"}