Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/zetane/torch-xai
ZetaForge XAI Torch Pipeline
https://github.com/zetane/torch-xai
Last synced: 41 minutes ago
JSON representation
ZetaForge XAI Torch Pipeline
- Host: GitHub
- URL: https://github.com/zetane/torch-xai
- Owner: zetane
- License: agpl-3.0
- Created: 2024-07-11T22:04:21.000Z (6 months ago)
- Default Branch: main
- Last Pushed: 2024-07-12T14:42:00.000Z (6 months ago)
- Last Synced: 2024-11-10T14:08:26.503Z (about 2 months ago)
- Language: Python
- Size: 58.6 KB
- Stars: 0
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# torch-xai
## ZetaForge XAI Torch Pipeline
![image (32)](https://github.com/user-attachments/assets/34127c39-5e85-4d90-8c0f-99a34e119345)
Block input details:
- test_dataset_file : zip file contaning all the images and xlsx dataframe, where image column name must be "filename".
- model_process_file: pickle file containing model and the pre-process code, see the code below for how to save the model and processor.
- model_architecture_type: three options we have "cnn", "ViT", "SwiT"
- target_layer: name of the target layer on which xai will be performed.
- saving_dir: name of the save dir.Here is the demo script for how to save model and processor.
```
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import pickle# Load the ResNet18 model
model = models.resnet50(pretrained=True)
model.eval()# Create the transform
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# Create a dictionary with the model and transform
data = {
'model': model,
'processor': transform
}# Save the dictionary to a pickle file
with open('resnet18_model_and_transform.pkl', 'wb') as f:
pickle.dump(data, f)
```