{"id":22656526,"url":"https://github.com/mytechnotalent/ragma","last_synced_at":"2026-04-13T22:33:44.822Z","repository":{"id":266744574,"uuid":"899230955","full_name":"mytechnotalent/RAGMA","owner":"mytechnotalent","description":"An interactive, state-of-the-art Retrieval-Augmented Generation Medical Assistant that leverages deep learning and FAISS-based retrieval to provide accurate, context-aware answers to medical queries in real time.","archived":false,"fork":false,"pushed_at":"2024-12-05T21:29:07.000Z","size":4991,"stargazers_count":1,"open_issues_count":0,"forks_count":0,"subscribers_count":1,"default_branch":"main","last_synced_at":"2025-03-29T07:44:57.359Z","etag":null,"topics":["ai","artificial-intelligence","faiss","faiss-vector-database","health","health-check","healthcare","healthcare-analysis","healthcare-application","healthcare-datasets","medical","medical-application","medical-assistant","pytorch","rag","retrieval-augmented-generation"],"latest_commit_sha":null,"homepage":"","language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/mytechnotalent.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2024-12-05T21:26:21.000Z","updated_at":"2024-12-07T07:41:04.000Z","dependencies_parsed_at":"2024-12-05T22:36:17.548Z","dependency_job_id":null,"html_url":"https://github.com/mytechnotalent/RAGMA","commit_stats":null,"previous_names":["mytechnotalent/ragma"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mytechnotalent%2FRAGMA","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mytechnotalent%2FRAGMA/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mytechnotalent%2FRAGMA/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/mytechnotalent%2FRAGMA/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/mytechnotalent","download_url":"https://codeload.github.com/mytechnotalent/RAGMA/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":246156029,"owners_count":20732359,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["ai","artificial-intelligence","faiss","faiss-vector-database","health","health-check","healthcare","healthcare-analysis","healthcare-application","healthcare-datasets","medical","medical-application","medical-assistant","pytorch","rag","retrieval-augmented-generation"],"created_at":"2024-12-09T10:14:45.029Z","updated_at":"2026-04-13T22:33:44.781Z","avatar_url":"https://github.com/mytechnotalent.png","language":"Jupyter Notebook","readme":"# Retrieval-Augmented Generation Medical Assistant (RAGMA)\n\n### [dataset](https://www.kaggle.com/datasets/jpmiller/layoutlm)\n\nAuthor: [Kevin Thomas](mailto:ket189@pitt.edu)\n\nLicense: [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0)\n\n## RAG Model Training, Index Creation, and Inference: Explained\n\nThis notebook will help you understand three Python scripts designed to work with a Retrieval-Augmented Generation (RAG) model. These scripts are `train.py`, `create_faiss_index.py`, and `inference.py`. Each serves a distinct purpose in the RAG workflow: training the model, creating a searchable index, and running inference.\n\n---\n\n### 1. **What is RAG?**\nRetrieval-Augmented Generation (RAG) is a method in Natural Language Processing (NLP) where:\n- **Retrieval:** You fetch relevant information (documents, passages) based on a query.\n- **Generation:** You generate meaningful responses by combining the retrieved information with a generative model like T5.\n\n---\n\n### 2. **`train.py`**: Fine-Tuning the T5 Model\n\nThis script fine-tunes a pretrained T5 model on a dataset for retrieval-augmented tasks. Here’s what it does:\n\n#### a. **Dataset**\n- The dataset must include three columns: `query`, `context`, and `answer`.\n  - `query`: The question or input from the user.\n  - `context`: Supporting information for the query (in this case, the same as `answer` initially).\n  - `answer`: The correct response to the query.\n\n#### b. **Data Preprocessing**\nThe `preprocess_data` function:\n- Reads the dataset from a CSV file.\n- Renames columns (`question` to `query` and `answer` remains the same).\n- Creates a `context` column from the `answer`.\n\n#### c. **Model Fine-Tuning**\n- The T5 model (`T5ForConditionalGeneration`) is fine-tuned using the `RAGDataset` class.\n- The dataset is tokenized and padded/truncated to specified lengths.\n- Training occurs in multiple epochs, and the loss is minimized using the AdamW optimizer.\n\n#### d. **Output**\n- The fine-tuned model and tokenizer are saved to a directory (`rag_model` by default).\n\n#### Key Takeaways:\n- `train.py` ensures the T5 model learns to map queries to answers effectively by utilizing the provided context.\n- This is the **training phase** of the RAG pipeline.\n\n---\n\n### 3. **`create_faiss_index.py`**: Building the FAISS Index\n\nThis script focuses on creating an index for **retrieving** relevant contexts.\n\n#### a. **Dataset**\n- Similar to `train.py`, the script processes a dataset with `query`, `context`, and `answer` columns.\n- If the `context` column does not exist, it creates one from the `answer` column.\n\n#### b. **Embedding Generation**\n- Uses `SentenceTransformer` (model: `all-MiniLM-L6-v2`) to generate dense vector embeddings for the `context` column.\n- Each context is transformed into a numerical representation.\n\n#### c. **FAISS Index**\n- FAISS (Facebook AI Similarity Search) is a library for efficient similarity search.\n- The script:\n  - Creates an index using L2 (Euclidean) distance.\n  - Adds the embeddings of all contexts to the FAISS index.\n\n#### d. **Output**\n- The FAISS index is saved to a file (`context.index`).\n\n#### Key Takeaways:\n- `create_faiss_index.py` prepares the **retrieval mechanism** by indexing context embeddings.\n- This is the **retrieval phase** of the RAG pipeline.\n\n---\n\n### 4. **`inference.py`**: Running the RAG Model\n\nThis script ties everything together for inference, where we can query the model and receive a generated response.\n\n#### a. **Dataset**\n- Loads the dataset (`medquad.csv`).\n- Ensures the `context` column exists and uses the `answer` column if not.\n\n#### b. **Retrieval**\n- The FAISS index (`context.index`) is loaded to fetch the most relevant context for a given query.\n\n#### c. **Inference**\n- A query is passed through the retrieval step to find the closest context.\n- The fine-tuned T5 model uses the retrieved context to generate an answer.\n\n#### Key Takeaways:\n- `inference.py` demonstrates the **full RAG pipeline**, combining retrieval and generation.\n- This is the **inference phase** of the RAG workflow.\n\n---\n\n### 5. **Summary**\n\n#### Workflow:\n1. **Train (`train.py`)**:\n   - Fine-tune the T5 model on a dataset of queries, contexts, and answers.\n2. **Create Index (`create_faiss_index.py`)**:\n   - Generate embeddings for the contexts and build a FAISS index for retrieval.\n3. **Inference (`inference.py`)**:\n   - Retrieve the most relevant context using FAISS.\n   - Use the fine-tuned T5 model to generate answers based on the query and context.\n\n---\n\n### 6. **Analogy: A Smart Librarian**\nImagine you’re asking a librarian for help:\n1. **Train**: The librarian learns how to answer questions (fine-tuning T5).\n2. **Index**: The librarian organizes all the books efficiently (FAISS indexing).\n3. **Inference**: You ask the librarian a question, they find the relevant book (retrieval), and then summarize the answer for you (generation).\n\nThis combination of retrieval and generation makes RAG a powerful tool for tasks like Q\u0026A systems or chatbots.\n\n## Install Libraries\n\n\n```python\n!pip install pandas torch transformers scikit-learn tqdm faiss-cpu sentence-transformers\n```\n\n    Requirement already satisfied: pandas in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (2.2.2)\n    Requirement already satisfied: torch in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (2.5.1)\n    Requirement already satisfied: transformers in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (4.46.2)\n    Requirement already satisfied: scikit-learn in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (1.5.1)\n    Requirement already satisfied: tqdm in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (4.66.5)\n    Requirement already satisfied: faiss-cpu in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (1.9.0)\n    Requirement already satisfied: sentence-transformers in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (3.3.0)\n    Requirement already satisfied: numpy\u003e=1.26.0 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from pandas) (1.26.4)\n    Requirement already satisfied: python-dateutil\u003e=2.8.2 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from pandas) (2.9.0.post0)\n    Requirement already satisfied: pytz\u003e=2020.1 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from pandas) (2024.1)\n    Requirement already satisfied: tzdata\u003e=2022.7 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from pandas) (2023.3)\n    Requirement already satisfied: filelock in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from torch) (3.13.1)\n    Requirement already satisfied: typing-extensions\u003e=4.8.0 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from torch) (4.11.0)\n    Requirement already satisfied: networkx in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from torch) (3.3)\n    Requirement already satisfied: jinja2 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from torch) (3.1.4)\n    Requirement already satisfied: fsspec in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from torch) (2024.6.1)\n    Requirement already satisfied: setuptools in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from torch) (75.1.0)\n    Requirement already satisfied: sympy==1.13.1 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from torch) (1.13.1)\n    Requirement already satisfied: mpmath\u003c1.4,\u003e=1.1.0 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from sympy==1.13.1-\u003etorch) (1.3.0)\n    Requirement already satisfied: huggingface-hub\u003c1.0,\u003e=0.23.2 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from transformers) (0.26.2)\n    Requirement already satisfied: packaging\u003e=20.0 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from transformers) (24.1)\n    Requirement already satisfied: pyyaml\u003e=5.1 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from transformers) (6.0.1)\n    Requirement already satisfied: regex!=2019.12.17 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from transformers) (2024.9.11)\n    Requirement already satisfied: requests in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from transformers) (2.32.3)\n    Requirement already satisfied: safetensors\u003e=0.4.1 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from transformers) (0.4.5)\n    Requirement already satisfied: tokenizers\u003c0.21,\u003e=0.20 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from transformers) (0.20.3)\n    Requirement already satisfied: scipy\u003e=1.6.0 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from scikit-learn) (1.13.1)\n    Requirement already satisfied: joblib\u003e=1.2.0 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from scikit-learn) (1.4.2)\n    Requirement already satisfied: threadpoolctl\u003e=3.1.0 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from scikit-learn) (3.5.0)\n    Requirement already satisfied: Pillow in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from sentence-transformers) (10.4.0)\n    Requirement already satisfied: six\u003e=1.5 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from python-dateutil\u003e=2.8.2-\u003epandas) (1.16.0)\n    Requirement already satisfied: MarkupSafe\u003e=2.0 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from jinja2-\u003etorch) (2.1.3)\n    Requirement already satisfied: charset-normalizer\u003c4,\u003e=2 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from requests-\u003etransformers) (3.3.2)\n    Requirement already satisfied: idna\u003c4,\u003e=2.5 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from requests-\u003etransformers) (3.7)\n    Requirement already satisfied: urllib3\u003c3,\u003e=1.21.1 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from requests-\u003etransformers) (2.2.3)\n    Requirement already satisfied: certifi\u003e=2017.4.17 in /opt/anaconda3/envs/prod/lib/python3.12/site-packages (from requests-\u003etransformers) (2024.8.30)\n\n\n## Train\n\n\n```python\nimport sys\nimport pandas as pd\nimport torch\nfrom torch.utils.data import Dataset, DataLoader\nfrom transformers import T5Tokenizer, T5ForConditionalGeneration\nimport argparse\nfrom sklearn.model_selection import train_test_split\nimport torch.optim as optim\nfrom tqdm.auto import tqdm\n\n\nclass RAGDataset(Dataset):\n    \"\"\"\n    Custom Dataset class for loading query-context-answer pairs for training the RAG model.\n    This class handles tokenizing the data and preparing it for PyTorch's DataLoader.\n    \"\"\"\n    def __init__(self, dataframe, tokenizer, source_len, target_len):\n        \"\"\"\n        Initialize the dataset.\n        \n        Args:\n            dataframe (pd.DataFrame): The dataset containing query, context, and answer columns.\n            tokenizer (transformers.PreTrainedTokenizer): Tokenizer for encoding text.\n            source_len (int): Maximum length for the input sequence.\n            target_len (int): Maximum length for the target sequence.\n        \"\"\"\n        self.data = dataframe.reset_index(drop=True)\n        self.tokenizer = tokenizer\n        self.source_len = source_len\n        self.target_len = target_len\n        self.query = self.data['query']\n        self.context = self.data['context']\n        self.answer = self.data['answer']\n\n    def __len__(self):\n        \"\"\"\n        Returns:\n            int: The number of samples in the dataset.\n        \"\"\"\n        return len(self.data)\n\n    def __getitem__(self, idx):\n        \"\"\"\n        Retrieve a single data point from the dataset.\n        \n        Args:\n            idx (int): Index of the data point.\n        \n        Returns:\n            dict: Dictionary containing tokenized input and target sequences.\n        \"\"\"\n        query = str(self.query[idx])\n        context = str(self.context[idx])\n        answer = str(self.answer[idx])\n\n        # combine query and context into a single input string\n        source_text = f\"query: {query} context: {context}\"\n        \n        # tokenize the input string\n        source = self.tokenizer.encode_plus(\n            source_text, max_length=self.source_len, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n        # tokenize the answer string\n        target = self.tokenizer.encode_plus(\n            answer, max_length=self.target_len, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n\n        return {\n            \"input_ids\": source[\"input_ids\"].squeeze(),\n            \"attention_mask\": source[\"attention_mask\"].squeeze(),\n            \"labels\": target[\"input_ids\"].squeeze(),\n        }\n\n\ndef preprocess_data(file_path):\n    \"\"\"\n    Preprocess the dataset to include required columns and handle missing values.\n    \n    Args:\n        file_path (str): Path to the dataset CSV file.\n    \n    Returns:\n        pd.DataFrame: Preprocessed dataframe with 'query', 'context', and 'answer' columns.\n    \"\"\"\n    # load the CSV file\n    df = pd.read_csv(file_path)\n    \n    # retain only the 'question' and 'answer' columns\n    df = df[['question', 'answer']]\n    \n    # drop rows with missing values\n    df = df.dropna(subset=['question', 'answer'])\n    \n    # rename columns for consistency\n    df = df.rename(columns={'question': 'query', 'answer': 'answer'})\n    \n    # add a 'context' column (using the answer as context for now)\n    df['context'] = df['answer']\n    return df\n\n\ndef train_epoch(model, loader, optimizer, device, epoch, logging_steps):\n    \"\"\"\n    Train the model for one epoch.\n    \n    Args:\n        model (torch.nn.Module): The model being trained.\n        loader (DataLoader): DataLoader for the training data.\n        optimizer (torch.optim.Optimizer): Optimizer for updating model weights.\n        device (torch.device): Device to run the model on (CPU, GPU, etc.).\n        epoch (int): Current epoch number.\n        logging_steps (int): Frequency of logging progress during training.\n    \n    Returns:\n        float: The average training loss for the epoch.\n    \"\"\"\n    model.train()  # set the model to training mode\n    total_loss = 0  # initialize total loss\n    progress_bar = tqdm(loader, desc=f\"Epoch {epoch}\", disable=False)  # progress bar for tracking\n\n    for step, batch in enumerate(progress_bar):\n        # move inputs and labels to the specified device\n        input_ids = batch[\"input_ids\"].to(device)\n        attention_mask = batch[\"attention_mask\"].to(device)\n        labels = batch[\"labels\"].to(device)\n\n        # forward pass through the model\n        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)\n        loss = outputs.loss\n        total_loss += loss.item()\n\n        # backward pass and optimization step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # log the loss every `logging_steps`\n        if (step + 1) % logging_steps == 0:\n            progress_bar.set_postfix({\"loss\": loss.item()})\n\n    # return the average loss for the epoch\n    return total_loss / len(loader)\n\n\ndef main():\n    \"\"\"\n    Main function to fine-tune the T5 model for Retrieval-Augmented Generation (RAG).\n    This version handles Jupyter Notebook's extra arguments gracefully.\n    \"\"\"\n    # simulating command-line arguments for Jupyter Notebook\n    class Args:\n        model_name = \"t5-base\"\n        train_file = \"medquad.csv\"\n        output_dir = \"rag_model\"\n        batch_size = 8\n        epochs = 3\n        lr = 5e-5\n        max_input_length = 512\n        max_output_length = 150\n        device = \"mps\"\n        logging_steps = 10\n\n    args = Args()  # use the custom Args class to store arguments\n\n    # enhanced device selection logic\n    if args.device == \"mps\" and torch.backends.mps.is_available():\n        device = torch.device(\"mps\")\n    elif args.device == \"cuda\" and torch.cuda.is_available():\n        device = torch.device(\"cuda\")\n    else:\n        device = torch.device(\"cpu\")\n    print(f\"Using device: {device}\")\n\n    # preprocess the data\n    df = preprocess_data(args.train_file)\n\n    # split data into training and validation sets\n    train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)\n\n    # load the tokenizer and model\n    tokenizer = T5Tokenizer.from_pretrained(args.model_name, legacy=False)\n    model = T5ForConditionalGeneration.from_pretrained(args.model_name).to(device)\n\n    # create DataLoaders for training and validation datasets\n    train_dataset = RAGDataset(train_df, tokenizer, args.max_input_length, args.max_output_length)\n    val_dataset = RAGDataset(val_df, tokenizer, args.max_input_length, args.max_output_length)\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)\n\n    # optimizer\n    optimizer = optim.AdamW(model.parameters(), lr=args.lr)\n\n    # training loop\n    for epoch in range(1, args.epochs + 1):\n        train_loss = train_epoch(model, train_loader, optimizer, device, epoch, args.logging_steps)\n        print(f\"Epoch {epoch} Training Loss: {train_loss:.4f}\")\n\n    # save the model and tokenizer\n    model.save_pretrained(args.output_dir)\n    tokenizer.save_pretrained(args.output_dir)\n    print(f\"Model saved to {args.output_dir}\")\n```\n\n\n```python\nmain()\n```\n\n    Using device: mps\n\n\n\n    Epoch 1:   0%|          | 0/1846 [00:00\u003c?, ?it/s]\n\n\n    Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.\n\n\n    Epoch 1 Training Loss: 0.0865\n\n\n\n    Epoch 2:   0%|          | 0/1846 [00:00\u003c?, ?it/s]\n\n\n    Epoch 2 Training Loss: 0.0135\n\n\n\n    Epoch 3:   0%|          | 0/1846 [00:00\u003c?, ?it/s]\n\n\n    Epoch 3 Training Loss: 0.0099\n    Model saved to rag_model\n\n\n## Create Faiss Index\n\n\n```python\nimport pandas as pd\nimport faiss\nfrom sentence_transformers import SentenceTransformer\n\n# define parameters\ncsv_file = \"medquad.csv\"  # Path to your dataset\nupdated_csv_file = \"medquad_with_context.csv\"  # Output dataset path\nindex_file = \"context.index\"  # Path to save the FAISS index\n\n# load the dataset\ndf = pd.read_csv(csv_file)\n\n# add a 'context' column if it doesn't exist\nif 'context' not in df.columns:\n    print(\"No 'context' column found. Creating it from the 'answer' column.\")\n    if 'answer' not in df.columns:\n        raise ValueError(\"The dataset must have an 'answer' column to create the 'context'.\")\n    df['context'] = df['answer']  # Use 'answer' as the context\n\n# save the updated dataset with the 'context' column\ndf.to_csv(updated_csv_file, index=False)\nprint(f\"Updated dataset with 'context' column saved to {updated_csv_file}\")\n\n# use SentenceTransformer to generate embeddings for the context column\nembedder = SentenceTransformer(\"all-MiniLM-L6-v2\")\ncontexts = df[\"context\"].tolist()\ncontext_embeddings = embedder.encode(contexts, convert_to_tensor=False).astype(\"float32\")\n\n# create a FAISS index for the embeddings\ndimension = context_embeddings.shape[1]\nindex = faiss.IndexFlatL2(dimension)  # L2 (Euclidean) distance\nindex.add(context_embeddings)\n\n# save the FAISS index\nfaiss.write_index(index, index_file)\nprint(f\"FAISS index saved to {index_file}\")\n```\n\n    No 'context' column found. Creating it from the 'answer' column.\n    Updated dataset with 'context' column saved to medquad_with_context.csv\n    FAISS index saved to context.index\n\n\n## Inference\n\n\n```python\nimport torch\nfrom transformers import T5Tokenizer, T5ForConditionalGeneration\nimport faiss\nimport pandas as pd\nfrom sentence_transformers import SentenceTransformer\n\n\ndef load_index(index_file, csv_file):\n    \"\"\"\n    Load the FAISS index and the associated dataset.\n\n    Args:\n        index_file (str): Path to the FAISS index file.\n        csv_file (str): Path to the CSV file containing the dataset.\n\n    Returns:\n        faiss.IndexFlatL2: The loaded FAISS index.\n        pd.DataFrame: The dataset containing queries, contexts, and answers.\n    \"\"\"\n    df = pd.read_csv(csv_file)  # Load the dataset\n    index = faiss.read_index(index_file)  # Load the FAISS index\n    return index, df\n\n\ndef retrieve_context(query, index, df, embedder, top_k=1):\n    \"\"\"\n    Retrieve the most relevant context(s) from the FAISS index based on the query.\n\n    Args:\n        query (str): The user's input question.\n        index (faiss.IndexFlatL2): The FAISS index for retrieval.\n        df (pd.DataFrame): The dataset to retrieve contexts from.\n        embedder (SentenceTransformer): The embedding model to encode the query.\n        top_k (int): Number of top contexts to retrieve.\n\n    Returns:\n        list: A list of retrieved contexts.\n    \"\"\"\n    query_vector = embedder.encode([query]).astype(\"float32\")  # Embed the query\n    distances, indices = index.search(query_vector, top_k)  # Search the FAISS index\n    return [df.iloc[i][\"context\"] for i in indices[0]]  # Retrieve contexts by index\n\n\n# simulated arguments for the Jupyter Notebook\nargs = {\n    \"query\": \"What are the symptoms of diabetes?\",\n    \"model_dir\": \"rag_model\",  # fine-tuned model directory\n    \"csv_file\": \"medquad_with_context.csv\",  # CSV file containing the dataset\n    \"index_file\": \"context.index\",  # FAISS index file\n    \"device\": \"mps\",  # device to run the inference on\n    \"top_k\": 1  # number of top contexts to retrieve\n}\n\n# enhanced device selection logic\nif args[\"device\"] == \"mps\" and torch.backends.mps.is_available():\n    device = torch.device(\"mps\")\nelif args[\"device\"] == \"cuda\" and torch.cuda.is_available():\n    device = torch.device(\"cuda\")\nelse:\n    device = torch.device(\"cpu\")\nprint(f\"Using device: {device}\")\n\n# load the fine-tuned model and tokenizer\ntokenizer = T5Tokenizer.from_pretrained(args[\"model_dir\"], legacy=False)\nmodel = T5ForConditionalGeneration.from_pretrained(args[\"model_dir\"]).to(device)\n\n# load the FAISS index and dataset\nembedder = SentenceTransformer(\"all-MiniLM-L6-v2\")  # use a sentence embedding model\nindex, df = load_index(args[\"index_file\"], args[\"csv_file\"])\n\n# retrieve the most relevant context for the input query\ncontexts = retrieve_context(args[\"query\"], index, df, embedder, args[\"top_k\"])\ninput_text = f\"query: {args['query']} context: {' '.join(contexts)}\"\n\n# generate the answer using the fine-tuned model\ninput_ids = tokenizer.encode(input_text, return_tensors=\"pt\").to(device)\ninput_length = len(input_ids[0])  # length of the input query + context\noutputs = model.generate(\n    input_ids,\n    max_length=input_length + 150,  # allow the model to generate a longer response\n    num_beams=5,\n    no_repeat_ngram_size=2\n)\nanswer = tokenizer.decode(outputs[0], skip_special_tokens=True)\nif \".\" in answer:\n    answer = answer[:answer.rfind(\".\") + 1]  # trim to the last full sentence\n\n# display the results\nprint(f\"Query: {args['query']}\")\nprint(f\"Retrieved Context: {' '.join(contexts)}\")\nprint(f\"Generated Answer: {answer}\")\n```\n\n    Using device: mps\n    Query: What are the symptoms of diabetes?\n    Retrieved Context: The signs and symptoms of diabetes are\n                    \n    - being very thirsty  - urinating often  - feeling very hungry  - feeling very tired  - losing weight without trying  - sores that heal slowly  - dry, itchy skin  - feelings of pins and needles in your feet  - losing feeling in your feet  - blurry eyesight\n                    \n    Some people with diabetes dont have any of these signs or symptoms. The only way to know if you have diabetes is to have your doctor do a blood test.\n    Generated Answer: The signs and symptoms of diabetes are - being very thirsty , urinating often – feeling very hungry ­ feeling extremely tired — losing weight without trying . sores that heal slowly : dry, itchy skin  feelings of pins and needles in your feet _ losing feeling in you feet- blurry eyesight Some people with diabetes dont have any of these signs or symptoms. The only way to know if you have diabetes is to have your doctor do a blood test.\n\n","funding_links":[],"categories":[],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmytechnotalent%2Fragma","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fmytechnotalent%2Fragma","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fmytechnotalent%2Fragma/lists"}