2

Is there a simple way to visualize the pyspark's LDA (pyspark.ml.clustering.LDA)?

ldamodel.transform(result).show() generates

+--------------------+---+--------------------+--------------------+
|            filtered| id|            features|   topicDistribution|
+--------------------+---+--------------------+--------------------+
|    [problem, popul]|  0|(18054,[49,493],[...|[0.03282220322786...|
|[tyler, note, glo...|  1|(18054,[40,52,57,...|[0.00440868073429...|
|[mani, economist,...|  2|(18054,[12,17,25,...|[0.00404065731437...|
|[probabl, correct...|  3|(18054,[0,4,7,21,...|[0.00485107317270...|
|[even, popul, ass...|  4|(18054,[10,12,49,...|[0.00334279689625...|
|[sake, argument, ...|  5|(18054,[1,9,12,61...|[0.00285045818525...|
|[much, tougher, p...|  6|(18054,[27,32,49,...|[0.00485107690380...|

+--------------------+---+--------------------+--------------------
lpt
  • 931
  • 16
  • 35

1 Answers1

5

This notebook helped me to visualize pyspark LDA topics. It uses D3 bubble chart to visualize the clusters. You could also use pyLDAvis for an interactive topic model visualization.

Here is code with pyspark that shows transforming the topic distribution from .transform API on dataframe. I am using spark LDA example data set in SVM format

# Code to train LDA model using spark ml
from pyspark.ml.clustering import LDA
from pyspark.sql.types import DoubleType
from pyspark.sql import functions as F

# Loads data
dataset = spark.read.format("libsvm").load("file:///usr/sample_lda_libsvm_data.txt")
dataset.show(truncate=False)

Example data

dataset.show(truncate=False)
+-----+---------------------------------------------------------------+
|label|features                                                       |
+-----+---------------------------------------------------------------+
|0.0  |(11,[0,1,2,4,5,6,7,10],[1.0,2.0,6.0,2.0,3.0,1.0,1.0,3.0])      |
|1.0  |(11,[0,1,3,4,7,10],[1.0,3.0,1.0,3.0,2.0,1.0])                  |
|2.0  |(11,[0,1,2,5,6,8,9],[1.0,4.0,1.0,4.0,9.0,1.0,2.0])             |
|3.0  |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,3.0,9.0])            |
|4.0  |(11,[0,1,2,3,4,6,9,10],[3.0,1.0,1.0,9.0,3.0,2.0,1.0,3.0])      |
|5.0  |(11,[0,1,3,4,5,6,7,8,9],[4.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0,4.0]) |
|6.0  |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,2.0,9.0])            |
|7.0  |(11,[0,1,2,3,4,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,1.0,2.0,1.0,3.0])|
|8.0  |(11,[0,1,3,4,5,6,7],[4.0,4.0,3.0,4.0,2.0,1.0,3.0])             |
|9.0  |(11,[0,1,2,4,6,8,9,10],[2.0,8.0,2.0,3.0,2.0,2.0,7.0,2.0])      |
|10.0 |(11,[0,1,2,3,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,2.0,3.0,3.0])      |
|11.0 |(11,[0,1,4,5,6,7,9],[4.0,1.0,4.0,5.0,1.0,3.0,1.0])             |
+-----+---------------------------------------------------------------+

Train a LDA model

# Trains a LDA model
lda = LDA(k=10, maxIter=10)
model = lda.fit(dataset)

# Describe topics.
topics = model.describeTopics(3)
print("The topics described by their top-weighted terms:")
topics.show(truncate=False)

The topics described by their top-weighted terms:

+-----+-----------+---------------------------------------------------------------+
|topic|termIndices|termWeights                                                    |
+-----+-----------+---------------------------------------------------------------+
|0    |[4, 7, 10] |[0.10782284792565977, 0.09748059037449146, 0.09623493647157101]|
|1    |[1, 6, 9]  |[0.16755678146051728, 0.14746675884135615, 0.12291623854765772]|
|2    |[3, 10, 6] |[0.2365737123772152, 0.10497827056720986, 0.0917840535687615]  |
|3    |[1, 3, 7]  |[0.1015758016249506, 0.09974496621850018, 0.09902599541011434] |
|4    |[9, 10, 3] |[0.10479879348457938, 0.10207370742688827, 0.09818478669740321]|
|5    |[8, 5, 7]  |[0.10843493028120557, 0.0970150424500599, 0.09334497822531877] |
|6    |[8, 5, 0]  |[0.09874156962344234, 0.09654280831555884, 0.09565956823827508]|
|7    |[9, 4, 7]  |[0.11252483000458603, 0.09755087587088286, 0.09643430900592685]|
|8    |[4, 1, 2]  |[0.10994283713713536, 0.09410686873447463, 0.0937471573628509] |
|9    |[5, 4, 0]  |[0.15265940066441183, 0.14015412109446546, 0.13878634876078264]|
+-----+-----------+---------------------------------------------------------------+

View topic distribution for every document

# view topic distribution for every document
transformed = model.transform(dataset)
transformed.show(truncate=False)

+-----+---------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|label|features                                                       |topicDistribution                                                                                                                                                                                                       |
+-----+---------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|0.0  |(11,[0,1,2,4,5,6,7,10],[1.0,2.0,6.0,2.0,3.0,1.0,1.0,3.0])      |[0.004830688509084788,0.9563375886321935,0.004924669693727129,0.004830693291141946,0.004830675601199576,0.004830690970098452,0.004830731737552684,0.004830674902568036,0.004830730786933749,0.004922855875500012]       |
|1.0  |(11,[0,1,3,4,7,10],[1.0,3.0,1.0,3.0,2.0,1.0])                  |[0.008057778755383592,0.3149188541525326,0.00821568856074705,0.008057899973735082,0.00805773202965193,0.00805773219443841,0.00805772753178338,0.008057790266770967,0.008057845264839285,0.6204609512701176]             |
|2.0  |(11,[0,1,2,5,6,8,9],[1.0,4.0,1.0,4.0,9.0,1.0,2.0])             |[0.004199741171245032,0.9620401773226402,0.004281469704273017,0.004199769097486346,0.004199807571784884,0.004199819505813106,0.004199835506062414,0.004199781772904878,0.004199800982100323,0.004279797365689855]       |
|3.0  |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,3.0,9.0])            |[0.003714896800546591,0.5070516557688054,0.4631584573147577,0.003714914880264338,0.0037150085177011572,0.003714949896828997,0.0037149846555122436,0.003714886267751718,0.003714909060953893,0.003785336836878225]       |
|4.0  |(11,[0,1,2,3,4,6,9,10],[3.0,1.0,1.0,9.0,3.0,2.0,1.0,3.0])      |[0.004024716198633711,0.004348960756766257,0.9633765414688664,0.004024715826289515,0.0040247523412803785,0.004024714760590197,0.004024750967476446,0.004024750137766685,0.004024763598734582,0.004101333943595805]      |
|5.0  |(11,[0,1,3,4,5,6,7,8,9],[4.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0,4.0]) |[0.003714916720108325,0.004014106400247752,0.0037876992243613913,0.0037149522531312196,0.0037149927030871474,0.0037149587146134535,0.0037149750439419123,0.0037150099006180567,0.003714963609773339,0.9661934254301174] |
|6.0  |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,2.0,9.0])            |[0.003863637584067354,0.44120209378688086,0.5278152614977222,0.0038636593932357263,0.003863751204372584,0.0038636970054184935,0.003863731528120536,0.0038636169190041057,0.003863652151710295,0.003936898929468125]     |
|7.0  |(11,[0,1,2,3,4,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,1.0,2.0,1.0,3.0])|[0.004390955723890411,0.004745014492795635,0.9600436030532219,0.004390986523517605,0.004391013571891052,0.004390968206875746,0.004391003804300225,0.004390998289212864,0.0043910030406065104,0.004474453293687847]      |
|8.0  |(11,[0,1,3,4,5,6,7],[4.0,4.0,3.0,4.0,2.0,1.0,3.0])             |[0.004391082468515706,0.004744799620819518,0.004477230286216996,0.004391179034422902,0.004391083385391976,0.0043911102087152145,0.004391108242443274,0.0043911476110250714,0.0043911508747108575,0.9600401082677386]    |
|9.0  |(11,[0,1,2,4,6,8,9,10],[2.0,8.0,2.0,3.0,2.0,2.0,7.0,2.0])      |[0.0033302167739046973,0.9698998050463385,0.0033949933226572675,0.0033302031974203014,0.0033302208173504686,0.003330228671311114,0.0033302277108795157,0.003330230056473623,0.0033302455331591036,0.0033936288705052665]|
|10.0 |(11,[0,1,2,3,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,2.0,3.0,3.0])      |[0.0041998552715806015,0.004538086674649772,0.9617828003374762,0.0041998854155415434,0.004199964563679233,0.004199898040748559,0.004199948969028732,0.004199941207400563,0.004199894377993083,0.004279725141901989]     |
|11.0 |(11,[0,1,4,5,6,7,9],[4.0,1.0,4.0,5.0,1.0,3.0,1.0])             |[0.0048305604098789244,0.005219225001032762,0.004924487214200011,0.004830543265675906,0.00483056515654878,0.004830577688731923,0.004830590528195045,0.004830599936989683,0.004830615233900232,0.9560422355648467]       |
+-----+---------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

Schema of transformed model

transformed.printSchema()
root
|-- label: double (nullable = true)
|-- features: vector (nullable = true)
|-- topicDistribution: vector (nullable = true)

As you notice topicDistribution is a vector. Below helper function helps parsing a vector.

def ith_(v, i):
    try:
        return float(v[i])
    except ValueError:
        return None

ith = F.udf(ith_, DoubleType())

Format to display each topic distribution for every document as separate column

df = transformed.select(["label"] + [ith("topicDistribution", F.lit(i)).alias('topic_'+str(i)) for i in range(10)] )
df.show(truncate=False)
+-----+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+
|label|topic_0              |topic_1              |topic_2              |topic_3              |topic_4              |topic_5              |topic_6              |topic_7              |topic_8              |topic_9              |
+-----+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+
|0.0  |0.004830687791450502 |0.9563377999372255   |0.004830652446299898 |0.004830693203685635 |0.004924680975321234 |0.004830690324650106 |0.004830724790894176 |0.004830674545741453 |0.004830728328369402 |0.00492266765636222  |
|1.0  |0.00805777782592821  |0.3150888304586096   |0.008057821375392899 |0.008057900091752447 |0.00821563090347786  |0.008057731378987427 |0.008057716226340182 |0.00805778996991863  |0.008057841440203276 |0.6202909603293896   |
|2.0  |0.004199740539975822 |0.9620403414727842   |0.004199830281319767 |0.004199769011855544 |0.004281446354869374 |0.004199818930938506 |0.004199829456280457 |0.004199781450899189 |0.004199798835689997 |0.00427964366538733  |
|3.0  |0.003714883352496639 |0.39438266523895776  |0.0037149161634889914|0.003714899290148889 |0.5758276298046127   |0.003714939245435922 |0.0037149657297638815|0.003714878209574761 |0.0037148981104253493|0.0037853248550950695|
|4.0  |0.00402472343811409  |0.0043486720544167945|0.0040247584323080295|0.004024726616022349 |0.9633767817635327   |0.004024722506471514 |0.004024749723387701 |0.004024759068339994 |0.00402477228684825  |0.0041013341105585275|
|5.0  |0.0037149161731463167|0.00401410657859215  |0.0037150318186438148|0.003714952190974752 |0.0037876713720541993|0.003714958223027372 |0.003714969707955506 |0.0037150096299263177|0.003714961725756829 |0.9661934225799228   |
|6.0  |0.0038636235465470963|0.32506932380193027  |0.0038636563625666425|0.003863644344443025 |0.6439482136665527   |0.0038636867164242353|0.003863712160357752 |0.003863609226073573 |0.003863641557265962 |0.00393688861783849  |
|7.0  |0.004390963901259502 |0.004744419369141901 |0.004391020228883301 |0.00439099927884862  |0.9600441405838983   |0.004390977425037901 |0.004391002809855065 |0.004391008592998927 |0.004391013090740394 |0.004474454719336111 |
|8.0  |0.004391081853379135 |0.004744865767572997 |0.004391206214702098 |0.004391178993516226 |0.004477132667794462 |0.0043911096593825015|0.0043911019675074445|0.004391147323286589 |0.0043911486798455125|0.960040026873013    |
|9.0  |0.003330216240957084 |0.9698999783457445   |0.00333023738785573  |0.0033302030986131904|0.003394973102900875 |0.0033302280874212362|0.0033302228867079335|0.0033302291785187624|0.0033302391644247616|0.003393472506855918 |
|10.0 |0.004199858865711682 |0.004538534384183169 |0.004199958349762097 |0.004199894260340701 |0.9617823390796781   |0.004199903494953782 |0.0041999446501473445|0.004199945557171458 |0.004199899755712464 |0.004279721602339041 |
|11.0 |0.00483055973980833  |0.005219211145215135 |0.004830592303351509 |0.004830543225945144 |0.004924458988916403 |0.004830577090650675 |0.004830583633398643 |0.004830599625982923 |0.004830612825588896 |0.9560422614211423   |
+-----+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+---------------------+

You can use the results from here to visualize topic distribution for a document or topics with top-weighted terms.

raj
  • 1,163
  • 7
  • 9
  • Thanks rajesh. Do you have example to share with? (this notebook is in Scala which I do not know) – lpt Jul 18 '18 at 17:54
  • That notebook is based on RDD, I am trying with DataFrame. – lpt Jul 19 '18 at 00:25
  • Updated answer with example LDA code using pyspark and dataframes – raj Jul 19 '18 at 06:15
  • Thanks rajesh. This is a great help. Here is another solution as well: https://stackoverflow.com/questions/51456838/match-index-from-pyspark-dataframe-in-pandas/51457137#51457137 – lpt Jul 21 '18 at 16:20