in src/main/python/deeplearning/tf_fm_on_spark.py [0:0]
def pre_train(env):
spark_context = SparkContext.getOrCreate()
spark = SparkSession(spark_context).builder.getOrCreate()
rating_df = spark.read.format('csv').option('header', 'True').load('/moviedata/ratings.csv')
movie_df = spark.read.format('csv').option('header', 'True').load('/moviedata/movies.csv')
# process user first
distinct_user_df = rating_df.select('userId').distinct()
users_number = distinct_user_df.count()
env.get("algo")["users_number"] = str(users_number)
users_row = distinct_user_df.collect()
users = []
users_dict = []
users_map = {}
for user in users_row:
users.append(user['userId'])
sorted_users = sorted(users)
for user in sorted_users:
users_dict.append((user,len(users_dict)))
users_map[user] = len(users_map)
# It is use for later process, to get the sorted user id.
columns = ["userid","id"]
users_sort_df = spark.createDataFrame(users_dict,columns)
# users_sort_df.write.format("csv").save("/moviedata/sortedusers")
# process genres
geners_row = movie_df.select("genres").distinct().collect()
genres_set = set()
genres_map = {}
for genres in geners_row:
for one_genre in genres['genres'].split('|'):
genres_set.add(one_genre)
for genre in genres_set:
genres_map[genre] = len(genres_map)
# join two dataframe and process later, userid(bigint) genres(string, need split), rating(float)
joined_df = rating_df.join(movie_df, rating_df.movieId == movie_df.movieId)
joined_df = joined_df.select(col('userId'),col('genres'),col('rating').cast('float').alias('rating'))
users_map_bc = spark_context.broadcast(users_map)
genres_map_bc = spark_context.broadcast(genres_map)
env.get("algo")["genres_number"] = str(len(genres_map))
def process_row(row):
userId = row.userId
genres = row.genres
users_map_rdd = users_map_bc.value
genres_map_rdd = genres_map_bc.value
genres_return_list = []
for i in genres.split("|"):
genres_return_list.append(str(genres_map_rdd[i]))
return (users_map_rdd[userId], "|".join(genres_return_list), row.rating)
return joined_df.rdd.map(process_row).toDF(['userId','genres','rating'])