def pre_train()

in src/main/python/deeplearning/tf_fm_on_spark.py [0:0]


    def pre_train(env):
        spark_context = SparkContext.getOrCreate()
        spark = SparkSession(spark_context).builder.getOrCreate()
        rating_df = spark.read.format('csv').option('header', 'True').load('/moviedata/ratings.csv')
        movie_df = spark.read.format('csv').option('header', 'True').load('/moviedata/movies.csv')

        # process user first
        distinct_user_df = rating_df.select('userId').distinct()
        users_number = distinct_user_df.count()
        env.get("algo")["users_number"] = str(users_number)

        users_row = distinct_user_df.collect()
        users = []
        users_dict = []
        users_map = {}
        for user in users_row:
            users.append(user['userId'])
        sorted_users = sorted(users)
        for user in sorted_users:
            users_dict.append((user,len(users_dict)))
            users_map[user] = len(users_map)

        # It is use for later process, to get the sorted user id.
        columns = ["userid","id"]
        users_sort_df = spark.createDataFrame(users_dict,columns)
        # users_sort_df.write.format("csv").save("/moviedata/sortedusers")

        # process genres
        geners_row = movie_df.select("genres").distinct().collect()
        genres_set = set()
        genres_map = {}
        for genres in geners_row:
            for one_genre in genres['genres'].split('|'):
                genres_set.add(one_genre)
        for genre in genres_set:
            genres_map[genre] = len(genres_map)

        # join two dataframe and process later, userid(bigint) genres(string, need split), rating(float)
        joined_df = rating_df.join(movie_df, rating_df.movieId == movie_df.movieId)
        joined_df = joined_df.select(col('userId'),col('genres'),col('rating').cast('float').alias('rating'))

        users_map_bc = spark_context.broadcast(users_map)
        genres_map_bc = spark_context.broadcast(genres_map)
        env.get("algo")["genres_number"] = str(len(genres_map))

        def process_row(row):
            userId = row.userId
            genres = row.genres
            users_map_rdd = users_map_bc.value
            genres_map_rdd = genres_map_bc.value
            genres_return_list = []
            for i in genres.split("|"):
                genres_return_list.append(str(genres_map_rdd[i]))
            return (users_map_rdd[userId], "|".join(genres_return_list), row.rating)

        return joined_df.rdd.map(process_row).toDF(['userId','genres','rating'])