spark实现tf-idf的一个案例

spark 提供了机器学习框架,由于spark是一个分布式引擎和数据处理引擎,在实现tf-idf并行处理,提取特征等操作极大的加快了效率。

import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer}
import org.apache.spark.ml.feature.NGram
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.SaveMode
import spark.implicits._

val stopwords = Seq("【","】","(",")",",","","[","]",". ","\"","'",":","。","|",".",".","-","/")
val strToList = udf((goods_name: String) =>  goods_name.replaceAll("[a-zA-Z0-9\\s]+","").toCharArray().map( _.toString().toLowerCase()).filter(!stopwords.contains(_)).toArray) 
spark.udf.register("strToList", strToList)

val df = spark.sql("select category as id,goods_id,category,img_url,goods_name,strToList(goods_name) goods_name_word  from datacenter.pdd_good_category_title_cat200up")


wordcountdf.printSchema()


val ngram = new NGram().setN(2).setInputCol("goods_name_word").setOutputCol("ngrams")
val ngramDataFrame = ngram.transform(df)
 
ngramDataFrame.createOrReplaceTempView("ct")
val wordcountdf = spark.sql("select count(distinct goods_name_words) cw from (select explode(ngrams) goods_name_words from ct ) t") 
val feature_num = wordcountdf.first().getAs[Long](0)

val hashingTF = new HashingTF().setInputCol("ngrams").setOutputCol("rawFeatures").setNumFeatures( (feature_num * 0.3).toInt )
val featurizedData = hashingTF.transform(ngramDataFrame)
 
hashingTF.save("/data/goods_name_wordsTF") 


val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features")
val idfModel = idf.fit(featurizedData)

val rescaledData = idfModel.transform(featurizedData)

rescaledData.select("id","ngrams","features").write.mode(SaveMode.Overwrite).saveAsTable("datacenter.pdd_good_category_title_words")

val vtdf = udf( (ngrams: scala.collection.mutable.WrappedArray[String],features: Vector) =>  ngrams zip  features.toSparse.values    ) 
spark.udf.register("vtdf", vtdf)
val vdf = spark.sql("select explode(vtdf(ngrams,features)) as word_weight from datacenter.pdd_good_category_title_words t ")
vdf.select("word_weight._1","word_weight._2").write.mode(SaveMode.Overwrite).saveAsTable("datacenter.pdd_good_category_title_words_weight")
Share