{"id":18549051,"url":"https://github.com/typoverflow/pytorch-crf","last_synced_at":"2025-04-09T21:32:31.075Z","repository":{"id":122589174,"uuid":"331645526","full_name":"typoverflow/pytorch-crf","owner":"typoverflow","description":"条件随机场（CRF）的pytorch实现","archived":false,"fork":false,"pushed_at":"2021-03-07T13:12:04.000Z","size":874,"stargazers_count":9,"open_issues_count":0,"forks_count":0,"subscribers_count":1,"default_branch":"master","last_synced_at":"2025-03-24T11:56:53.530Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":null,"status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/typoverflow.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2021-01-21T14:01:33.000Z","updated_at":"2024-01-13T12:17:59.000Z","dependencies_parsed_at":null,"dependency_job_id":"6158c591-f17c-47f0-a61e-7fc9caf87748","html_url":"https://github.com/typoverflow/pytorch-crf","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/typoverflow%2Fpytorch-crf","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/typoverflow%2Fpytorch-crf/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/typoverflow%2Fpytorch-crf/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/typoverflow%2Fpytorch-crf/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/typoverflow","download_url":"https://codeload.github.com/typoverflow/pytorch-crf/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248114858,"owners_count":21050130,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-11-06T20:37:52.027Z","updated_at":"2025-04-09T21:32:31.062Z","avatar_url":"https://github.com/typoverflow.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# pytorch-crf\n使用条件随机场（CRF）解决OCR任务的pytorch实现。\n\n## 算法描述\n接下来的推导中，数学符号的定义均基于《统计学习方法》11.2.3中的符号定义。具体而言，我们将所有的特征及其权值使用统一的符号表示，分别记为$f_k(y_{i-1}, y_i, x, i)$和$w_k$，其中$k=1,2,3,...,K$。\n\n由此定义全局特征向量$F(y, x)$和权重向量$w$\n$$F(y, x)=\\left(f_1(y, x), f_2(y, x), ...,f_K(y, x)\\right)^\\top$$\n$$w=\\left(w_1, w_2, ..., w_K\\right)^\\top$$\n\n于是条件随机场可表示为如下简化形式\n$$\n\\begin{aligned}\n    P(y|x)=\u0026=\\frac 1{Z(x)}\\exp(w^\\top F(y, x))\\\\\n    Z(x)\u0026=\\sum_{y'}\\exp(w^\\top F(y', x))\n\\end{aligned}\n$$\n\n### 对数似然与梯度\n给定数据集$\\mathcal{D}=\\{(x^{1}, y^{1}), ..., (x^{j}, y^{j}), ...(x^{J}, y^{J})\\}$，对数似然函数为\n$$\n\\begin{aligned}\n    LL(\\mathcal{D})\u0026= \\log \\prod_{j=1}^J P(y^{j}|x^{j})\\\\\n    \u0026=\\sum_{j=1}^Jw^\\top F(y^{j}, x^{j})-\\log Z(x^{j})\\\\\n\\end{aligned}\n$$\n对参数$w$求导，得到\n$$\n\\begin{aligned}\n    \\frac {\\partial LL(\\mathcal{D})}{\\partial w} \u0026= \\sum_{j=1}^J\\left(F(y^j, x^j)-\\frac 1{Z(x^j)}\\frac {\\partial Z(x^j)}{\\partial w}\\right)\\\\\n    \u0026=\\sum_{j=1}^J\\left(F(y^j, x^j)-\\frac 1{Z(x^j)}\\sum_{y'}\\exp(w^\\top F(y', x^j))F(y', x^j)\\right)\\\\\n    \u0026=\\sum_{j=1}^J\\left(F(y^j, x^j)-\\sum_{y'}P(y'|x^j)F(y', x^j)\\right)\\\\\n    \u0026=\\sum_{j=1}^J\\left(F(y^j, x^j)-\\mathbb{E}_{y'\\sim P(y'|x^j)}\\left[F(y', x^j)\\right]\\right)\\\\\n\\end{aligned}\n$$\n\n从上面的推导过程可以看出，对数似然对参数$w$的导数的方向，与样本特征函数与其特征函数期望之差的方向相同。在使用梯度上升算法对梯度进行更新后，特征函数的期望$\\mathbb{E}_{y'\\sim P(y'|x^j)}\\left[F(y', x^j)\\right]$会向$F(y^j, x^j)$靠近。\n\n### 学习算法\n本项目使用梯度上升算法对参数$w$进行优化。参数学习算法分为两个步骤，第一步是求解对数似然，第二步是计算梯度。\n  + **求解对数似然**\n    求解对数似然部分我使用的是《统计学习方法》一书11.3.1章的前向后向算法。假设序列长度为$n$，该算法首先需要计算概率矩阵$M_i, i=1, 2, ..., n+1$，然后定义前向向量$\\alpha_i(\\cdot|x)$，其第$i$个元素表示位置$i$的标记是$y_i$并且从$1$到$i$的前部分标记序列的非规范化概率。基于动态规划算法，可使用迭代式$\\alpha^\\top_i(\\cdot|x)=\\alpha^\\top_{i-1} (\\cdot|x)M_i$计算得到最后一个位置的前向向量$\\alpha_n(\\cdot|x)$。此时规范化因子可通过$Z(x)=\\mathbf{1}^\\top \\alpha_n(\\cdot|x)$求得。\n    求得$Z(x)$后，代入式(\\ref{eq0})即可得到对数似然函数。\n  + **计算梯度**\n    上面给出了一个求解梯度的计算式，但是式中需要对所有可能的$y'$进行遍历，这一操作的复杂度随序列长度成指数型增长。我在实际实现时使用了pytorch对参数$w$进行自动微分。\n\n### 模型解码\n+ 在求解得到模型参数后，可使用维特比算法求解给定观察$x$的最可能状态序列$y^*$。\n+ 这里用到的维特比算法与《统计学习方法》231-233页叙述的内容完全一致，在此仅作简要描述。\n+ 定义维特比变量$\\delta_i(l)$和备忘录变量$\\Phi_i(l)$为\n$$\n\\begin{aligned}\n    \\delta_i(l)=\\max_{1\\leq j\\leq m}\\{\\delta_{i-1}(j)+w^\\top F_i(y_{i-1}=j, y_i=l, x)\\}\\quad\\quad l=1,2,...,m\\\\\n    \\Phi_i(l) = \\arg\\max \\{\\delta_{i-1}(j)+w^\\top F_i(y_{i-1}=j, y_i=l, x)\\}\\quad\\quad l=1,2,...,m\\\\\n\\end{aligned}\n$$\n分别表示从第一个位置到位置$i$的各个标记$l=1,2,...,m$的非规范化概率的最大值和最大值路径。易见维特比变量可使用动态规划算法进行迭代求解。\n+ 求得最后一个位置的维特比变量后，只需根据最大的概率值回溯路径即可得到最可能路径。\n\n---\n\n## 项目描述\n目录中包含三个.py文件，其中utils.py定义数据加载相关的函数；feature\\_functions.py实现了三种特征函数，并为它们实现了统一的调用接口；CRF.py为主要模块，其中定义了用于实现OCR识别的类CRF\\_OCR。\n\n运行方式为：在根目录下，执行命令`python3 src/CRF.py`，即可开始加载数据集、训练并在测试集上进行测试。\n\n## 结果\n\u003ccenter\u003e\u003cimg src=\"./figs/fig2.png\"\u003e\u003c/center\u003e\n\n\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ftypoverflow%2Fpytorch-crf","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Ftypoverflow%2Fpytorch-crf","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Ftypoverflow%2Fpytorch-crf/lists"}