数据挖掘的过程
数据挖掘任务主要分为以下六个步骤:
- 1.数据预处理
- 2.特征转换
- 3.特征选择
- 4.训练模型
- 5.模型预测
- 6.评估预测结果
数据准备
这里准备了20条关于不同地区、不同性别、不同身高、体重…的人的兴趣数据集(命名为hobby.csv):
id,hobby,sex,address,age,height,weight
1,football,male,dalian,12,168,552,pingpang,female,yangzhou,21,163,603,football,male,dalian,,172,704,football,female,,13,167,585,pingpang,female,shanghai,63,170,646,football,male,dalian,30,177,767,basketball,male,shanghai,25,181,908,football,male,dalian,15,172,719,basketball,male,shanghai,25,179,8010,pingpang,male,shanghai,55,175,7211,football,male,dalian,13,169,5512,pingpang,female,yangzhou,22,164,6113,football,male,dalian,23,170,7114,football,female,,12,164,5515,pingpang,female,shanghai,64,169,6316,football,male,dalian,30,177,7617,basketball,male,shanghai,22,180,8018,football,male,dalian,16,173,7219,basketball,male,shanghai,23,176,7320,pingpang,male,shanghai,56,171,71
- 任务分析 通过sex,address,age,height,weight这五个特征预测一个人的兴趣爱好
数据预处理
想要连接数据,必须先创建一个spark对象
定义Spark对象
使用SparkSession中的builder()构建 后续设定appName 和master ,最后使用getOrCreate()完成构建
// 定义spark对象val spark = SparkSession.builder().appName("兴趣预测").master("local[*]").getOrCreate()
连接数据
使用spark.read连接数据,需要指定数据的格式为“CSV”,将首行设置为header,最后指定文件路径:
val df=spark.read.format("CSV").option("header",true).load("C:/Users/35369/Desktop/hobby.csv")
使用df.show() df.printSchema()查看数据:
df.show()
df.printSchema()
spark.stop()// 关闭spark
输出信息:
+---+----------+------+--------+----+------+------+| id| hobby| sex| address| age|height|weight|+---+----------+------+--------+----+------+------+|1| football| male| dalian|12|168|55||2| pingpang|female|yangzhou|21|163|60||3| football| male| dalian|null|172|70||4| football|female|null|13|167|58||5| pingpang|female|shanghai|63|170|64||6| football| male| dalian|30|177|76||7|basketball| male|shanghai|25|181|90||8| football| male| dalian|15|172|71||9|basketball| male|shanghai|25|179|80||10| pingpang| male|shanghai|55|175|72||11| football| male| dalian|13|169|55||12| pingpang|female|yangzhou|22|164|61||13| football| male| dalian|23|170|71||14| football|female|null|12|164|55||15| pingpang|female|shanghai|64|169|63||16| football| male| dalian|30|177|76||17|basketball| male|shanghai|22|180|80||18| football| male| dalian|16|173|72||19|basketball| male|shanghai|23|176|73||20| pingpang| male|shanghai|56|171|71|+---+----------+------+--------+----+------+------+
root
|-- id: string (nullable =true)|-- hobby: string (nullable =true)|-- sex: string (nullable =true)|-- address: string (nullable =true)|-- age: string (nullable =true)|-- height: string (nullable =true)|-- weight: string (nullable =true)
补全年龄空缺的行
补全数值型数据可以分三步:
(1)取出去除空行数据之后的这一列数据
(2)计算(1)中那一列数据的平均值
(3)将平均值填充至原先的表中
- (1)取出空行之后的数据
val ageNaDF = df.select("age").na.drop()
ageNaDF.show()
+---+|age|+---+|12||21||13||63||30||25||15||25||55||13||22||23||12||64||30||22||16||23||56|+---+
- (2)计算(1)中那一列数据的平均值
查看ageNaDF的基本特征
ageNaDF.describe("age").show()
输出:
+-------+-----------------+|summary| age|+-------+-----------------+| count|19|| mean|28.42105263157895|| stddev|17.48432882286206|| min|12|| max|64|+-------+-----------------+
可以看到其中的均值mean为28.42105263157895,我们需要取出这个mean
val mean = ageNaDF.describe("age").select("age").collect()(1)(0).toString
print(mean)//28.42105263157895
- (3)将平均值填充至原先的表中 使用df.na.fill()方法可以填充空值,需要指定列为“age”,所以第二个参数为List(“age”)
val ageFilledDF = df.na.fill(mean,List("age"))
ageFilledDF.show()
输出:
+---+----------+------+--------+-----------------+------+------+| id| hobby| sex| address| age|height|weight|+---+----------+------+--------+-----------------+------+------+|1| football| male| dalian|12|168|55||2| pingpang|female|yangzhou|21|163|60||3| football| male| dalian|28.42105263157895|172|70||4| football|female|null|13|167|58||5| pingpang|female|shanghai|63|170|64||6| football| male| dalian|30|177|76||7|basketball| male|shanghai|25|181|90||8| football| male| dalian|15|172|71||9|basketball| male|shanghai|25|179|80||10| pingpang| male|shanghai|55|175|72||11| football| male| dalian|13|169|55||12| pingpang|female|yangzhou|22|164|61||13| football| male| dalian|23|170|71||14| football|female|null|12|164|55||15| pingpang|female|shanghai|64|169|63||16| football| male| dalian|30|177|76||17|basketball| male|shanghai|22|180|80||18| football| male| dalian|16|173|72||19|basketball| male|shanghai|23|176|73||20| pingpang| male|shanghai|56|171|71|+---+----------+------+--------+-----------------+------+------+
可以发现年龄中的空值被填充了平均值
删除城市有空值所在的行
由于城市的列没有合理的数据可以填充,所以如果城市出现空数据则选择把改行删除
使用.na.drop()方法
val addressDf = ageFilledDF.na.drop()
addressDf.show()
输出:
+---+----------+------+--------+-----------------+------+------+| id| hobby| sex| address| age|height|weight|+---+----------+------+--------+-----------------+------+------+|1| football| male| dalian|12|168|55||2| pingpang|female|yangzhou|21|163|60||3| football| male| dalian|28.42105263157895|172|70||5| pingpang|female|shanghai|63|170|64||6| football| male| dalian|30|177|76||7|basketball| male|shanghai|25|181|90||8| football| male| dalian|15|172|71||9|basketball| male|shanghai|25|179|80||10| pingpang| male|shanghai|55|175|72||11| football| male| dalian|13|169|55||12| pingpang|female|yangzhou|22|164|61||13| football| male| dalian|23|170|71||15| pingpang|female|shanghai|64|169|63||16| football| male| dalian|30|177|76||17|basketball| male|shanghai|22|180|80||18| football| male| dalian|16|173|72||19|basketball| male|shanghai|23|176|73||20| pingpang| male|shanghai|56|171|71|+---+----------+------+--------+-----------------+------+------+
4和14行被删除
将每列字段的格式转换成合理的格式
//对df的schema进行调整val formatDF = addressDf.select(
col("id").cast("int"),
col("hobby").cast("String"),
col("sex").cast("String"),
col("address").cast("String"),
col("age").cast("Double"),
col("height").cast("Double"),
col("weight").cast("Double"))
formatDF.printSchema()
输出:
root
|-- id: integer (nullable =true)|-- hobby: string (nullable =true)|-- sex: string (nullable =true)|-- address: string (nullable =true)|-- age: double (nullable =true)|-- height: double (nullable =true)|-- weight: double (nullable =true)
到此,数据预处理部分完成。
特征转换
为了便于模型训练,在数据的特征转换中,我们需要对age、weight、height、address、sex这些特征做分桶处理。
对年龄做分桶处理
- 18以下
- 18-35
- 35-60
- 60以上
使用Bucketizer类用来分桶处理,需要设置输入的列名和输出的列名,把定义的分桶区间作为这个类分桶的依据,最后给定需要做分桶处理的DataFrame
//2.1 对年龄进行分桶处理//定义一个数组作为分桶的区间val ageSplits = Array(Double.NegativeInfinity,18,35,60,Double.PositiveInfinity)val bucketizerDF =new Bucketizer().setInputCol("age").setOutputCol("ageFeature").setSplits(ageSplits).transform(formatDF)
bucketizerDF.show()
查看分桶结果:
+---+----------+------+--------+-----------------+------+------+----------+| id| hobby| sex| address| age|height|weight|ageFeature|+---+----------+------+--------+-----------------+------+------+----------+|1| football| male| dalian|12.0|168.0|55.0|0.0||2| pingpang|female|yangzhou|21.0|163.0|60.0|1.0||3| football| male| dalian|28.42105263157895|172.0|70.0|1.0||5| pingpang|female|shanghai|63.0|170.0|64.0|3.0||6| football| male| dalian|30.0|177.0|76.0|1.0||7|basketball| male|shanghai|25.0|181.0|90.0|1.0||8| football| male| dalian|15.0|172.0|71.0|0.0||9|basketball| male|shanghai|25.0|179.0|80.0|1.0||10| pingpang| male|shanghai|55.0|175.0|72.0|2.0||11| football| male| dalian|13.0|169.0|55.0|0.0||12| pingpang|female|yangzhou|22.0|164.0|61.0|1.0||13| football| male| dalian|23.0|170.0|71.0|1.0||15| pingpang|female|shanghai|64.0|169.0|63.0|3.0||16| football| male| dalian|30.0|177.0|76.0|1.0||17|basketball| male|shanghai|22.0|180.0|80.0|1.0||18| football| male| dalian|16.0|173.0|72.0|0.0||19|basketball| male|shanghai|23.0|176.0|73.0|1.0||20| pingpang| male|shanghai|56.0|171.0|71.0|2.0|+---+----------+------+--------+-----------------+------+------+----------+
对身高做二值化处理
基准为170 使用Binarizer类
//2.2 对身高做二值化处理val heightDF =new Binarizer().setInputCol("height").setOutputCol("heightFeature").setThreshold(170)// 阈值.transform(bucketizerDF)
heightDF.show()
查看处理后结果:
+---+----------+------+--------+-----------------+------+------+----------+-------------+| id| hobby| sex| address| age|height|weight|ageFeature|heightFeature|+---+----------+------+--------+-----------------+------+------+----------+-------------+|1| football| male| dalian|12.0|168.0|55.0|0.0|0.0||2| pingpang|female|yangzhou|21.0|163.0|60.0|1.0|0.0||3| football| male| dalian|28.42105263157895|172.0|70.0|1.0|1.0||5| pingpang|female|shanghai|63.0|170.0|64.0|3.0|0.0||6| football| male| dalian|30.0|177.0|76.0|1.0|1.0||7|basketball| male|shanghai|25.0|181.0|90.0|1.0|1.0||8| football| male| dalian|15.0|172.0|71.0|0.0|1.0||9|basketball| male|shanghai|25.0|179.0|80.0|1.0|1.0||10| pingpang| male|shanghai|55.0|175.0|72.0|2.0|1.0||11| football| male| dalian|13.0|169.0|55.0|0.0|0.0||12| pingpang|female|yangzhou|22.0|164.0|61.0|1.0|0.0||13| football| male| dalian|23.0|170.0|71.0|1.0|0.0||15| pingpang|female|shanghai|64.0|169.0|63.0|3.0|0.0||16| football| male| dalian|30.0|177.0|76.0|1.0|1.0||17|basketball| male|shanghai|22.0|180.0|80.0|1.0|1.0||18| football| male| dalian|16.0|173.0|72.0|0.0|1.0||19|basketball| male|shanghai|23.0|176.0|73.0|1.0|1.0||20| pingpang| male|shanghai|56.0|171.0|71.0|2.0|1.0|+---+----------+------+--------+-----------------+------+------+----------+-------------+
对体重做二值化处理
阈值设为 65
//2.3 对体重做二值化处理val weightDF =new Binarizer().setInputCol("weight").setOutputCol("weightFeature").setThreshold(65).transform(heightDF)
weightDF.show()
性别、城市、爱好字段的处理
这三个字段都是字符串,而字符串的形式在机器学习中是不适合做分析处理的,所以也需要对他们做特征转换(编码处理)。
//2.4 对性别进行labelEncode转换val sexIndex =new StringIndexer().setInputCol("sex").setOutputCol("sexIndex").fit(weightDF).transform(weightDF)//2.5对家庭地址进行labelEncode转换val addIndex =new StringIndexer().setInputCol("address").setOutputCol("addIndex").fit(sexIndex).transform(sexIndex)//2.6对地址进行one-hot编码val addOneHot =new OneHotEncoder().setInputCol("addIndex").setOutputCol("addOneHot").fit(addIndex).transform(addIndex)//2.7对兴趣字段进行LabelEncode处理val hobbyIndexDF =new StringIndexer().setInputCol("hobby").setOutputCol("hobbyIndex").fit(addOneHot).transform(addOneHot)
hobbyIndexDF.show()
这里额外对地址做了一个one-hot处理。
将hobbyIndex列名称改成label,因为hobby在模型训练阶段用作标签。
//2.8修改列名val resultDF = hobbyIndexDF.withColumnRenamed("hobbyIndex","label")
resultDF.show()
最终特征转换后的结果:
+---+----------+------+--------+-----------------+------+------+----------+-------------+-------------+--------+--------+-------------+-----+| id| hobby| sex| address| age|height|weight|ageFeature|heightFeature|weightFeature|sexIndex|addIndex| addOneHot|label|+---+----------+------+--------+-----------------+------+------+----------+-------------+-------------+--------+--------+-------------+-----+|1| football| male| dalian|12.0|168.0|55.0|0.0|0.0|0.0|0.0|0.0|(2,[0],[1.0])|0.0||2| pingpang|female|yangzhou|21.0|163.0|60.0|1.0|0.0|0.0|1.0|2.0|(2,[],[])|1.0||3| football| male| dalian|28.42105263157895|172.0|70.0|1.0|1.0|1.0|0.0|0.0|(2,[0],[1.0])|0.0||5| pingpang|female|shanghai|63.0|170.0|64.0|3.0|0.0|0.0|1.0|1.0|(2,[1],[1.0])|1.0||6| football| male| dalian|30.0|177.0|76.0|1.0|1.0|1.0|0.0|0.0|(2,[0],[1.0])|0.0||7|basketball| male|shanghai|25.0|181.0|90.0|1.0|1.0|1.0|0.0|1.0|(2,[1],[1.0])|2.0||8| football| male| dalian|15.0|172.0|71.0|0.0|1.0|1.0|0.0|0.0|(2,[0],[1.0])|0.0||9|basketball| male|shanghai|25.0|179.0|80.0|1.0|1.0|1.0|0.0|1.0|(2,[1],[1.0])|2.0||10| pingpang| male|shanghai|55.0|175.0|72.0|2.0|1.0|1.0|0.0|1.0|(2,[1],[1.0])|1.0||11| football| male| dalian|13.0|169.0|55.0|0.0|0.0|0.0|0.0|0.0|(2,[0],[1.0])|0.0||12| pingpang|female|yangzhou|22.0|164.0|61.0|1.0|0.0|0.0|1.0|2.0|(2,[],[])|1.0||13| football| male| dalian|23.0|170.0|71.0|1.0|0.0|1.0|0.0|0.0|(2,[0],[1.0])|0.0||15| pingpang|female|shanghai|64.0|169.0|63.0|3.0|0.0|0.0|1.0|1.0|(2,[1],[1.0])|1.0||16| football| male| dalian|30.0|177.0|76.0|1.0|1.0|1.0|0.0|0.0|(2,[0],[1.0])|0.0||17|basketball| male|shanghai|22.0|180.0|80.0|1.0|1.0|1.0|0.0|1.0|(2,[1],[1.0])|2.0||18| football| male| dalian|16.0|173.0|72.0|0.0|1.0|1.0|0.0|0.0|(2,[0],[1.0])|0.0||19|basketball| male|shanghai|23.0|176.0|73.0|1.0|1.0|1.0|0.0|1.0|(2,[1],[1.0])|2.0||20| pingpang| male|shanghai|56.0|171.0|71.0|2.0|1.0|1.0|0.0|1.0|(2,[1],[1.0])|1.0|+---+----------+------+--------+-----------------+------+------+----------+-------------+-------------+--------+--------+-------------+-----+
特征选择
特征转换后的结果是一个多列数据,但不是所有的列都可以拿来用作机器学习的模型训练,特征选择就是要选择可以用来机器学习的数据。
选择特征
使用VectorAssembler()可以将需要的列取出
//3.1选择特征val vectorAssembler =new VectorAssembler().setInputCols(Array("ageFeature","heightFeature","weightFeature","sexIndex","addIndex","label")).setOutputCol("features")
特征进行规范化处理
val scaler =new StandardScaler().setInputCol("features").setOutputCol("featureScaler").setWithStd(true)// 是否使用标准差.setWithMean(false)// 是否使用中位数
特征筛选
// 特征筛选,使用卡方检验方法来做筛选val selector =new ChiSqSelector().setLabelCol("label").setOutputCol("featuresSelector")
构建逻辑回归模型和pipline
// 逻辑回归模型val lr =new LogisticRegression().setLabelCol("label").setFeaturesCol("featuresSelector")// 构造pipelineval pipeline =new Pipeline().setStages(Array(vectorAssembler,scaler,selector,lr))
设置网络搜索最佳参数
// 设置网络搜索最佳参数val params =new ParamGridBuilder().addGrid(lr.regParam,Array(0.1,0.01))//正则化参数.addGrid(selector.numTopFeatures,Array(5,10,5))//设置卡方检验最佳特征数.build()
设置交叉检验
// 设置交叉检验val cv =new CrossValidator().setEstimator(pipeline).setEvaluator(new BinaryClassificationEvaluator()).setEstimatorParamMaps(params).setNumFolds(5)
模型训练与预测
模型训练前需要拆分一下训练集和测试集
val Array(trainDF,testDF)= resultDF.randomSplit(Array(0.8,0.2))
使用randomSplit方法可以完成拆分
- 开始训练和预测
val model = cv.fit(trainDF)// 模型预测val preddiction = model.bestModel.transform(testDF)
preddiction.show()
报错求解决
运行cv.fit(trainDF)的地方报错了 这个信息网上也没找到
Exception in thread "main" java.lang.NoClassDefFoundError: org/apache/spark/sql/catalyst/trees/BinaryLike
at java.lang.ClassLoader.defineClass1(Native Method)
at java.lang.ClassLoader.defineClass(ClassLoader.java:756)
at java.security.SecureClassLoader.defineClass(SecureClassLoader.java:142)
at java.net.URLClassLoader.defineClass(URLClassLoader.java:473)
at java.net.URLClassLoader.access$100(URLClassLoader.java:74)
at java.net.URLClassLoader$1.run(URLClassLoader.java:369)
at java.net.URLClassLoader$1.run(URLClassLoader.java:363)
at java.security.AccessController.doPrivileged(Native Method)
at java.net.URLClassLoader.findClass(URLClassLoader.java:362)
at java.lang.ClassLoader.loadClass(ClassLoader.java:418)
at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:355)
at java.lang.ClassLoader.loadClass(ClassLoader.java:351)
at org.apache.spark.ml.stat.SummaryBuilderImpl.summary(Summarizer.scala:251)
at org.apache.spark.ml.stat.SummaryBuilder.summary(Summarizer.scala:54)
at org.apache.spark.ml.feature.StandardScaler.fit(StandardScaler.scala:112)
at org.apache.spark.ml.feature.StandardScaler.fit(StandardScaler.scala:84)
at org.apache.spark.ml.Pipeline.$anonfun$fit$5(Pipeline.scala:151)
at org.apache.spark.ml.MLEvents.withFitEvent(events.scala:130)
at org.apache.spark.ml.MLEvents.withFitEvent$(events.scala:123)
at org.apache.spark.ml.util.Instrumentation.withFitEvent(Instrumentation.scala:42)
at org.apache.spark.ml.Pipeline.$anonfun$fit$4(Pipeline.scala:151)
at scala.collection.Iterator.foreach(Iterator.scala:943)
at scala.collection.Iterator.foreach$(Iterator.scala:943)
at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
at org.apache.spark.ml.Pipeline.$anonfun$fit$2(Pipeline.scala:147)
at org.apache.spark.ml.MLEvents.withFitEvent(events.scala:130)
at org.apache.spark.ml.MLEvents.withFitEvent$(events.scala:123)
at org.apache.spark.ml.util.Instrumentation.withFitEvent(Instrumentation.scala:42)
at org.apache.spark.ml.Pipeline.$anonfun$fit$1(Pipeline.scala:133)
at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
at scala.util.Try$.apply(Try.scala:213)
at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
at org.apache.spark.ml.Pipeline.fit(Pipeline.scala:133)
at org.apache.spark.ml.Pipeline.fit(Pipeline.scala:93)
at org.apache.spark.ml.Estimator.fit(Estimator.scala:59)
at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$7(CrossValidator.scala:174)
at scala.runtime.java8.JFunction0$mcD$sp.apply(JFunction0$mcD$sp.java:23)
at scala.concurrent.Future$.$anonfun$apply$1(Future.scala:659)
at scala.util.Success.$anonfun$map$1(Try.scala:255)
at scala.util.Success.map(Try.scala:213)
at scala.concurrent.Future.$anonfun$map$1(Future.scala:292)
at scala.concurrent.impl.Promise.liftedTree1$1(Promise.scala:33)
at scala.concurrent.impl.Promise.$anonfun$transform$1(Promise.scala:33)
at scala.concurrent.impl.CallbackRunnable.run(Promise.scala:64)
at org.sparkproject.guava.util.concurrent.MoreExecutors$SameThreadExecutorService.execute(MoreExecutors.java:293)
at scala.concurrent.impl.ExecutionContextImpl$$anon$4.execute(ExecutionContextImpl.scala:138)
at scala.concurrent.impl.CallbackRunnable.executeWithValue(Promise.scala:72)
at scala.concurrent.impl.Promise$KeptPromise$Kept.onComplete(Promise.scala:372)
at scala.concurrent.impl.Promise$KeptPromise$Kept.onComplete$(Promise.scala:371)
at scala.concurrent.impl.Promise$KeptPromise$Successful.onComplete(Promise.scala:379)
at scala.concurrent.impl.Promise.transform(Promise.scala:33)
at scala.concurrent.impl.Promise.transform$(Promise.scala:31)
at scala.concurrent.impl.Promise$KeptPromise$Successful.transform(Promise.scala:379)
at scala.concurrent.Future.map(Future.scala:292)
at scala.concurrent.Future.map$(Future.scala:292)
at scala.concurrent.impl.Promise$KeptPromise$Successful.map(Promise.scala:379)
at scala.concurrent.Future$.apply(Future.scala:659)
at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$6(CrossValidator.scala:182)
at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
at scala.collection.TraversableLike.map(TraversableLike.scala:286)
at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:198)
at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$4(CrossValidator.scala:172)
at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
at scala.collection.TraversableLike.map(TraversableLike.scala:286)
at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:198)
at org.apache.spark.ml.tuning.CrossValidator.$anonfun$fit$1(CrossValidator.scala:166)
at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
at scala.util.Try$.apply(Try.scala:213)
at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
at org.apache.spark.ml.tuning.CrossValidator.fit(CrossValidator.scala:137)
at org.example.SparkML.SparkMl01$.main(SparkMl01.scala:147)
at org.example.SparkML.SparkMl01.main(SparkMl01.scala)
Caused by: java.lang.ClassNotFoundException: org.apache.spark.sql.catalyst.trees.BinaryLike
at java.net.URLClassLoader.findClass(URLClassLoader.java:387)
at java.lang.ClassLoader.loadClass(ClassLoader.java:418)
at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:355)
at java.lang.ClassLoader.loadClass(ClassLoader.java:351)
全部源码以及pom文件
packageorg.example.SparkML
importorg.apache.spark.ml.Pipeline
importorg.apache.spark.ml.classification.LogisticRegression
importorg.apache.spark.ml.evaluation.BinaryClassificationEvaluator
importorg.apache.spark.ml.feature.{Binarizer, Bucketizer, ChiSqSelector, OneHotEncoder, StandardScaler, StringIndexer, VectorAssembler}importorg.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}importorg.apache.spark.sql.SparkSession
importorg.apache.spark.sql.functions.col/**
* 数据挖掘的过程
* 1.数据预处理
* 2.特征转换(编码。。。)
* 3.特征选择
* 4.训练模型
* 5.模型预测
* 6.评估预测结果
*/object SparkMl01 {def main(args: Array[String]):Unit={// 定义spark对象val spark = SparkSession.builder().appName("兴趣预测").master("local").getOrCreate()importspark.implicits._
val df=spark.read.format("CSV").option("header",true).load("C:/Users/35369/Desktop/hobby.csv")//1.数据预处理,补全空缺的年龄val ageNaDF = df.select("age").na.drop()val mean = ageNaDF.describe("age").select("age").collect()(1)(0).toString
val ageFilledDF = df.na.fill(mean,List("age"))//address为空的行直接删除val addressDf = ageFilledDF.na.drop()//对df的schema进行调整val formatDF = addressDf.select(
col("id").cast("int"),
col("hobby").cast("String"),
col("sex").cast("String"),
col("address").cast("String"),
col("age").cast("Double"),
col("height").cast("Double"),
col("weight").cast("Double"))//2.特征转换//2.1 对年龄进行分桶处理//定义一个数组作为分桶的区间val ageSplits = Array(Double.NegativeInfinity,18,35,60,Double.PositiveInfinity)val bucketizerDF =new Bucketizer().setInputCol("age").setOutputCol("ageFeature").setSplits(ageSplits).transform(formatDF)//2.2 对身高做二值化处理val heightDF =new Binarizer().setInputCol("height").setOutputCol("heightFeature").setThreshold(170)// 阈值.transform(bucketizerDF)//2.3 对体重做二值化处理val weightDF =new Binarizer().setInputCol("weight").setOutputCol("weightFeature").setThreshold(65).transform(heightDF)//2.4 对性别进行labelEncode转换val sexIndex =new StringIndexer().setInputCol("sex").setOutputCol("sexIndex").fit(weightDF).transform(weightDF)//2.5对家庭地址进行labelEncode转换val addIndex =new StringIndexer().setInputCol("address").setOutputCol("addIndex").fit(sexIndex).transform(sexIndex)//2.6对地址进行one-hot编码val addOneHot =new OneHotEncoder().setInputCol("addIndex").setOutputCol("addOneHot").fit(addIndex).transform(addIndex)//2.7对兴趣字段进行LabelEncode处理val hobbyIndexDF =new StringIndexer().setInputCol("hobby").setOutputCol("hobbyIndex").fit(addOneHot).transform(addOneHot)//2.8修改列名val resultDF = hobbyIndexDF.withColumnRenamed("hobbyIndex","label")//3 特征选择//3.1选择特征val vectorAssembler =new VectorAssembler().setInputCols(Array("ageFeature","heightFeature","weightFeature","sexIndex","addOneHot")).setOutputCol("features")//3.2特征进行规范化处理val scaler =new StandardScaler().setInputCol("features").setOutputCol("featureScaler").setWithStd(true)// 是否使用标准差.setWithMean(false)// 是否使用中位数// 特征筛选,使用卡方检验方法来做筛选val selector =new ChiSqSelector().setFeaturesCol("featureScaler").setLabelCol("label").setOutputCol("featuresSelector")// 逻辑回归模型val lr =new LogisticRegression().setLabelCol("label").setFeaturesCol("featuresSelector")// 构造pipelineval pipeline =new Pipeline().setStages(Array(vectorAssembler,scaler,selector,lr))// 设置网络搜索最佳参数val params =new ParamGridBuilder().addGrid(lr.regParam,Array(0.1,0.01))//正则化参数.addGrid(selector.numTopFeatures,Array(5,10,5))//设置卡方检验最佳特征数.build()// 设置交叉检验val cv =new CrossValidator().setEstimator(pipeline).setEvaluator(new BinaryClassificationEvaluator()).setEstimatorParamMaps(params).setNumFolds(5)// 模型训练val Array(trainDF,testDF)= resultDF.randomSplit(Array(0.8,0.2))
trainDF.show()
testDF.show()val model = cv.fit(trainDF)//生成模型// val model = pipeline.fit(trainDF)// val prediction = model.transform(testDF)// prediction.show()// 模型预测// val preddiction = model.bestModel.transform(testDF)// preddiction.show()
spark.stop()}}
<?xml version="1.0" encoding="UTF-8"?><project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"><modelVersion>4.0.0</modelVersion><groupId>org.example</groupId><artifactId>untitled</artifactId><version>1.0-SNAPSHOT</version><properties><maven.compiler.source>8</maven.compiler.source><maven.compiler.target>8</maven.compiler.target><project.build.sourceEncoding>UTF-8</project.build.sourceEncoding></properties><dependencies><dependency><groupId>org.scala-lang</groupId><artifactId>scala-library</artifactId><version>2.12.18</version></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-core_2.12</artifactId><version>3.0.0-preview2</version></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-hive_2.12</artifactId><version>3.1.2</version><!--<scope>provided</scope>--></dependency><dependency><groupId>org.apache.spark</groupId><artifactId>spark-sql_2.12</artifactId><version>3.0.0-preview2</version><!--<scope>compile</scope>--></dependency><!--<dependency>--><!--<groupId>mysql</groupId>--><!--<artifactId>mysql-connector-java</artifactId>--><!--<version>8.0.16</version>--><!--</dependency>--><dependency><groupId>org.apache.spark</groupId><artifactId>spark-mllib_2.12</artifactId><version>3.5.0</version><!--<scope>compile</scope>--></dependency></dependencies><build><plugins><plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-shade-plugin</artifactId><version>2.4.1</version><executions><execution><phase>package</phase><goals><goal>shade</goal></goals><configuration><transformers><transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer"><mainClass>com.xxg.Main</mainClass></transformer></transformers></configuration></execution></executions></plugin></plugins></build></project>
版权归原作者 冲鸭嘟嘟可 所有, 如有侵权,请联系我们删除。