贝叶斯算法的代码实现
贝叶斯分类算法是统计学的一种分类方法,它是一类利用概率统计知识进行分类的算法。在许多场合,朴素贝叶斯(Naïve Bayes,NB)分类算法可以与决策树和神经网络分类算法相媲美,该算法能运用到大型数据库中,而且方法简单、分类准确率高、速度快。
今天教大家如何用代码实现贝叶斯算法
所需jar包
创建目录存放训练语料
创建类BayesClassifier
朴素贝叶斯分类器, 利用样本数据集计算先验概率和各个文本向量属性在分类中的条件概率,从而计算出各个概率值,最后对各个概率值进行排序,选出最大的概率值,即为所属的分类。
public class BayesClassifier
{
private TrainingDataManager tdm;//训练集管理器
private String trainnigDataPath;//训练集路径
private static double zoomFactor = 10.0f;
默认的构造器,初始化训练集
package com.lzl;
import com.lzl.ChineseSpliter;
import com.lzl.ClassConditionalProbability;
import com.lzl.PriorProbability;
import com.lzl.TrainingDataManager;
import com.lzl.StopWordsHandler;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Vector;
public BayesClassifier()
{
tdm =new TrainingDataManager();
}
计算给定的文本属性向量X在给定的分类Cj中的类条件概率
ClassConditionalProbability
连乘值
@param X 给定的文本属性向量
@param Cj 给定的类别
@return 分类条件概率连乘值,即
float calcProd(String[] X, String Cj)
{
float ret = 1.0F;
// 类条件概率连乘
for (int i = 0; i <X.length; i++)
{
String Xi = X[i];
//因为结果过小,因此在连乘之前放大10倍,这对最终结果并无影响,因为我们只是比较概率大小而已
ret *=ClassConditionalProbability.calculatePxc(Xi, Cj)*zoomFactor;
}
// 再乘以先验概率
ret *= PriorProbability.calculatePc(Cj);
return ret;
}
去掉停用词
@param text 给定的文本
@return 去停用词后结果
public String[] DropStopWords(String[] oldWords)
{
Vector<String> v1 = new Vector<String>();
for(int i=0;i<oldWords.length;++i)
{
if(StopWordsHandler.IsStopWord(oldWords[i])==false)
{//不是停用词
v1.add(oldWords[i]);
}
}
String[] newWords = new String[v1.size()];
v1.toArray(newWords);
return newWords;
}
对给定的文本进行分类
@param text 给定的文本
@return 分类结果
@SuppressWarnings("unchecked")
public String classify(String text)
{
String[] terms = null;
terms= ChineseSpliter.split(text, " ").split(" ");//中文分词处理(分词后结果可能还包含有停用词)
terms = DropStopWords(terms);//去掉停用词,以免影响分词
String[] Classes = tdm.getTraningClassifications();//分类
float probility = 0.0F;
List<ClassifyResult> crs = new ArrayList<ClassifyResult>();//分类结果
for (int i = 0; i <Classes.length; i++)
{
String Ci = Classes[i];//第i个分类
probility = calcProd(terms, Ci);//计算给定的文本属性向量terms在给定的分类Ci中的分类条件概率
//保存分类结果
ClassifyResult cr = new ClassifyResult();
cr.classification = Ci;//分类
cr.probility = probility;//关键字在分类的条件概率
System.out.println("In process.");
System.out.println(Ci + ":" + probility);
crs.add(cr);
}
//对最后概率结果进行排序
java.util.Collections.sort(crs,new Comparator()
{
public int compare(final Object o1,final Object o2)
{
final ClassifyResult m1 = (ClassifyResult) o1;
final ClassifyResult m2 = (ClassifyResult) o2;
final double ret = m1.probility - m2.probility;
if (ret < 0)
{
return 1;
}
else
{
return -1;
}
}
});
//返回概率最大的分类
return crs.get(0).classification;
}
public static void main(String[] args)
{
String text = "微软公司提出以446亿美元的价格收购雅虎中国网2月1日报道 美联社消息,微软公司提出以446亿美元现金加股票的价格收购搜索网站雅虎公司。微软提出以每股31美元的价格收购雅虎。微软的收购报价较雅虎1月31日的收盘价19.18美元溢价62%。微软公司称雅虎公司的股东可以选择以现金或股票进行交易。微软和雅虎公司在2006年底和2007年初已在寻求双方合作。而近两年,雅虎一直处于困境:市场份额下滑、运营业绩不佳、股价大幅下跌。对于力图在互联网市场有所作为的微软来说,收购雅虎无疑是一条捷径,因为双方具有非常强的互补性。(小桥)";
BayesClassifier classifier = new BayesClassifier();//构造Bayes分类器
String result = classifier.classify(text);//进行分类
System.out.println("此项属于["+result+"]");
}
}
创建类ChineseSpliter.java
中文分词器,因为对文本进行分类时,需要计算文本中每个词在各类别中出现的概率,所以需要对文本进行分词,用到了分词器
package com.lzl;
import java.io.IOException;
import jeasy.analysis.MMAnalyzer;
public class ChineseSpliter
{
/**
* 对给定的文本进行中文分词
* @param text 给定的文本
* @param splitToken 用于分割的标记,如"|"
* @return 分词完毕的文本
*/
public static String split(String text,String splitToken)
{
String result = null;
MMAnalyzer analyzer = new MMAnalyzer(); //中文分词工具
try
{
result = analyzer.segment(text, splitToken);
}
catch (IOException e)
{
e.printStackTrace();
}
return result;
}
}
创建类ClassConditionalProbability.java
类条件概率计算,这是另一个影响因子,和先验概率一起来决定最终结果
类条件概率
P(xj|cj)=( N(X=xi, C=cj
)+1 ) / ( N(C=cj)+M+V )
其中,N(X=xi, C=cj)表示类别cj中包含属性x
i的训练文本数量;N(C=cj)表示类别cj中的训练文本数量;M值用于避免
N(X=xi, C=cj)过小所引发的问题;V表示类别的总数。
条件概率
定义 设A, B是两个事件,且P(A)>0 称
P(B∣A)=P(AB)/P(A)
为在条件A下发生的条件事件B发生的条件概率。
package com.lzl;
public class ClassConditionalProbability
{
private static TrainingDataManager tdm = new TrainingDataManager();//训练语料的分类
private static final float M = 0F;
/**
* 计算类条件概率
* @param x 给定的文本属性
* @param c 给定的分类
* @return 给定条件下的类条件概率
*/
public static float calculatePxc(String x, String c)
{
float ret = 0F;
float Nxc = tdm.getCountContainKeyOfClassification(c, x);//返回给定分类c中包含关键字/词x的训练文本的数目
float Nc = tdm.getTrainingFileCountOfClassification(c);//返回训练文本集中在给定分类c下的训练文本数目
float V = tdm.getTraningClassifications().length;//所有训练文本的类别数目
ret = (Nxc + 1) / (Nc + M + V);
return ret;
}
}
创建ClassifyResult.java
分类结果,用来保存各个分类及其计算出的概率值
package com.lzl;
public class ClassifyResult
{
public double probility;//分类的概率
public String classification;//分类
public ClassifyResult()
{
this.probility = 0;
this.classification = null;
}
}
创建PriorProbability.java
先验概率计算
P(cj)=N(C=cj)/N
其中,N(C=cj)表示类别cj中的训练文本数量;
N表示训练文本集总数量。
package com.lzl;
public class PriorProbability
{
private static TrainingDataManager tdm =new TrainingDataManager();
/**
* 先验概率
* @param c 给定的分类
* @return 给定条件下的先验概率
*/
public static float calculatePc(String c)
{
float ret = 0F;
float Nc = tdm.getTrainingFileCountOfClassification(c);//返回训练文本集中在给定分类下的训练文本数目
float N = tdm.getTrainingFileCount();//返回训练文本集中所有的文本数目
ret = Nc / N;
return ret;
}
}
创建StopWordsHandler.java
停用词处理器,去掉文档中无意思的词语也是必须的一项工作,这里简单的定义了一些常见的停用词,
并根据这些常用停用词在分词时进行判断。
相当于预处理,去噪
package com.lzl;
public class StopWordsHandler
{
private static String stopWordsList[] ={"的", "我们","要","自己","之","将","“","”",",","(",")","后","应","到","某","后","个","是","位","新","一","两","在","中","或","有","更","好",""};//常用停用词
public static boolean IsStopWord(String word)
{
for(int i=0;i<stopWordsList.length;++i)
{
if(word.equalsIgnoreCase(stopWordsList[i]))
return true;
}
return false;
}
}
创建TrainingDataManager.java
训练集管理器,首先需要从训练样本集中得到假设的先验概率和给定假设下观察到不同数据的概率。
package com.lzl;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Properties;
import java.util.logging.Level;
import java.util.logging.Logger;
public class TrainingDataManager
{
private String[] traningFileClassifications;//训练语料分类集合
private File traningTextDir;//训练语料存放目录
private static String defaultPath = "F:\\studyfiles\\数据挖掘\\3\\Bayes\\BayesData\\Sample";
public TrainingDataManager()
{
traningTextDir = new File(defaultPath);
if (!traningTextDir.isDirectory())
{
throw new IllegalArgumentException("训练语料库搜索失败! [" +defaultPath + "]");
}
this.traningFileClassifications = traningTextDir.list();
}
返回训练文本类别,这个类别就是目录名
@return 训练文本类别
public String[] getTraningClassifications()
{
return this.traningFileClassifications;
}
根据训练文本类别返回这个类别下的所有训练文本路径(full path)
@param classification 给定的分类
@return 给定分类下所有文件的路径(full path)
public String[] getFilesPath(String classification)
{
File classDir = new File(traningTextDir.getPath() +File.separator +classification);
String[] ret = classDir.list();
for (int i = 0; i < ret.length; i++)
{
ret[i] = traningTextDir.getPath() +File.separator +classification +File.separator +ret[i];
}
return ret;
}
返回给定路径的文本文件内容
@param filePath 给定的文本文件路径
@return 文本内容
@throws java.io.FileNotFoundException
@throws java.io.IOException
public static String getText(String filePath) throws FileNotFoundException,IOException
{
InputStreamReader isReader =new InputStreamReader(new FileInputStream(filePath),"GBK");
BufferedReader reader = new BufferedReader(isReader);
String aline;
StringBuilder sb = new StringBuilder();
while ((aline = reader.readLine()) != null)
{
sb.append(aline + " ");
}
isReader.close();
reader.close();
return sb.toString();
}
返回训练文本集中所有的文本数目
@return 训练文本集中所有的文本数目
public int getTrainingFileCount()
{
int ret = 0;
for (int i = 0; i < traningFileClassifications.length; i++)
{
ret +=getTrainingFileCountOfClassification(traningFileClassifications[i]);
}
return ret;
}
返回训练文本集中在给定分类下的训练文本数目
@param classification 给定的分类
@return 训练文本集中在给定分类下的训练文本数目
public int getTrainingFileCountOfClassification(String classification)
{
File classDir = new File(traningTextDir.getPath() +File.separator +classification);
return classDir.list().length;
}
返回给定分类中包含关键字/词的训练文本的数目
@param classification 给定的分类
@param key 给定的关键字/词
@return 给定分类中包含关键字/词的训练文本的数目
public int getCountContainKeyOfClassification(String classification,String key)
{
int ret = 0;
try
{
String[] filePath = getFilesPath(classification);
for (int j = 0; j < filePath.length; j++)
{
String text = getText(filePath[j]);
if (text.contains(key))
{
ret++;
}
}
}
catch (FileNotFoundException ex)
{
Logger.getLogger(TrainingDataManager.class.getName()).log(Level.SEVERE, null,ex);
}
catch (IOException ex)
{
Logger.getLogger(TrainingDataManager.class.getName()).log(Level.SEVERE, null,ex);
}
return ret;
}
}
运行结果
- 点赞
- 收藏
- 关注作者
评论(0)