An open API service indexing awesome lists of open source software.

https://github.com/pbenner/pycoordinationnet

A neural network model for materials property prediction based on coordination environments
https://github.com/pbenner/pycoordinationnet

Last synced: 7 months ago
JSON representation

A neural network model for materials property prediction based on coordination environments

Awesome Lists containing this project

README

          

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from coordinationnet import CoordinationNet\n",
"from coordinationnet import CoordinationFeaturesData"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## Load MP Oxides data set\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"data = CoordinationFeaturesData.load('data/mpoxides.dill')"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## Create a new model instance\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model config:\n",
"-> composition : False\n",
"-> sites : False\n",
"-> sites_oxid : False\n",
"-> sites_ces : False\n",
"-> site_features : True\n",
"-> site_features_ces : True\n",
"-> site_features_oxid : True\n",
"-> site_features_csms : True\n",
"-> site_features_ligands: False\n",
"-> ligands : False\n",
"-> ce_neighbors : False\n",
"\n",
"Creating a transformer model with 12,277,945 parameters\n"
]
}
],
"source": [
"model = CoordinationNet(max_epochs = 10)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"## Training, predicting, and cross-validation\n",
"---"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 9: 100%|██████████| 31/31 [00:01<00:00, 22.34it/s, val_loss=0.192, train_loss=0.225]\n",
"Model config:\n",
"-> composition : False\n",
"-> sites : False\n",
"-> sites_oxid : False\n",
"-> sites_ces : False\n",
"-> site_features : True\n",
"-> site_features_ces : True\n",
"-> site_features_oxid : True\n",
"-> site_features_csms : True\n",
"-> site_features_ligands: False\n",
"-> ligands : False\n",
"-> ce_neighbors : False\n",
"\n",
"Creating a transformer model with 12,277,945 parameters\n"
]
},
{
"data": {
"text/plain": [
"{'best_val_error': 0.19231659173965454,\n",
" 'train_error': [0.8821187615394592,\n",
" 0.6863077282905579,\n",
" 0.6475048065185547,\n",
" 0.5556382536888123,\n",
" 0.48192858695983887,\n",
" 0.3381344974040985,\n",
" 0.32037481665611267,\n",
" 0.289189875125885,\n",
" 0.2524196207523346,\n",
" 0.2248552143573761],\n",
" 'val_error': [0.6633234024047852,\n",
" 0.6896966695785522,\n",
" 0.6661754846572876,\n",
" 0.5392132997512817,\n",
" 0.46571817994117737,\n",
" 0.32083022594451904,\n",
" 0.27424806356430054,\n",
" 0.2790601849555969,\n",
" 0.2448316216468811,\n",
" 0.20683790743350983,\n",
" 0.19600436091423035]}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.train(data)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicting DataLoader 0: 100%|██████████| 35/35 [00:00<00:00, 74.62it/s]\n"
]
},
{
"data": {
"text/plain": [
"tensor([[-2.3259],\n",
" [-3.2458],\n",
" [-3.7321],\n",
" ...,\n",
" [-2.1301],\n",
" [-2.1895],\n",
" [-1.7834]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.predict(data)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training fold 1/3...\n",
"Epoch 9: 100%|██████████| 21/21 [00:00<00:00, 21.63it/s, val_loss=0.182, train_loss=0.176]\n",
"Model config:\n",
"-> composition : False\n",
"-> sites : False\n",
"-> sites_oxid : False\n",
"-> sites_ces : False\n",
"-> site_features : True\n",
"-> site_features_ces : True\n",
"-> site_features_oxid : True\n",
"-> site_features_csms : True\n",
"-> site_features_ligands: False\n",
"-> ligands : False\n",
"-> ce_neighbors : False\n",
"\n",
"Creating a transformer model with 12,277,945 parameters\n",
"Testing DataLoader 0: 100%|██████████| 12/12 [00:00<00:00, 73.37it/s]\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" Test metric DataLoader 0\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" test_loss 0.16995078325271606\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
"Best validation score: 0.1810467690229416\n",
"Training fold 2/3...\n",
"Epoch 9: 100%|██████████| 21/21 [00:00<00:00, 21.10it/s, val_loss=0.167, train_loss=0.163]\n",
"Model config:\n",
"-> composition : False\n",
"-> sites : False\n",
"-> sites_oxid : False\n",
"-> sites_ces : False\n",
"-> site_features : True\n",
"-> site_features_ces : True\n",
"-> site_features_oxid : True\n",
"-> site_features_csms : True\n",
"-> site_features_ligands: False\n",
"-> ligands : False\n",
"-> ce_neighbors : False\n",
"\n",
"Creating a transformer model with 12,277,945 parameters\n",
"Testing DataLoader 0: 100%|██████████| 12/12 [00:00<00:00, 77.25it/s]\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" Test metric DataLoader 0\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" test_loss 0.17509198188781738\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
"Best validation score: 0.1643427014350891\n",
"Training fold 3/3...\n",
"Epoch 9: 100%|██████████| 21/21 [00:00<00:00, 22.19it/s, val_loss=0.158, train_loss=0.173]\n",
"Model config:\n",
"-> composition : False\n",
"-> sites : False\n",
"-> sites_oxid : False\n",
"-> sites_ces : False\n",
"-> site_features : True\n",
"-> site_features_ces : True\n",
"-> site_features_oxid : True\n",
"-> site_features_csms : True\n",
"-> site_features_ligands: False\n",
"-> ligands : False\n",
"-> ce_neighbors : False\n",
"\n",
"Creating a transformer model with 12,277,945 parameters\n",
"Testing DataLoader 0: 100%|██████████| 12/12 [00:00<00:00, 86.33it/s]\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" Test metric DataLoader 0\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
" test_loss 0.17175181210041046\n",
"────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n",
"Best validation score: 0.1518968790769577\n"
]
}
],
"source": [
"test_loss, y, y_hat = model.cross_validation(data, 3)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.1722649782896042"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "crysfeat",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}