{"id":15032894,"url":"https://github.com/chenglongchen/tensorflow-deepfm","last_synced_at":"2025-04-08T00:38:14.373Z","repository":{"id":38290652,"uuid":"110348942","full_name":"ChenglongChen/tensorflow-DeepFM","owner":"ChenglongChen","description":"Tensorflow implementation of DeepFM for CTR prediction.","archived":false,"fork":false,"pushed_at":"2018-06-10T11:10:10.000Z","size":149,"stargazers_count":2045,"open_issues_count":46,"forks_count":808,"subscribers_count":66,"default_branch":"master","last_synced_at":"2025-04-08T00:37:52.813Z","etag":null,"topics":["click-through-rate","ctr","ctr-prediction","deep-ctr","deepfm","factorization-machine"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/ChenglongChen.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2017-11-11T13:49:34.000Z","updated_at":"2025-04-05T12:43:55.000Z","dependencies_parsed_at":"2022-07-12T02:02:13.765Z","dependency_job_id":null,"html_url":"https://github.com/ChenglongChen/tensorflow-DeepFM","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ChenglongChen%2Ftensorflow-DeepFM","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ChenglongChen%2Ftensorflow-DeepFM/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ChenglongChen%2Ftensorflow-DeepFM/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ChenglongChen%2Ftensorflow-DeepFM/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/ChenglongChen","download_url":"https://codeload.github.com/ChenglongChen/tensorflow-DeepFM/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247755560,"owners_count":20990620,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["click-through-rate","ctr","ctr-prediction","deep-ctr","deepfm","factorization-machine"],"created_at":"2024-09-24T20:19:41.122Z","updated_at":"2025-04-08T00:38:14.356Z","avatar_url":"https://github.com/ChenglongChen.png","language":"Python","readme":"# tensorflow-DeepFM\n\nThis project includes a Tensorflow implementation of DeepFM [1].\n\n# NEWS\n- A modified version of DeepFM is used to win the 4th Place for [Mercari Price Suggestion Challenge on Kaggle](https://www.kaggle.com/c/mercari-price-suggestion-challenge). See the slide [here](https://github.com/ChenglongChen/tensorflow-XNN/blob/master/doc/Mercari_Price_Suggesion_Competition_ChenglongChen_4th_Place.pdf) how we deal with fields containing sequences, how we incoporate various FM components into deep model.\n\n# Usage\n## Input Format\nThis implementation requires the input data in the following format:\n- [ ] **Xi**: *[[ind1_1, ind1_2, ...], [ind2_1, ind2_2, ...], ..., [indi_1, indi_2, ..., indi_j, ...], ...]*\n    - *indi_j* is the feature index of feature field *j* of sample *i* in the dataset\n- [ ] **Xv**: *[[val1_1, val1_2, ...], [val2_1, val2_2, ...], ..., [vali_1, vali_2, ..., vali_j, ...], ...]*\n    - *vali_j* is the feature value of feature field *j* of sample *i* in the dataset\n    - *vali_j* can be either binary (1/0, for binary/categorical features) or float (e.g., 10.24, for numerical features)\n- [ ] **y**: target of each sample in the dataset (1/0 for classification, numeric number for regression)\n\nPlease see `example/DataReader.py` an example how to prepare the data in required format for DeepFM.\n\n## Init and train a model\n```\nimport tensorflow as tf\nfrom sklearn.metrics import roc_auc_score\n\n# params\ndfm_params = {\n    \"use_fm\": True,\n    \"use_deep\": True,\n    \"embedding_size\": 8,\n    \"dropout_fm\": [1.0, 1.0],\n    \"deep_layers\": [32, 32],\n    \"dropout_deep\": [0.5, 0.5, 0.5],\n    \"deep_layers_activation\": tf.nn.relu,\n    \"epoch\": 30,\n    \"batch_size\": 1024,\n    \"learning_rate\": 0.001,\n    \"optimizer_type\": \"adam\",\n    \"batch_norm\": 1,\n    \"batch_norm_decay\": 0.995,\n    \"l2_reg\": 0.01,\n    \"verbose\": True,\n    \"eval_metric\": roc_auc_score,\n    \"random_seed\": 2017\n}\n\n# prepare training and validation data in the required format\nXi_train, Xv_train, y_train = prepare(...)\nXi_valid, Xv_valid, y_valid = prepare(...)\n\n# init a DeepFM model\ndfm = DeepFM(**dfm_params)\n\n# fit a DeepFM model\ndfm.fit(Xi_train, Xv_train, y_train)\n\n# make prediction\ndfm.predict(Xi_valid, Xv_valid)\n\n# evaluate a trained model\ndfm.evaluate(Xi_valid, Xv_valid, y_valid)\n```\n\nYou can use early_stopping in the training as follow\n```\ndfm.fit(Xi_train, Xv_train, y_train, Xi_valid, Xv_valid, y_valid, early_stopping=True)\n```\n\nYou can refit the model on the whole training and validation set as follow\n```\ndfm.fit(Xi_train, Xv_train, y_train, Xi_valid, Xv_valid, y_valid, early_stopping=True, refit=True)\n```\n\nYou can use the FM or DNN part only by setting the parameter `use_fm` or `use_dnn` to `False`.\n\n## Regression\nThis implementation also supports regression task. To use DeepFM for regression, you can set `loss_type` as `mse`. Accordingly, you should use eval_metric for regression, e.g., mse or mae.\n\n# Example\nFolder `example` includes an example usage of DeepFM/FM/DNN models for [Porto Seguro's Safe Driver Prediction competition on Kaggle](https://www.kaggle.com/c/porto-seguro-safe-driver-prediction).\n\nPlease download the data from the competition website and put them into the `example/data` folder.\n\nTo train DeepFM model for this dataset, run\n\n```\n$ cd example\n$ python main.py\n```\nPlease see `example/DataReader.py` how to parse the raw dataset into the required format for DeepFM.\n\n## Performance\n\n### DeepFM\n\n![dfm](example/fig/DeepFM.png)\n\n### FM\n\n![fm](example/fig/FM.png)\n\n### DNN\n\n![dnn](example/fig/DNN.png)\n\n## Some tips\n- [ ] You should tune the parameters for each model in order to get reasonable performance.\n- [ ] You can also try to ensemble these models or ensemble them with other models (e.g., XGBoost or LightGBM).\n\n# Reference\n[1] *DeepFM: A Factorization-Machine based Neural Network for CTR Prediction*, Huifeng Guo, Ruiming Tang, Yunming Yey, Zhenguo Li, Xiuqiang He.\n\n# Acknowledgments\nThis project gets inspirations from the following projects:\n- [ ] He Xiangnan's [neural_factorization_machine](https://github.com/hexiangnan/neural_factorization_machine)\n- [ ] Jian Zhang's [YellowFin](https://github.com/JianGoForIt/YellowFin) (yellowfin optimizer is taken from here)\n\n# License\nMIT","funding_links":[],"categories":[],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fchenglongchen%2Ftensorflow-deepfm","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fchenglongchen%2Ftensorflow-deepfm","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fchenglongchen%2Ftensorflow-deepfm/lists"}