{"id":15169601,"url":"https://github.com/hmunachi/nanodl","last_synced_at":"2025-04-05T01:05:20.824Z","repository":{"id":189959377,"uuid":"681653336","full_name":"HMUNACHI/nanodl","owner":"HMUNACHI","description":"A Jax-based library for designing and training transformer models from scratch.","archived":false,"fork":false,"pushed_at":"2024-08-28T21:24:22.000Z","size":46538,"stargazers_count":284,"open_issues_count":2,"forks_count":10,"subscribers_count":8,"default_branch":"main","last_synced_at":"2025-04-03T11:22:12.706Z","etag":null,"topics":["attention","attention-mechanism","deep-learning","distributed-training","flax","gpt","jax","llama","machine-learning","mistral","nlp","transformer"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/HMUNACHI.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":".github/FUNDING.yml","license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null},"funding":{"github":["HMUNACHI"]}},"created_at":"2023-08-22T13:22:24.000Z","updated_at":"2025-03-28T05:08:14.000Z","dependencies_parsed_at":null,"dependency_job_id":"940e4190-24be-4ea3-bd30-23ee97bb9ad5","html_url":"https://github.com/HMUNACHI/nanodl","commit_stats":{"total_commits":138,"total_committers":3,"mean_commits":46.0,"dds":"0.021739130434782594","last_synced_commit":"e52861a5b2c9bf76e4e79e0bf88a07420497579d"},"previous_names":["hmunachi/jax-models","hmunachi/nanodl"],"tags_count":8,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/HMUNACHI%2Fnanodl","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/HMUNACHI%2Fnanodl/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/HMUNACHI%2Fnanodl/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/HMUNACHI%2Fnanodl/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/HMUNACHI","download_url":"https://codeload.github.com/HMUNACHI/nanodl/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247271520,"owners_count":20911587,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["attention","attention-mechanism","deep-learning","distributed-training","flax","gpt","jax","llama","machine-learning","mistral","nlp","transformer"],"created_at":"2024-09-27T07:04:05.092Z","updated_at":"2025-04-05T01:05:20.804Z","avatar_url":"https://github.com/HMUNACHI.png","language":"Python","funding_links":["https://github.com/sponsors/HMUNACHI"],"categories":[],"sub_categories":[],"readme":"\u003cp align=\"center\"\u003e\n  \u003cimg src=\"assets/logo.jpg\" alt=\"Alt text\"/\u003e\n\u003c/p\u003e\n\n# A Jax-based library for designing and training transformer models from scratch.\n\n![License](https://img.shields.io/github/license/hmunachi/nanodl?style=flat-square) [![Read the Docs](https://img.shields.io/readthedocs/nanodl?labelColor=blue\u0026color=white)](https://nanodl.readthedocs.io) [![Discord](https://img.shields.io/discord/1222217369816928286?style=social\u0026logo=discord\u0026label=Discord\u0026color=white)](https://discord.gg/3u9vumJEmz) [![LinkedIn](https://img.shields.io/badge/-LinkedIn-blue?style=flat-square\u0026logo=linkedin\u0026logoColor=white)](https://www.linkedin.com//company/80434055) [![Twitter](https://img.shields.io/twitter/follow/hmunachii?style=social)](https://twitter.com/hmunachii)\n\nAuthor: [Henry Ndubuaku](https://www.linkedin.com/in/henry-ndubuaku-7b6350b8/) (Discord \u0026 Docs badges are clickable)\n\nN/B: Codes are implemented pedagogically at the expense of repetition. \nEach model is purposefully contained in a file without inter-file dependencies. \n\n## Overview\nDeveloping and training transformer-based models is typically resource-intensive and time-consuming and AI/ML experts frequently need to build smaller-scale versions of these models for specific problems. Jax, a low-resource yet powerful framework, accelerates the development of neural networks and abstracts distributed training, but existing resources for transformer development in Jax are limited. NanoDL addresses this challenge with the following features:\n\n- A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch.\n- An extensive selection of models like Gemma, LlaMa3, Mistral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, CLIP etc.\n- Data-parallel distributed trainers models on multiple GPUs or TPUs, without the need for manual training loops.\n- Dataloaders, making the process of data handling for Jax/Flax more straightforward and effective.\n- Layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development.\n- GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc. \n- True random number generators in Jax which do not need the verbose code.\n- A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU, Tokenizer etc.\n- Each model is contained in a single file with no external dependencies, so the source code can also be easily used. \n- True random number generators in Jax which do not need the verbose code (examples shown in next sections).\n\nThere are experimental and/or unfinished features (like MAMBA, KAN, BitNet, GAT and RLHF) \nin the repo which are not yet available via the package, but can be copied from this repo.\nFeedback on any of our discussion, issue and pull request threads are welcomed! \nPlease report any feature requests, issues, questions or concerns in the [Discord](https://discord.gg/3u9vumJEmz), \nor just let us know what you're working on!\n\n## Quick install\n\nYou will need Python 3.9 or later, and working [JAX](https://github.com/google/jax/blob/main/README.md)\ninstallation, [FLAX](https://github.com/google/flax/blob/main/README.md)\ninstallation, [OPTAX](https://github.com/google-deepmind/optax/blob/main/README.md)\ninstallation (with GPU support for running training, without can only support creations).\nModels can be designed and tested on CPUs but trainers are all Distributed Data-Parallel \nwhich would require a GPU with 1 to N GPUS/TPUS. For CPU-only version of JAX:\n\n```\npip install --upgrade pip # To support manylinux2010 wheels.\npip install jax flax optax\n```\n\nThen, install nanodl from PyPi:\n\n```\npip install nanodl\n```\n\n## What does nanodl look like?\n\nWe provide various example usages of the nanodl API.\n\n```py\nimport jax\nimport nanodl\nimport jax.numpy as jnp\nfrom nanodl import ArrayDataset, DataLoader\nfrom nanodl import GPT4, GPTDataParallelTrainer\n\n# Preparing your dataset\nbatch_size = 8\nmax_length = 50\nvocab_size = 1000\n\n# Create random data\ndata = nanodl.uniform(\n    shape=(batch_size, max_length), \n    minval=0, maxval=vocab_size-1\n    ).astype(jnp.int32)\n\n# Shift to create next-token prediction dataset\ndummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]\n\n# Create dataset and dataloader\ndataset = ArrayDataset(dummy_inputs, dummy_targets)\ndataloader = DataLoader(\n    dataset, batch_size=batch_size, shuffle=True, drop_last=False\n    )\n\n# model parameters\nhyperparams = {\n    'num_layers': 1,\n    'hidden_dim': 256,\n    'num_heads': 2,\n    'feedforward_dim': 256,\n    'dropout': 0.1,\n    'vocab_size': vocab_size,\n    'embed_dim': 256,\n    'max_length': max_length,\n    'start_token': 0,\n    'end_token': 50,\n}\n\n# Inferred GPT4 model \nmodel = GPT4(**hyperparams)\n\ntrainer = GPTDataParallelTrainer(\n    model, dummy_inputs.shape, 'params.pkl'\n    )\n\ntrainer.train(\n    train_loader=dataloader, num_epochs=100, val_loader=dataloader\n    ) # use actual val data\n\n# Generating from a start token\nstart_tokens = jnp.array([[123, 456]])\n\n# Remember to load the trained parameters \nparams = trainer.load_params('params.pkl')\n\noutputs = model.apply(\n    {'params': params}, \n    start_tokens,\n    rngs={'dropout': nanodl.time_rng_key()}, \n    method=model.generate\n    )\n```\n\nVision example\n\n```py\nimport nanodl\nimport jax.numpy as jnp\nfrom nanodl import ArrayDataset, DataLoader\nfrom nanodl import DiffusionModel, DiffusionDataParallelTrainer\n\nimage_size = 32\nblock_depth = 2\nbatch_size = 8\nwidths = [32, 64, 128]\ninput_shape = (101, image_size, image_size, 3)\nimages = nanodl.normal(shape=input_shape)\n\n# Use your own images\ndataset = ArrayDataset(images) \ndataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) \n\n# Create diffusion model\ndiffusion_model = DiffusionModel(image_size, widths, block_depth)\n\n# Training on your data\ntrainer = DiffusionDataParallelTrainer(diffusion_model, \n                                       input_shape=images.shape, \n                                       weights_filename='params.pkl', \n                                       learning_rate=1e-4)\n\ntrainer.train(dataloader, 10)\n\n# Generate some samples: Each model is a Flax.linen module\n# Use as you normally would\nparams = trainer.load_params('params.pkl')\ngenerated_images = diffusion_model.apply({'params': params}, \n                                         num_images=5, \n                                         diffusion_steps=5, \n                                         method=diffusion_model.generate)\n```\n\nAudio example\n\n```py\nimport jax\nimport jax.numpy as jnp\nfrom nanodl import ArrayDataset, DataLoader\nfrom nanodl import Whisper, WhisperDataParallelTrainer\n\n# Dummy data parameters\nbatch_size = 8\nmax_length = 50\nembed_dim = 256 \nvocab_size = 1000 \n\n# Generate data: replace with actual tokenised/quantised data\ndummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)\ndummy_inputs = jnp.ones((101, max_length, embed_dim))\n\ndataset = ArrayDataset(dummy_inputs, dummy_targets)\ndataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)\n\n# model parameters\nhyperparams = {\n    'num_layers': 1,\n    'hidden_dim': 256,\n    'num_heads': 2,\n    'feedforward_dim': 256,\n    'dropout': 0.1,\n    'vocab_size': 1000,\n    'embed_dim': embed_dim,\n    'max_length': max_length,\n    'start_token': 0,\n    'end_token': 50,\n}\n\n# Initialize model\nmodel = Whisper(**hyperparams)\n\n# Training on your data\ntrainer = WhisperDataParallelTrainer(model, \n                                     dummy_inputs.shape, \n                                     dummy_targets.shape, \n                                     'params.pkl')\n\ntrainer.train(dataloader, 2, dataloader)\n\n# Sample inference\nparams = trainer.load_params('params.pkl')\n\n# for more than one sample, often use model.generate_batch\ntranscripts = model.apply({'params': params}, \n                          dummy_inputs[:1],\n                          method=model.generate)\n```\n\nReward Model example for RLHF\n\n```py\nimport nanodl\nimport jax.numpy as jnp\nfrom nanodl import ArrayDataset, DataLoader\nfrom nanodl import Mistral, RewardModel, RewardDataParallelTrainer\n\n# Generate dummy data\nbatch_size = 8\nmax_length = 10\n\n# Replace with actual tokenised data\ndummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32)\ndummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)\n\n# Create dataset and dataloader\ndataset = ArrayDataset(dummy_chosen, dummy_rejected)\ndataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)\n\n # model parameters\nhyperparams = {\n    'num_layers': 1,\n    'hidden_dim': 256,\n    'num_heads': 2,\n    'feedforward_dim': 256,\n    'dropout': 0.1,\n    'vocab_size': 1000,\n    'embed_dim': 256,\n    'max_length': max_length,\n    'start_token': 0,\n    'end_token': 50,\n    'num_groups': 2,\n    'window_size': 5,\n    'shift_size': 2\n}\n\n# Initialize reward model from Mistral\nmodel = Mistral(**hyperparams)\nreward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1)\n\n# Train the reward model\ntrainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')\ntrainer.train(dataloader, 5, dataloader)\nparams = trainer.load_params('reward_model_weights.pkl')\n\n# Call as you would a regular Flax model\nrewards = reward_model.apply({'params': params}, \n                    dummy_chosen, \n                    rngs={'dropout': nanodl.time_rng_key()})\n```\n\nPCA example\n\n```py\nimport nanodl\nfrom nanodl import PCA\n\n# Use actual data\ndata = nanodl.normal(shape=(1000, 10))\n\n# Initialise and train PCA model\npca = PCA(n_components=2)\npca.fit(data)\n\n# Get PCA transforms\ntransformed_data = pca.transform(data)\n\n# Get reverse transforms\noriginal_data = pca.inverse_transform(transformed_data)\n\n# Sample from the distribution\nX_sampled = pca.sample(n_samples=1000, key=None)\n```\n\nThis is still in dev, works great but roughness is expected, and contributions are therefore highly encouraged! \n\n- Make your changes without changing the design patterns.\n- Write tests for your changes if necessary.\n- Install locally with `pip3 install -e .`.\n- Run tests with `python3 -m unittest discover -s tests`.\n- Then submit a pull request.\n\nContributions can be made in various forms:\n\n- Writing documentation.\n- Fixing bugs.\n- Implementing papers.\n- Writing high-coverage tests.\n- Optimizing existing codes.\n- Experimenting and submitting real-world examples to the examples section.\n- Reporting bugs.\n- Responding to reported issues.\n\nJoin the [Discord Server](https://discord.gg/3u9vumJEmz) for more.\n\n## Sponsorships\n\nThe name \"NanoDL\" stands for Nano Deep Learning. Models are exploding in size, therefore gate-keeping \nexperts and companies with limited resources from building flexible models without prohibitive costs.\nFollowing the success of Phi models, the long-term goal is to build and train nano versions of all available models,\nwhile ensuring they compete with the original models in performance, with total \nnumber of parameters not exceeding 1B. Trained weights will be made available via this library.\nAny form of sponsorship, funding will help with training resources.\nYou can either sponsor via GitHub [here](https://github.com/sponsors/HMUNACHI) or reach out via ndubuakuhenry@gmail.com.\n\n## Citing nanodl\n\nTo cite this repository:\n\n```\n@software{nanodl2024github,\n  author = {Henry Ndubuaku},\n  title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.},\n  url = {http://github.com/hmunachi/nanodl},\n  year = {2024},\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fhmunachi%2Fnanodl","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fhmunachi%2Fnanodl","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fhmunachi%2Fnanodl/lists"}