{"id":27267848,"url":"https://github.com/lamm-mit/Graph-Aware-Transformers","last_synced_at":"2025-04-11T10:02:28.091Z","repository":{"id":271031257,"uuid":"911611090","full_name":"lamm-mit/Graph-Aware-Transformers","owner":"lamm-mit","description":"Graph-Aware Attention for Adaptive Dynamics in Transformers","archived":false,"fork":false,"pushed_at":"2025-01-08T00:13:24.000Z","size":149,"stargazers_count":55,"open_issues_count":2,"forks_count":5,"subscribers_count":2,"default_branch":"main","last_synced_at":"2025-04-05T12:03:25.817Z","etag":null,"topics":["ai4science","attention-mechanism","graph","graph-aware","huggingface-transformers","language","llm-training","llms","materials-informatics","materials-science","materiomics"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/lamm-mit.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2025-01-03T12:29:24.000Z","updated_at":"2025-03-18T02:49:34.000Z","dependencies_parsed_at":"2025-01-04T23:25:00.032Z","dependency_job_id":"f8a332e5-df9d-45ba-86e7-b3f2a918b2a9","html_url":"https://github.com/lamm-mit/Graph-Aware-Transformers","commit_stats":null,"previous_names":["lamm-mit/graph-aware-transformers"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lamm-mit%2FGraph-Aware-Transformers","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lamm-mit%2FGraph-Aware-Transformers/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lamm-mit%2FGraph-Aware-Transformers/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lamm-mit%2FGraph-Aware-Transformers/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/lamm-mit","download_url":"https://codeload.github.com/lamm-mit/Graph-Aware-Transformers/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248372368,"owners_count":21093134,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["ai4science","attention-mechanism","graph","graph-aware","huggingface-transformers","language","llm-training","llms","materials-informatics","materials-science","materiomics"],"created_at":"2025-04-11T10:01:31.141Z","updated_at":"2025-04-11T10:02:28.082Z","avatar_url":"https://github.com/lamm-mit.png","language":"Python","funding_links":[],"categories":["Representation Learning"],"sub_categories":[],"readme":"# Graph-Aware Isomorphic Attention for Adaptive Dynamics in Transformers\n\nWe present an approach to enhancing Transformer architectures by integrating graph-aware relational reasoning into their attention mechanisms. Building on the inherent connection between attention and graph theory, we reformulate the Transformer’s attention mechanism as a graph operation and propose Graph-Aware Isomorphic Attention. This method leverages advanced graph modeling strategies, including Graph Isomorphism Networks (GIN) and Principal Neighborhood Aggregation (PNA), to enrich the representation of relational structures. Our approach improves the model’s ability to capture complex dependencies and generalize across tasks, as evidenced by a reduced generalization gap and improved learning performance. \n\nAdditionally, we expand the concept of graph-aware attention to introduce Sparse GIN-Attention, a fine-tuning approach that employs sparse GINs. By interpreting attention matrices as sparse adjacency graphs, this technique enhances the adaptability of pre-trained foundational models with minimal computational overhead, endowing them with graph-aware capabilities. Across our experiments, our results demonstrate that graph-aware attention mechanisms outperform traditional attention in both training efficiency and validation performance. Furthermore, sparse GIN fine-tuning achieves improved training dynamics and better generalization compared to conventional methods like LoRA. These insights not only bridge graph theory and Transformer architectures but also uncover latent graph-like structures within traditional attention mechanisms, offering a new lens through which Transformers can be understood and optimized. \n\nBy evolving Transformers as hierarchical GIN models, we reveal their implicit capacity for graph-level relational reasoning. This perspective suggests profound implications for foundational model development, enabling the design of architectures that dynamically adapt to both local and global dependencies. Applications in bioinformatics, materials science, language modeling, and beyond could benefit from this synthesis of relational and sequential data modeling, setting the stage for interpretable and generalizable modeling strategies.\n\n![image](https://github.com/user-attachments/assets/02c9b587-73f0-4293-84f8-574bc2e9018c)\n\nFigure 1: Encoder-only transformer architecture (panel A), adapted here by using a GNN-based self-attention mechanism with a graph neural network. As another variant (panel B) suitable for fine-tuning a pre-trained model akin to a LoRA model, we introduce Sparse-GIN, an option where we retain the adjacency matrix predicted by the pretrained model but instead use it to construct a sparse adjacency matrix.\n\n![image](https://github.com/user-attachments/assets/5c15d37d-c693-453d-822a-97a36d4c9b8b)\n\nFigure 2: Visualization of adjacency matrices and interpretation of corresponding causal graphs. Panel A: Visual representation of an adjacency matrix for one specific layer and one head, extracted from a pretrained model. Panel B, left shows a large-scale adjacency matrix, where interaction strengths are color-coded, with annotations highlighting specific points of interest. Panel B, right displays the corresponding causal graph, illustrating directional relationships between nodes based on the adjacency matrix.  These visualizations provide insights into the structural and causal relationships encoded in the adjacency matrices.\n\n## Installation\n\n#### Install PyTorch first\n\nIt is recommended to first install PyTorch separately so that ```torch_scatter``` is installed correctly.\n\nNote, for PyTorch installation, check https://pytorch.org/get-started/locally/ for details:\n\n```bash\npip3 install torch torchvision torchaudio\n```\n\n#### Install directly from GitHub via pip\n\n```bash\nconda create -n xgpt_env python=3.11 -y\nconda activate xgpt_env\n\npip install git+https://github.com/lamm-mit/Graph-Aware-Transformers.git#egg=xgpt\n```\n\n#### Clone repository and install as editable library\n\n```bash\nconda create -n xgpt_env python=3.11 -y\nconda activate xgpt_env\n\ngit clone https://github.com/lamm-mit/Graph-Aware-Transformers.git\ncd Graph-Aware-Transformers\n\npip install -e .\n```\n\nAdditional details of the code and algorithms, including experimental features such as coarse-grained (CG) latent representations, can be found [here](STRUCTURE.md). \n\n\n#### Import the library for use in Python\n```python\nfrom xgpt import *\n```\n\nDetailed examples on how to set up/train models are included below. \n\n## Create a GIN-Transformer Model from Scratch\n\nHere we show how to create a GIN-Transformer model from scratch. We use the ```meta-llama/Meta-Llama-3-8B-Instruct``` model as source for basic model hyperparameters (not weights, however). \n\n- Step 1: Load dataset (necessary to train custom tokenizer)\n- Step 2: Train tokenizer\n- Step 3: Setup GIN-Transformer model\n- Step 4: Train model\n\n### Load dataset and train tokenizer\n\n#### Load dataset\n```python\nfrom datasets import load_dataset\n\ndataset = load_dataset(\"lamm-mit/protein_secondary_structure_from_PDB\")\nmax_length=256\n\ndataset = dataset.filter(lambda example: example['Sequence_length'] \u003c max_length)\ndataset = dataset['train'] \nsplit_dataset = dataset.train_test_split(test_size=0.1, seed=42)\n\n# Access the new splits\ntrain_dataset = split_dataset['train']\ntest_dataset = split_dataset['test']\n\n# Verify the sizes of the new datasets\nprint(f\"Training set size: {len(train_dataset)}\")\nprint(f\"Test set size: {len(test_dataset)}\")\n\n# Apply chat template\ndef format_data(example):\n    '''\n    example[\"text\"] = tokenizer.apply_chat_template(\n        [{\"role\": \"user\", \"content\": example[\"question\"]}, {\"role\": \"assistant\", \"content\": example[\"answer\"]}],\n        tokenize=False, add_generation_prompt=False\n    )\n    '''\n    example[\"text\"] =f\"\u003c|begin_of_text|\u003e\u003c|sequence|\u003e{example['Sequence']}\u003c|/sequence|\u003e\u003c|{example['Primary_SS_Type']}|\u003e\u003c|{example['Secondary_SS_Type']}|\u003e\u003c|eot_id|\u003e\"\n    \n    return example\n\ntrain_dataset = train_dataset.map(format_data,remove_columns=train_dataset.column_names)\ntest_dataset = test_dataset.map(format_data, remove_columns=test_dataset.column_names)\n```\n\n#### Train tokenizer\n\n```python\nfrom xgpt import *\n\n# Train the tokenizer\ntexts = train_dataset['text']\ntokenizer = train_tokenizer_from_scratch(texts, vocab_size=20, special_tokens = [\n    \"\u003c|pad|\u003e\",\n    \"\u003c|eot_id|\u003e\", \n    \"\u003c|begin_of_text|\u003e\",\n    \"\u003c|unk|\u003e\",\n    \"\u003c|mask|\u003e\",\n    \"\u003c|sequence|\u003e\",\n    \"\u003c|/sequence|\u003e\",\n    # Single-letter amino acid codes\n    \"A\", \"R\", \"N\", \"D\", \"C\", \"E\", \"Q\", \"G\", \"H\", \"I\",\n    \"L\", \"K\", \"M\", \"F\", \"P\", \"S\", \"T\", \"W\", \"Y\", \"V\",\n    # Additional special words\n    \"\u003c|AH|\u003e\", \"\u003c|BS|\u003e\", \"\u003c|UNSTRUCTURED|\u003e\", \"\u003c|BEND|\u003e\", \"\u003c|PHIHELIX|\u003e\", \"\u003c|310HELIX|\u003e\", \"\u003c|BETABRIDGE|\u003e\", \"\u003c|T|\u003e\",\n]\n)\n\n# Save the trained tokenizer\ntokenizer.save_pretrained(\"./custom_tokenizer\")\n\n# Test with various scenarios\ntest_cases = [\n    \"\u003c|begin_of_text|\u003e\u003c|sequence|\u003eA A A I\u003c|/sequence|\u003e\",  # Simple space-separated\n    \"\u003c|begin_of_text|\u003e\u003c|sequence|\u003eAAAIAIIAJ\u003c|/sequence|\u003e\",  # Simple space-separated\n    \"Hello World!\",  # With punctuation\n    \"Test   Multiple   Spaces\",  # Multiple spaces\n    \"NoSpaces\",  # No spaces\n    \"123.456\",  # Numbers\n    \"user@email.com\",  # Special characters\n    \"Mixed12345Chars!@#\",  # Mixed content\n]\n\nprint(\"Testing tokenizer:\")\nfor test in test_cases:\n    encoded = tokenizer.encode(test, add_special_tokens=False)\n    decoded = tokenizer.decode(encoded)\n    print(f\"\\nOriginal: {repr(test)}\")\n    print(f\"Encoded : {encoded}\")\n    print(f\"Decoded : {repr(decoded)}\")\n    \n# Print vocabulary info\nprint(f\"\\nVocabulary size: {len(tokenizer)}\")\nprint(f\"Special tokens: {tokenizer.special_tokens_map}\")\n\n# tokenizer.padding_side,    tokenizer.pad_token\n```\nYou can also push the tokenizer to the hub:\n```python\ntokenizer.push_to_hub ('lamm-mit/custom_GIN_Attention_tokenizer')\n```\n\n#### Create GIN model\n```python\n#Load Graph-Aware Transformer library\nfrom xgpt import *\n\nfrom transformers import set_seed\nset_seed(42)\n\n# Load Pretrained LLaMA Configuration on which model will be based\npretrained_model_name = 'meta-llama/Meta-Llama-3-8B-Instruct'\n\n# Adapt these parameters to whatever your choices are, e.g. change number of heads, head_dim, number of layers, etc. \ntransformer_config = LlamaConfig.from_pretrained(pretrained_model_name)\ntransformer_config.num_attention_heads=8\ntransformer_config.num_key_value_heads=transformer_config.num_attention_heads\ntransformer_config.head_dim=70\ntransformer_config.hidden_size=transformer_config.head_dim*transformer_config.num_attention_heads \ntransformer_config.intermediate_size=512 #ALT: 4*transformer_config.hidden_size\ntransformer_config.num_hidden_layers=6\ntransformer_config.torch_dtype='bfloat16'\ntransformer_config.vocab_size=tokenizer.vocab_size\ntransformer_config._attn_implementation='eager' \n\ngnn_config = GNNConfig(\n    num_layers=1,        \n    activation=\"relu\", #\"prelu\"\n    dropout=0.1,\n    lambda_GNN=1,\n    norm_to_hidden_states=False,\n    use_layer_norm=False,  \n    combined_norm=False,\n    rms_norm_eps=1e-5,\n    hidden_dim=transformer_config.hidden_size,\n    learnable_aggregate_activation ='softmax', #\n    gnn_mode='none',\n    \n    ### Set type of GNN-Attention you want to create\n    #use_GNN_from_attention='LlamaAttentionPNA',\n    use_GNN_from_attention='LlamaAttentionGIN',    \n\n    attention_GIN_MLP_GIN_use_softmax=True,\n    attention_GIN_MLP_use_scoring_fnct=False, #standard attn\n    attention_GIN_MLP_multiplier = 0.5, #1, 2, 4, ...  \n    \n    use_sharpening=True, sharpening_value_init='value', initial_sharpening_value=1.0,\n\n    attention_GIN_MLP_o_proj_at_end=False, \n\n    use_differential_attention = False,\n\n    ### Set transformer FF MLP type - here you can change the transformer FF type if needed\n    MLP_type='standard_MLP', #'linear_MLP' 'no_MLP' 'shallow_MLP'\n)\n\nmodel_with_gnn  = load_model_with_pretrained_transformer( gnn_config, transformer_config, \n                                torch_dtype='bfloat16',\n                                pretrained_model_name = None, attn_implementation='eager',# 'flash_attention_2' #'eager'\n                                )\n\n# Move to appropriate device (if necessary)\nmodel_with_gnn.to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\ncount_trainable_parameters(model_with_gnn)\n```\n#### Train model\n\nOnce we loaded the training data and created the model, we train the model, like so: \n\n```python\nfrom trl import SFTConfig, SFTTrainer\nfrom transformers import TrainingArguments, DataCollatorForSeq2Seq, TrainerCallback\n\nsample_steps    = 100\nmax_seq_length  = 300\n\nclass SampleGenerationCallback(TrainerCallback):\n    def __init__(self, model, tokenizer, prompts, max_tokens, temperature, sample_steps, test_dataset):\n        self.model = model\n        self.tokenizer = tokenizer\n        self.prompts = prompts\n        self.max_tokens = max_tokens\n        self.temperature = temperature\n        self.sample_steps = sample_steps\n        self.test_dataset = test_dataset\n        self.perplexity_scores = []\n        self.test_scores = []\n        self.trainable_scale_history = []\n        self.loss_fct = CrossEntropyLoss(reduction='none')  # Changed to 'none' for per-token loss\n\n    def on_step_end(self, args, state, control,\n                    log_trainable_scale_values=True,\n                    **kwargs):\n        if state.global_step % self.sample_steps == 0:\n            print(f\"\\n[Sample Generation at Step {state.global_step}]\")\n            for item in self.prompts:\n                res=perform_inference(self.model, self.tokenizer, \n                                  prompt=item, \n                                  max_tokens=self.max_tokens, \n                                  temperature=self.temperature)[0]\n                print (\"QUESTION: \", item)\n                print (\"RESPONSE: \", res)\n                \n            try:\n                # Log trainable_scale values\n                if log_trainable_scale_values:\n                    layer_scales = []\n                    total_scale = 0\n                    num_layers = len(self.model.model.layers)\n    \n                    for layer_idx, layer in enumerate(self.model.model.layers):\n                        trainable_scale_value = layer.trainable_scale.item()\n                        layer_scales.append(trainable_scale_value)\n                        total_scale += trainable_scale_value\n    \n                    average_trainable_scale = total_scale / num_layers\n                    self.trainable_scale_history.append((state.global_step, layer_scales, average_trainable_scale))\n                    print(f\"Average trainable_scale at step {state.global_step}: {average_trainable_scale}\")\n\n            except Exception as e:\n                raise OSError(f\"Error: {str(e)}\")\n\nsample_generation_callback = SampleGenerationCallback(\n    model=model_with_gnn,\n    tokenizer=tokenizer,\n    prompts=[\n             test_dataset['text'][0][:-60],\n             test_dataset['text'][10][:-60]\n             #...\n            ],\n\n    max_tokens=128,\n    temperature=0.1,\n    sample_steps=sample_steps,\n    test_dataset=test_dataset,    \n)\n\n# Training arguments and initialization remain the same\ntraining_args = SFTConfig(\n    output_dir=\"./results_output\",\n    eval_strategy=\"steps\",\n    eval_steps=sample_steps,\n    learning_rate=1e-4, #1e-4,\n    per_device_train_batch_size=8,\n    per_device_eval_batch_size=4,\n    gradient_accumulation_steps=4,\n    num_train_epochs=9,\n    weight_decay=0.01,\n    logging_dir=\"./logs_output\",\n    lr_scheduler_type=\"constant\", #'cosine'\n    max_seq_length=max_seq_length,\n    logging_steps=sample_steps,\n    warmup_steps=250,\n    dataset_text_field=\"text\",\n    packing=False,\n    max_grad_norm=1,\n    report_to='none',\n    save_strategy='no', #'epoch',\n    do_eval=True,\n)\n\ntrainer = SFTTrainer(\n    model=model_with_gnn,\n    args=training_args,\n    train_dataset=train_dataset,\n    eval_dataset=test_dataset,\n    tokenizer=tokenizer,\n    callbacks=[sample_generation_callback],\n)\n\n# Train\ntrainer.train()\n```\n\nYou can save/push the model like so:\n```python\nmodel_with_gnn.push_to_hub ('lamm-mit/GIN-Transformer-Model')\ntokenizer.push_to_hub ()\n```\n![image](https://github.com/user-attachments/assets/42c3c673-58e6-4595-b4f6-4e94641d7431)\n\nFigure 3: Construction of the GIN-Attention mechanism. The flowchart shows how input embeddings in the hidden states in each layer in the  transformer via self-attention are used to construct the attention matrix. The output is processed further before aggregation and GIN-MLP application. \n\n## Create a Sparse-GIN Fine Tuning Model\n\nHere we show how to fine-tune a pre-trained Transformer model using the Sparse-GIN fine-tuning method. We use the ```meta-llama/Llama-3.2-3B-Instruct``` model as pre-trained model. \n\n- Step 1: Load dataset  \n- Step 2: Create Sparse-GIN on top of pre-trained Llama model \n- Step 3: Train model\n\n#### Load dataset\n\n```python\nfrom datasets import load_dataset\n\ndataset = load_dataset(\"mlabonne/orca-math-word-problems-80k\")\ndataset = dataset['train']\n\nsplit_dataset = dataset.train_test_split(test_size=0.1, seed=42)\n\n# Access the new splits\ntrain_dataset = split_dataset['train']\ntest_dataset = split_dataset['test']\n \n# Apply chat template\ndef format_data(example):\n    '''\n    example[\"text\"] = tokenizer.apply_chat_template(\n        [{\"role\": \"user\", \"content\": example[\"question\"]}, {\"role\": \"assistant\", \"content\": example[\"answer\"]}],\n        tokenize=False, add_generation_prompt=False\n    )\n    '''\n    example[\"text\"] =f\"### User: {example['question']}\u003c|eot_id|\u003e### Assistant: {example['answer']}\u003c|eot_id|\u003e\" \n    \n    return example\n \ncolumns_to_remove = [col for col in train_dataset.column_names if col != \"text\"]\n\ntrain_dataset = train_dataset.map(format_data, remove_columns=columns_to_remove)\ntest_dataset = test_dataset.map(format_data, remove_columns=columns_to_remove)\n\n# Verify the sizes of the new datasets\nprint(f\"Training set size: {len(train_dataset)}\")\nprint(f\"Test set size: {len(test_dataset)}\")\n```\n\n#### Create Sparse-GIN model on top of pre-trained LLM\n\n```python\n#Load Graph-Aware Transformer library\nfrom xgpt import *\n\nfrom transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizerFast\n\npretrained_model_name = \"meta-llama/Llama-3.2-3B-Instruct\"\n\ntransformer_config = LlamaConfig.from_pretrained(pretrained_model_name)\n\ntokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)\n\ntokenizer.padding_side='right'\ntokenizer.pad_token = \"\u003c|finetune_right_pad_id|\u003e\"# tokenizer.eos_token\ntokenizer.pad_token,tokenizer.pad_token_id\n\n# Define Sparse-GIN Configuration  \n \ngnn_config = GNNConfig(\n    num_layers=1,        \n    activation=\"prelu\", #\"relu\" \n    dropout=0.1,\n    lambda_GNN_initial = 0.,\n    lambda_GNN=0.5,\n    norm_to_hidden_states=False,\n    use_layer_norm=True, \n    combined_norm=False,\n    rms_norm_eps=1e-5,\n    hidden_dim=155,\n\n    ### GIN type/approach\n    gnn_type='causal_gin', \n    gnn_mode='single', #one GIN, not separate per head\n    GIN_use_MLP = True, \n    GIN_hidden_dim_multiplier = 1, # MLP hidden dimension in the GIN\n\n    ### Parameters for adjacency processing\n    adj_construction_method='sum', #sum all per-head adj matrices, clamped at 1.0    \n    continuous_transform_alpha = 10.0, threshold = 0.1,   \n    epsilon_threshold = 0.6, zero_below_epsilon_threshold = True, # All edges below threshold are set to zero\n    remove_self_connections = False, \n    GIN_use_norm = False, \n    GIN_edge_weight_scaling = True, # Scale graph edges based on adjacency matrix derived from attention weights\n\n    gnn_residual = False, \n    \n    plot_for_debugging=False,\n\n    gnn_logic='before_MLP', #'after_MLP' 'parallel_GNN_MLP',\n)\n\ntransformer_config._attn_implementation='eager' \n\nmodel_with_gnn  = load_model_with_pretrained_transformer ( gnn_config, transformer_config, \n                               pretrained_model_name = pretrained_model_name,\n                               attn_implementation='eager',\n                               )\ncount_trainable_parameters(model_with_gnn)\n\n# Move to appropriate device (if necessary)\nmodel_with_gnn.to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n \ntransformer_config = LlamaConfig.from_pretrained(pretrained_model_name)\n\nfreeze_except_select(model_with_gnn, unfreeze_keywords=['gnn', \n                                                        'trainable_scale',\n                                                        'gnn_norm',\n                                                        'combined_norm'\n                                                       ], \n                                                       verbose=True)\n\ncount_trainable_parameters(model_with_gnn)\n```\n\n#### Training\n```python\nfrom trl import SFTConfig, SFTTrainer\nfrom transformers import TrainingArguments, DataCollatorForSeq2Seq, TrainerCallback\n\nsample_steps    = 1000\nmax_seq_length  = 1024\n\n# Training arguments and initialization remain the same\ntraining_args = SFTConfig(\n    output_dir=\"./output_dir/\",\n    eval_strategy=\"epoch\", #\"steps\",\n    eval_steps=sample_steps,\n    learning_rate=2e-4, \n    per_device_train_batch_size=1,\n    per_device_eval_batch_size=2,\n    gradient_accumulation_steps=4,\n    num_train_epochs=3,\n    weight_decay=0.01,\n    logging_dir=\"./logging_dir/\",\n    lr_scheduler_type=\"constant\", #\"cosine\"\n    max_seq_length=max_seq_length,\n    logging_steps=sample_steps,\n    warmup_steps=50,\n    dataset_text_field=\"text\",\n    packing=False,\n    max_grad_norm=0.5,\n    report_to='none',\n    save_strategy='no',\n    do_eval=True,\n)\n\ntrainer = SFTTrainer(\n    model=model_with_gnn,\n    args=training_args,\n    train_dataset=train_dataset,\n    eval_dataset=test_dataset,\n    tokenizer=tokenizer,\n    #callbacks=[sample_generation_callback],\n)\n\n# Train\ntrainer.train()\n```\n![image](https://github.com/user-attachments/assets/ba3ed75f-949f-4638-bcd8-f2474e6e3df4)\n\nFigure 4: Trainable scale parameter over all layers in the model, plotted over all epochs, for the Sparse-GIN model. The trainable scale parameter delineates the relative importance of the sparse GIN as it is added to the original signal. The plot illustrates how the scale parameter evolves over both the layer index and the epoch fraction. Early in training, higher layers exhibit stronger scaling values, indicating a higher reliance on sparse GIN adjustments. As training progresses, the scaling values stabilize, suggesting convergence in the relative importance of the sparse GIN contributions across layers. The color gradient reflects the magnitude of the scale parameter, with warmer colors (red) indicating higher values and cooler colors (blue) indicating lower values. This visualization provides insights into the adaptive behavior of the trainable scale parameter over the course of training.\n\n#### Model weights and other datasets\n\nA trained model can be found at the LAMM Hugging Face organization. \n\n1) ORCA math word problems: [lamm-mit/Llama-3.2-3B-Instruct-Sparse-GIN-orca-math-word-problems](https://huggingface.co/lamm-mit/Llama-3.2-3B-Instruct-Sparse-GIN-orca-math-word-problems)\n2) Reasoning and logic: [lamm-mit/Llama-3.2-3B-Instruct-Sparse-GIN-logic](https://huggingface.co/lamm-mit/Llama-3.2-3B-Instruct-Sparse-GIN-logic)\n3) Bio-inspired materials and technology: [lamm-mit/Llama-3.2-3B-Instruct-Sparse-GIN-bio](https://huggingface.co/lamm-mit/Llama-3.2-3B-Instruct-Sparse-GIN-bio)\n\nExample task:\n\n```raw\nExplain the concept of hierarchical structures in materials science, especially relating proteins and music for use as optical device. Think step by step.\n```\n\n```Llama-3.2-3B-Instruct-Sparse-GIN-bio``` model output:\n\n```raw\nHierarchical structures are complex arrangements of materials at multiple scales, from the atomic to the macroscopic level.In materials science, these structures are often found in biological materials, such as proteins, which exhibit a hierarchical organization of their molecular building blocks. This hierarchical structure allows for the creation of materials with unique properties and functions.\n\nOne example of hierarchical structures in materials science is the use of proteins as optical devices. Proteins can be engineered to have specific optical properties, such as refractive index and birefringence, which can be used to create optical devices with unique functions. For example, proteins can be used to create optical fibers with high refractive index and low birefringence, which can be used for high-speed data transmission.\n\nThe hierarchical structure of proteins allows for the creation of materials with unique properties and functions. At the atomic level, proteins are composed of amino acids, which are arranged in a specific sequence to form a three-dimensional structure. This structure is then organized into larger units, such as domains and fibrils, which are arranged in a hierarchical manner.\n\nThe hierarchical structure of proteins also allows for the creation of materials with specific optical properties. For example, the refractive index of a protein can be controlled by adjusting the sequence of amino acids and the arrangement of the protein's molecular building blocks. Similarly, the birefringence of a protein can be controlled by adjusting the orientation of the protein's molecular building blocks.\n\nThe use of proteins as optical devices is a promising area of research, with potential applications in fields such as telecommunications, optics, and biomedicine. By engineering proteins with specific optical properties, researchers can create materials with unique functions and properties, such as high-speed data transmission and optical sensing.\n\nIn conclusion, hierarchical structures are a key feature of materials science, particularly in the field of protein-based materials. The hierarchical organization of proteins allows for the creation of materials with unique properties and functions, such as optical devices with high refractive index and low birefringence. The use of proteins as optical devices is a promising area of research, with potential applications in fields such as telecommunications, optics, and biomedicine.\u003c|eot_id|\u003e\n```\n\n## Reference\n\n```bibtex\n@article{Buehler2025GraphAwareGPT,\n      title={Graph-Aware Isomorphic Attention for Adaptive Dynamics in Transformers}, \n      author={Markus J. Buehler},\n      year={2025},\n      eprint={2501.02393},\n      archivePrefix={arXiv},\n      primaryClass={cs.LG},\n      url={https://arxiv.org/abs/2501.02393}, \n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flamm-mit%2FGraph-Aware-Transformers","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Flamm-mit%2FGraph-Aware-Transformers","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flamm-mit%2FGraph-Aware-Transformers/lists"}