{"id":13687547,"url":"https://github.com/GFNOrg/torchgfn","last_synced_at":"2025-05-01T12:34:26.676Z","repository":{"id":62614192,"uuid":"508766972","full_name":"GFNOrg/torchgfn","owner":"GFNOrg","description":"A modular, easy to extend GFlowNet library","archived":false,"fork":false,"pushed_at":"2024-11-05T14:08:45.000Z","size":6592,"stargazers_count":231,"open_issues_count":51,"forks_count":29,"subscribers_count":10,"default_branch":"master","last_synced_at":"2024-11-05T15:24:42.095Z","etag":null,"topics":["amortized-sampling","gflownet","gflownets","gfn","library","pytorch"],"latest_commit_sha":null,"homepage":"https://torchgfn.readthedocs.io/en/latest/","language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"other","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/GFNOrg.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-06-29T16:26:43.000Z","updated_at":"2024-11-05T14:50:14.000Z","dependencies_parsed_at":"2023-11-07T09:27:31.712Z","dependency_job_id":"5a840786-c7b2-4e18-9bd9-310ac3a8bb41","html_url":"https://github.com/GFNOrg/torchgfn","commit_stats":null,"previous_names":["saleml/torchgfn","gfnorg/torchgfn"],"tags_count":7,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/GFNOrg%2Ftorchgfn","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/GFNOrg%2Ftorchgfn/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/GFNOrg%2Ftorchgfn/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/GFNOrg%2Ftorchgfn/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/GFNOrg","download_url":"https://codeload.github.com/GFNOrg/torchgfn/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":224257776,"owners_count":17281772,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["amortized-sampling","gflownet","gflownets","gfn","library","pytorch"],"created_at":"2024-08-02T15:00:56.394Z","updated_at":"2025-05-01T12:34:26.666Z","avatar_url":"https://github.com/GFNOrg.png","language":"Jupyter Notebook","funding_links":[],"categories":["Jupyter Notebook"],"sub_categories":[],"readme":"\u003cp align=\"center\"\u003e\n    \u003ca\u003e\n\t    \u003cimg src='https://img.shields.io/badge/python-3.10%2B-blueviolet' alt='Python' /\u003e\n\t\u003c/a\u003e\n\t\u003ca href='https://torchgfn.readthedocs.io/en/latest/?badge=latest'\u003e\n    \t\u003cimg src='https://readthedocs.org/projects/torchgfn/badge/?version=latest' alt='Documentation Status' /\u003e\n\t\u003c/a\u003e\n    \u003ca\u003e\n\t    \u003cimg src='https://img.shields.io/badge/code%20style-black-black' /\u003e\n\t\u003c/a\u003e\n\u003c/p\u003e\n\n\u003c/p\u003e\n\u003cp align=\"center\"\u003e\n  \u003ca href=\"https://torchgfn.readthedocs.io/en/latest/\"\u003eDocumentation\u003c/a\u003e ~ \u003ca href=\"https://github.com/saleml/torchgfn\"\u003eCode\u003c/a\u003e ~ \u003ca href=\"https://arxiv.org/abs/2305.14594\"\u003ePaper\u003c/a\u003e\n\u003c/p\u003e\n\n# torchgfn: a Python package for GFlowNets\n\n\u003cp align=\"center\"\u003e Please cite \u003ca href=\"https://arxiv.org/abs/2305.14594\"\u003ethis paper\u003c/a\u003e if you are using the library for your research \u003c/p\u003e\n\n## Installing the package\n\nThe codebase requires python \u003e= 3.10. To install the latest stable version:\n\n```bash\npip install torchgfn\n```\n\nOptionally, to run scripts:\n\n```bash\npip install torchgfn[scripts]\n```\n\nTo install the cutting edge version (from the `main` branch):\n\n```bash\ngit clone https://github.com/GFNOrg/torchgfn.git\nconda create -n gfn python=3.10\nconda activate gfn\ncd torchgfn\npip install -e \".[all]\"\n```\n\n## Installing `oneccl` bindings for multinode training.\n\nYou can determine the version of `pytorch` installed using the command\n\n```\necho $(python -c $\"import torch; print(torch.__version__)\")\n```\n\nafter which you can install the closest matching version from [this table](https://github.com/intel/torch-ccl?tab=readme-ov-file#install-prebuilt-wheel) (otherwise, you must build from source). You can see the specific wheels [here](https://pytorch-extension.intel.com/release-whl/stable/cpu/us/oneccl-bind-pt/).\n\n```\npip install oneccl_bind_pt=={pytorch_version} -f https://pytorch-extension.intel.com/release-whl/stable/cpu/us/\n```\n\nfor example, if your pytorch version is `2.0.1+cu117`, you would run `python -m pip install oneccl_bind_pt==2.0.0+cpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/`.\n\n\n***TODO: Rough instructions - to integrate into docs (just moving them here from email) -\n```\n# Create \u0026 activate conda env.\nconda create -n gfn python=3.10\nconda activate gfn\n\n# Install the package.\npip install .[scripts]  # Includes `tqdm`.\n\n# We will use torch-ccl library for multinode implementation. The latest torch-ccl is compatible with PyTorch 2.2.0. The above command installs the latest torch. So, we need to uninstall it and install latest torch. If you agree that we can make it the default version, I can update it in pyproject.toml.\npip uninstall torch -y\nconda install pytorch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 cpuonly -c pytorch\n\n# Install torch-ccl\ngit clone https://github.com/intel/torch-ccl.git torch-ccl \u0026\u0026 cd torch-ccl\ngit checkout tags/v2.2.0+cpu\ngit submodule sync\ngit submodule update --init --recursive\n\n# TODO: this didn't work for me -- I had to use a prebuilt wheel.\n# ONECCL_BINDINGS_FOR_PYTORCH_BACKEND=cpu python setup.py install\npython -m pip install oneccl_bind_pt --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/\n\n# Installation is complete now.\n\nYou can submit a job by modifying one of the slurm scripts and submitting. For example, ddp_gfn.small.8.slurm. Please note that you need to modify the conda env name in the slurm script to the name of your env. Also, change the paths and dimensions if needed. I submit the script using the following command:\n\nsbatch ddp_gfn.small.4.mila.slurm\n```\n\n\n## About this repo\n\nThis repo serves the purpose of fast prototyping [GFlowNet](https://arxiv.org/abs/2111.09266) (GFN) related algorithms. It decouples the environment definition, the sampling process, and the parametrization of the function approximators used to calculate the GFN loss. It aims to accompany researchers and engineers in learning about GFlowNets, and in developing new algorithms.\n\nCurrently, the library is shipped with three environments: two discrete environments (Discrete Energy Based Model and Hyper Grid) and a continuous box environment. The library is designed to allow users to define their own environments. See [here](https://github.com/saleml/torchgfn/tree/master/tutorials/ENV.md) for more details.\n\n### Scripts and notebooks\n\nExample scripts and notebooks for the three environments are provided [here](https://github.com/saleml/torchgfn/tree/master/tutorials/examples). For the hyper grid and the box environments, the provided scripts are supposed to reproduce published results.\n\n\n### Standalone example\n\nThis example, which shows how to use the library for a simple discrete environment, requires [`tqdm`](https://github.com/tqdm/tqdm) package to run. Use `pip install tqdm` or install all extra requirements with `pip install .[scripts]` or `pip install torchgfn[scripts]`. In the first example, we will train a Tarjectory Balance GFlowNet:\n\n```python\nimport torch\nfrom tqdm import tqdm\n\nfrom gfn.gflownet import TBGFlowNet\nfrom gfn.gym import HyperGrid  # We use the hyper grid environment\nfrom gfn.preprocessors import KHotPreprocessor\nfrom gfn.modules import DiscretePolicyEstimator\nfrom gfn.samplers import Sampler\nfrom gfn.utils.modules import MLP  # is a simple multi-layer perceptron (MLP)\n\n# 1 - We define the environment.\nenv = HyperGrid(ndim=4, height=8, R0=0.01)  # Grid of size 8x8x8x8\npreprocessor = KHotPreprocessor(ndim=env.ndim, height=env.height)\n\n# 2 - We define the needed modules (neural networks).\nmodule_PF = MLP(\n    input_dim=preprocessor.output_dim,\n    output_dim=env.n_actions\n)  # Neural network for the forward policy, with as many outputs as there are actions\n\nmodule_PB = MLP(\n    input_dim=preprocessor.output_dim,\n    output_dim=env.n_actions - 1,\n    trunk=module_PF.trunk  # We share all the parameters of P_F and P_B, except for the last layer\n)\n\n# 3 - We define the estimators.\npf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=preprocessor)\npb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=preprocessor)\n\n# 4 - We define the GFlowNet.\ngfn = TBGFlowNet(logZ=0., pf=pf_estimator, pb=pb_estimator)  # We initialize logZ to 0\n\n# 5 - We define the sampler and the optimizer.\nsampler = Sampler(estimator=pf_estimator)  # We use an on-policy sampler, based on the forward policy\n\n# Different policy parameters can have their own LR.\n# Log Z gets dedicated learning rate (typically higher).\noptimizer = torch.optim.Adam(gfn.pf_pb_parameters(), lr=1e-3)\noptimizer.add_param_group({\"params\": gfn.logz_parameters(), \"lr\": 1e-1})\n\n# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration\nfor i in (pbar := tqdm(range(1000))):\n    trajectories = sampler.sample_trajectories(env=env, n=16, save_logprobs=True)  # The save_logprobs=True makes on-policy training faster\n    optimizer.zero_grad()\n    loss = gfn.loss(env, trajectories)\n    loss.backward()\n    optimizer.step()\n    if i % 25 == 0:\n        pbar.set_postfix({\"loss\": loss.item()})\n```\n\nand in this example, we instead train using Sub Trajectory Balance. You can see we simply assemble our GFlowNet from slightly different building blocks:\n\n```python\nimport torch\nfrom tqdm import tqdm\n\nfrom gfn.gflownet import SubTBGFlowNet\nfrom gfn.gym import HyperGrid  # We use the hyper grid environment\nfrom gfn.preprocessors import KHotPreprocessor\nfrom gfn.modules import DiscretePolicyEstimator, ScalarEstimator\nfrom gfn.samplers import Sampler\nfrom gfn.utils.modules import MLP  # MLP is a simple multi-layer perceptron (MLP)\n\n# 1 - We define the environment.\nenv = HyperGrid(ndim=4, height=8, R0=0.01)  # Grid of size 8x8x8x8\npreprocessor = KHotPreprocessor(ndim=env.ndim, height=env.height)\n\n# 2 - We define the needed modules (neural networks).\n# The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator\nmodule_PF = MLP(\n    input_dim=preprocessor.output_dim,\n    output_dim=env.n_actions\n)  # Neural network for the forward policy, with as many outputs as there are actions\n\nmodule_PB = MLP(\n    input_dim=preprocessor.output_dim,\n    output_dim=env.n_actions - 1,\n    trunk=module_PF.trunk  # We share all the parameters of P_F and P_B, except for the last layer\n)\nmodule_logF = MLP(\n    input_dim=preprocessor.output_dim,\n    output_dim=1,  # Important for ScalarEstimators!\n)\n\n# 3 - We define the estimators.\npf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=preprocessor)\npb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=preprocessor)\nlogF_estimator = ScalarEstimator(module=module_logF, preprocessor=env.preprocessor)\n\n# 4 - We define the GFlowNet.\ngfn = SubTBGFlowNet(pf=pf_estimator, pb=pb_estimator, logF=logF_estimator, lamda=0.9)\n\n# 5 - We define the sampler and the optimizer.\nsampler = Sampler(estimator=pf_estimator)\n\n# Different policy parameters can have their own LR.\n# Log F gets dedicated learning rate (typically higher).\noptimizer = torch.optim.Adam(gfn.pf_pb_parameters(), lr=1e-3)\noptimizer.add_param_group({\"params\": gfn.logF_parameters(), \"lr\": 1e-2})\n\n# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration\nfor i in (pbar := tqdm(range(1000))):\n    # We are going to sample trajectories off policy, by tempering the distribution.\n    # We should not save the sampling logprobs, as we are not using them for training.\n    # We should save the estimator outputs to make training faster.\n    trajectories = sampler.sample_trajectories(env=env, n=16, save_logprobs=False, save_estimator_outputs=True, temperature=1.5)\n    optimizer.zero_grad()\n    loss = gfn.loss(env, trajectories)\n    loss.backward()\n    optimizer.step()\n    if i % 25 == 0:\n        pbar.set_postfix({\"loss\": loss.item()})\n\n```\n\n## Contributing\n\nBefore the first commit:\n\n```bash\npip install -e .[dev,scripts]\npre-commit install\npre-commit run --all-files\n```\n\nRun `pre-commit` after staging, and before committing. Make sure all the tests pass (By running `pytest`). Note that the `pytest` hook of `pre-commit` only runs the tests in the `testing/` folder. To run all the tests, which take longer, run `pytest` manually.\n\nThe codebase uses:\n- `black` formatter for code style\n- `flake8` for linting\n- `pyright` for static type checking\n\nThe pre-commit hooks ensure code quality and type safety across the project. The pyright configuration includes all project directories including tutorials/examples and testing.\n\nTo make the docs locally:\n\n```bash\ncd docs\nmake html\nopen build/html/index.html\n```\n\n## Details about the codebase\n\n### Defining an environment\n\nSee [here](https://github.com/saleml/torchgfn/tree/master/tutorials/ENV.md)\n\n### States\n\nStates are the primitive building blocks for GFlowNet objects such as transitions and trajectories, on which losses operate.\n\nAn abstract `States` class is provided. But for each environment, a `States` subclass is needed. A `States` object\nis a collection of multiple states (nodes of the DAG). A tensor representation of the states is required for batching. If a state is represented with a tensor of shape `(*state_shape)`, a batch of states is represented with a `States` object, with the attribute `tensor` of shape `(*batch_shape, *state_shape)`. Other\nrepresentations are possible (e.g. a state as a string, a `numpy` array, a graph, etc...), but these representations cannot be batched, unless the user specifies a function that transforms these raw states to tensors.\n\nThe `batch_shape` attribute is required to keep track of the batch dimension. A trajectory can be represented by a States object with `batch_shape = (n_states,)`. Multiple trajectories can be represented by a States object with `batch_shape = (n_states, n_trajectories)`.\n\nBecause multiple trajectories can have different lengths, batching requires appending a dummy tensor to trajectories that are shorter than the longest trajectory. The dummy state is the $s_f$ attribute of the environment (e.g. `[-1, ..., -1]`, or `[-inf, ..., -inf]`, etc...). Which is never processed, and is used to pad the batch of states only.\n\nFor discrete environments, the action set is represented with the set $\\{0, \\dots, n_{actions} - 1\\}$, where the $(n_{actions})$-th action always corresponds to the exit or terminate action, i.e. that results in a transition of the type $s \\rightarrow s_f$, but not all actions are possible at all states. For discrete environments, each `States` object is endowed with two extra attributes: `forward_masks` and `backward_masks`, representing which actions are allowed at each state and which actions could have led to each state, respectively. Such states are instances of the `DiscreteStates` abstract subclass of `States`. The `forward_masks` tensor is of shape `(*batch_shape, n_{actions})`, and `backward_masks` is of shape `(*batch_shape, n_{actions} - 1)`. Each subclass of `DiscreteStates` needs to implement the `update_masks` function, that uses the environment's logic to define the two tensors.\n\n### Actions\nActions should be though of as internal actions of an agent building a compositional object. They correspond to transitions $s \\rightarrow s'$. An abstract `Actions` class is provided. It is automatically subclassed for discrete environments, but needs to be manually subclassed otherwise.\n\nSimilar to `States` objects, each action is a tensor of shape `(*batch_shape, *action_shape)`. For discrete environments for instances, `action_shape = (1,)`, representing an integer between $0$ and $n_{actions} - 1$.\n\nAdditionally, each subclass needs to define two more class variable tensors:\n- `dummy_action`: A tensor that is padded to sequences of actions in the shorter trajectories of a batch of trajectories. It is `[-1]` for discrete environments.\n- `exit_action`: A tensor that corresponds to the termination action. It is `[n_{actions} - 1]` fo discrete environments.\n\n### Containers\n\nContainers are collections of `States`, along with other information, such as reward values, or densities $p(s' \\mid s)$. Three containers are available:\n\n- [Transitions](https://github.com/saleml/torchgfn/tree/master/src/gfn/containers/transitions.py), representing a batch of transitions $s \\rightarrow s'$.\n- [Trajectories](https://github.com/saleml/torchgfn/tree/master/src/gfn/containers/trajectories.py), representing a batch of complete trajectories $\\tau = s_0 \\rightarrow s_1 \\rightarrow \\dots \\rightarrow s_n \\rightarrow s_f$.\n- [StatePairs](https://github.com/saleml/torchgfn/tree/master/src/gfn/containers/state_pairs.py), representing pairs of states with optional conditioning, particularly useful for flow matching algorithms.\n\nThese containers can either be instantiated using a `States` object, or can be initialized as empty containers that can be populated on the fly, allowing the usage of the [ReplayBuffer](https://github.com/saleml/torchgfn/tree/master/src/gfn/containers/replay_buffer.py) class.\n\nThey inherit from the base `Container` [class](https://github.com/saleml/torchgfn/tree/master/src/gfn/containers/base.py), indicating some helpful methods.\n\nIn most cases, one needs to sample complete trajectories. From a batch of trajectories, various training samples can be generated:\n- Use `Trajectories.to_transitions()` and `Trajectories.to_states()` for edge-decomposable or state-decomposable losses\n- Use `Trajectories.to_state_pairs()` for flow matching losses\n- Use `GFlowNet.loss_from_trajectories()` as a convenience method that handles the conversion internally\n\nThese methods exclude meaningless transitions and dummy states that were added to the batch of trajectories to allow for efficient batching.\n\n### Modules\n\nTraining GFlowNets requires one or multiple estimators, called `GFNModule`s, which is an abstract subclass of `torch.nn.Module`. In addition to the usual `forward` function, `GFNModule`s need to implement a `required_output_dim` attribute, to ensure that the outputs have the required dimension for the task at hand; and some (but not all) need to implement a `to_probability_distribution` function.\n\n- `DiscretePolicyEstimator` is a `GFNModule` that defines the policies $P_F(. \\mid s)$ and $P_B(. \\mid s)$ for discrete environments. When `is_backward=False`, the required output dimension is `n = env.n_actions`, and when `is_backward=True`, it is `n = env.n_actions - 1`. These `n` numbers represent the logits of a Categorical distribution. The corresponding `to_probability_distribution` function transforms the logits by masking illegal actions (according to the forward or backward masks), then return a Categorical distribution. The masking is done by setting the corresponding logit to $-\\infty$. The function also includes exploration parameters, in order to define a tempered version of $P_F$, or a mixture of $P_F$ with a uniform distribution. `DiscretePolicyEstimator` with `is_backward=False` can be used to represent log-edge-flow estimators $\\log F(s \\rightarrow s')$.\n- `ScalarModule` is a simple module with required output dimension 1. It is useful to define log-state flows $\\log F(s)$.\n\nFor non-discrete environments, the user needs to specify their own policies $P_F$ and $P_B$. The module, taking as input a batch of states (as a `States`) object, should return the batched parameters of a `torch.Distribution`. The distribution depends on the environment. The `to_probability_distribution` function handles the conversion of the parameter outputs to an actual batched `Distribution` object, that implements at least the `sample` and `log_prob` functions. An example is provided [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/gym/helpers/box_utils.py), for a square environment in which the forward policy has support either on a quarter disk, or on an arc-circle, such that the angle, and the radius (for the quarter disk part) are scaled samples from a mixture of Beta distributions. The provided example shows an intricate scenario, and it is not expected that user defined environment need this much level of details.\n\nIn general, (and perhaps obviously) the `to_probability_distribution` method is used to calculate a probability distribution from a policy. Therefore, in order to go off-policy, one needs to modify the computations in this method during sampling. One accomplishes this using `policy_kwargs`, a `dict` of kwarg-value pairs which are used by the `Estimator` when calculating the new policy. In the discrete case, where common settings apply, one can see their use in `DiscretePolicyEstimator`'s `to_probability_distribution` method by passing a softmax `temperature`, `sf_bias` (a scalar to subtract from the exit action logit) or `epsilon` which allows for e-greedy style exploration. In the continuous case, it is not possible to foresee the methods used for off-policy exploration (as it depends on the details of the `to_probability_distribution` method, which is not generic for continuous GFNs), so this must be handled by the user, using custom `policy_kwargs`.\n\nIn all `GFNModule`s, note that the input of the `forward` function is a `States` object. Meaning that they first need to be transformed to tensors. However, `states.tensor` does not necessarily include the structure that a neural network can used to generalize. It is common in these scenarios to have a function that transforms these raw tensor states to ones where the structure is clearer, via a `Preprocessor` object, that is part of the environment. More on this [here](https://github.com/saleml/torchgfn/tree/master/tutorials/ENV.md). The `forward` pass thus first calls the `preprocessor` attribute of the environment on `States`, before performing any transformation. The `preprocessor` is thus an attribute of the module. If it is not explicitly defined, it is set to the identity preprocessor.\n\nFor discrete environments, a `Tabular` module is provided, where a lookup table is used instead of a neural network. Additionally, a `UniformPB` module is provided, implementing a uniform backward policy. These modules are provided [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/utils/modules.py).\n\n### Samplers\n\nA [Sampler](https://github.com/saleml/torchgfn/tree/master/src/gfn/samplers.py) object defines how actions are sampled (`sample_actions()`) at each state, and trajectories  (`sample_trajectories()`), which can sample a batch of trajectories starting from a given set of initial states or starting from $s_0$. It requires a `GFNModule` that implements the `to_probability_distribution` function. For simple off-policy sampling (e.g., epsilon-noisy or tempering), you can pass appropriate `policy_kwargs` to the `Sampler` object, which will be used by the `GFNModule`. If you need more complex off-policy sampling, you can subclass the `Sampler` object, and override the `sample_actions` and `sample_trajectories` methods.\n\nCurrently, the library provides two samplers:\n\n- Sampler\n- LocalSearchSampler (references: [EB-GFN](https://arxiv.org/abs/2202.01361), [LS-GFN](https://arxiv.org/abs/2310.02710))\n\n\n### Losses\n\nGFlowNets can be trained with different losses, each of which requires a different parametrization, which we call in this library a `GFlowNet`. A `GFlowNet` includes one or multiple `GFNModule`s, at least one of which implements a `to_probability_distribution` function. They also need to implement a `loss` function, that takes as input either states, transitions, or trajectories, depending on the loss.\n\nCurrently, the implemented losses are:\n\n- Flow Matching\n- Detailed Balance (and it's modified variant).\n- Trajectory Balance\n- Sub-Trajectory Balance. By default, each sub-trajectory is weighted geometrically (within the trajectory) depending on its length. This corresponds to the strategy defined [here](https://www.semanticscholar.org/reader/f2c32fe3f7f3e2e9d36d833e32ec55fc93f900f5). Other strategies exist and are implemented [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/losses/sub_trajectory_balance.py).\n- Log Partition Variance loss. Introduced [here](https://arxiv.org/abs/2302.05446)\n\n### Extending GFlowNets\n\nTo define a new `GFlowNet`, the user needs to define a class which subclasses `GFlowNet` and implements the following methods:\n\n- `sample_trajectories`: Sample a specific number of complete trajectories.\n- `loss`: Compute the loss given the training objects.\n- `to_training_samples`: Convert trajectories to training samples.\n\nBased on the type of training samples returned by `to_training_samples`, the user should define the generic type `TrainingSampleType` when subclassing `GFlowNet`. For example, if the training sample is an instance of `Trajectories`, the `GFlowNet` class should be subclassed as `GFlowNet[Trajectories]`. Thus, the class definition should look like this:\n\n```python\nclass MyGFlowNet(GFlowNet[Trajectories]):\n    ...\n```\n\n**Example: Flow Matching GFlowNet**\n\nLet's consider the example of the `FMGFlowNet` class, which is a subclass of `GFlowNet` that implements the Flow Matching GFlowNet. The training samples are pairs of states managed by the `StatePairs` container:\n\n```python\nclass FMGFlowNet(GFlowNet[StatePairs[DiscreteStates]]):\n    ...\n\n    def to_training_samples(\n        self, trajectories: Trajectories\n    ) -\u003e StatePairs[DiscreteStates]:\n        \"\"\"Converts a batch of trajectories into a batch of training samples.\"\"\"\n        return trajectories.to_state_pairs()\n```\n\nThis means that the `loss` method of `FMGFlowNet` will receive a `StatePairs[DiscreteStates]` object as its training samples argument:\n\n```python\ndef loss(self, env: DiscreteEnv, states: StatePairs[DiscreteStates]) -\u003e torch.Tensor:\n    ...\n```\n\n**Adding New Training Sample Types**\n\nIf your GFlowNet returns a unique type of training samples, you'll need to expand the `TrainingSampleType` bound. This ensures type-safety and better code clarity.\n\n**Implementing Class Methods**\n\nAs mentioned earlier, your new GFlowNet must implement the following methods:\n\n- `sample_trajectories`: Sample a specific number of complete trajectories.\n- `loss`: Compute the loss given the training objects.\n- `to_training_samples`: Convert trajectories to training samples.\n\nThese methods are defined in `src/gfn/gflownet/base.py` and are abstract methods, so they must be implemented in your new GFlowNet. If your GFlowNet has unique functionality which should be represented as additional class methods, implement them as required. Remember to document new methods to ensure other developers understand their purposes and use-cases!\n\n**Testing**\n\nRemember to create unit tests for your new GFlowNet to ensure it works as intended and integrates seamlessly with other parts of the codebase. This ensures maintainability and reliability of the code!\n\n\n## Training Examples\n\nThe repository includes several example environments and training scripts. Below are three different implementations of training on the HyperGrid environment, which serve as good starting points for understanding GFlowNets:\n\n1. `tutorials/examples/train_hypergrid.py`: The main training script with full features:\n   - Multiple loss functions (FM, TB, DB, SubTB, ZVar, ModifiedDB)\n   - Weights \u0026 Biases integration for experiment tracking\n   - Support for replay buffers (including prioritized)\n   - Visualization capabilities for 2D environments:\n     * True probability distribution\n     * Learned probability distribution\n     * L1 distance evolution over training\n   - Various hyperparameter options\n   - Reproduces results from multiple papers (see script docstring)\n\n2. `tutorials/examples/train_hypergrid_simple.py`: A simplified version focused on core concepts:\n   - Uses only Trajectory Balance (TB) loss\n   - Minimal architecture with shared trunks\n   - No extra features (no replay buffer, no wandb)\n   - Great starting point for understanding GFlowNets\n\n3. `tutorials/examples/train_hypergrid_simple_ls.py`: Demonstrates advanced sampling strategies:\n   - Implements local search sampling\n   - Configurable local search parameters\n   - Optional Metropolis-Hastings acceptance criterion\n   - Shows how to extend basic GFlowNet training with sophisticated sampling\n\nOther environments available in the package include:\n- Discrete Energy Based Model: A simple environment for learning energy-based distributions\n- Box Environment: A continuous environment for sampling from distributions in bounded spaces\n- Custom environments can be added by following the environment creation guide in `tutorials/ENV.md`\n\n## Usage Examples\n\nTo train with Weights \u0026 Biases tracking:\n```bash\npython tutorials/examples/train_hypergrid.py --ndim 4 --height 8 --wandb_project your_project_name\n```\n\nTo train with visualization (2D environments only):\n```bash\npython tutorials/examples/train_hypergrid.py --ndim 2 --height 8 --plot\n```\n\nTo try the simple version with epsilon-greedy exploration:\n```bash\npython tutorials/examples/train_hypergrid_simple.py --ndim 2 --height 8 --epsilon 0.1\n```\n\nTo experiment with local search:\n```bash\npython tutorials/examples/train_hypergrid_simple_ls.py --ndim 2 --height 8 --n_local_search_loops 2 --back_ratio 0.5 --use_metropolis_hastings\n```\n\nFor more options and configurations, check the help of each script:\n```bash\npython tutorials/examples/train_hypergrid.py --help\npython tutorials/examples/train_hypergrid_simple.py --help\npython tutorials/examples/train_hypergrid_simple_ls.py --help\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FGFNOrg%2Ftorchgfn","html_url":"https://awesome.ecosyste.ms/projects/github.com%2FGFNOrg%2Ftorchgfn","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FGFNOrg%2Ftorchgfn/lists"}