def fit()

in contrib/sarplus/python/pysarplus/SARPlus.py [0:0]


    def fit(self, df):
        """Main fit method for SAR.

        Expects the dataframes to have row_id, col_id columns which are indexes,
        i.e. contain the sequential integer index of the original alphanumeric user and item IDs.
        Dataframe also contains rating and timestamp as floats; timestamp is in seconds since Epoch by default.

        Arguments:
            df (pySpark.DataFrame): input dataframe which contains the index of users and items.
        """
        # threshold - items below this number get set to zero in coocurrence counts

        df.createOrReplaceTempView(self.f("{prefix}df_train_input"))

        if self.timedecay_formula:
            # WARNING: previously we would take the last value in training dataframe and set it
            # as a matrix U element
            # for each user-item pair. Now with time decay, we compute a sum over ratings given
            # by a user in the case
            # when T=np.inf, so user gets a cumulative sum of ratings for a particular item and
            # not the last rating.
            # Time Decay
            # do a group by on user item pairs and apply the formula for time decay there
            # Time T parameter is in days and input time is in seconds
            # so we do dt/60/(T*24*60)=dt/(T*24*3600)
            # the folling is the query which we want to run

            query = self.f(
                """
            SELECT
                 {col_user}, {col_item}, 
                 SUM({col_rating} * EXP(-log(2) * (latest_timestamp - CAST({col_timestamp} AS long)) / ({time_decay_coefficient} * 3600 * 24))) as {col_rating}
            FROM {prefix}df_train_input,
                 (SELECT CAST(MAX({col_timestamp}) AS long) latest_timestamp FROM {prefix}df_train_input)
            GROUP BY {col_user}, {col_item} 
            CLUSTER BY {col_user} 
            """
            )

            # replace with timedecayed version
            df = self.spark.sql(query)
        else:
            # since SQL is case insensitive, this check needs to be performed similar
            if self.header["col_timestamp"].lower() in [
                s.name.lower() for s in df.schema
            ]:
                # we need to de-duplicate items by using the latest item
                query = self.f(
                    """
                SELECT {col_user}, {col_item}, {col_rating}
                FROM
                (
                SELECT
                    {col_user}, {col_item}, {col_rating}, 
                    ROW_NUMBER() OVER (PARTITION BY {col_user}, {col_item} ORDER BY {col_timestamp} DESC) latest
                FROM {prefix}df_train_input
                )
                WHERE latest = 1
                """
                )

                df = self.spark.sql(query)

        df.createOrReplaceTempView(self.f("{prefix}df_train"))

        log.info("sarplus.fit 1/2: compute item cooccurences...")

        # compute cooccurrence above minimum threshold
        query = self.f(
            """
        SELECT A.{col_item} i1, B.{col_item} i2, COUNT(*) value
        FROM   {prefix}df_train A INNER JOIN {prefix}df_train B
               ON A.{col_user} = B.{col_user} AND A.{col_item} <= b.{col_item}  
        GROUP  BY A.{col_item}, B.{col_item}
        HAVING COUNT(*) >= {threshold}
        CLUSTER BY i1, i2
        """
        )

        item_cooccurrence = self.spark.sql(query)
        item_cooccurrence.write.mode("overwrite").saveAsTable(
            self.f("{prefix}item_cooccurrence")
        )

        # compute the diagonal used later for Jaccard and Lift
        if self.similarity_type == SIM_LIFT or self.similarity_type == SIM_JACCARD:
            item_marginal = self.spark.sql(
                self.f(
                    "SELECT i1 i, value AS margin FROM {prefix}item_cooccurrence WHERE i1 = i2"
                )
            )
            item_marginal.createOrReplaceTempView(self.f("{prefix}item_marginal"))

        if self.similarity_type == SIM_COOCCUR:
            self.item_similarity = item_cooccurrence
        elif self.similarity_type == SIM_JACCARD:
            query = self.f(
                """
            SELECT i1, i2, value / (M1.margin + M2.margin - value) AS value
            FROM {prefix}item_cooccurrence A 
                INNER JOIN {prefix}item_marginal M1 ON A.i1 = M1.i 
                INNER JOIN {prefix}item_marginal M2 ON A.i2 = M2.i
            CLUSTER BY i1, i2
            """
            )
            self.item_similarity = self.spark.sql(query)
        elif self.similarity_type == SIM_LIFT:
            query = self.f(
                """
            SELECT i1, i2, value / (M1.margin * M2.margin) AS value
            FROM {prefix}item_cooccurrence A 
                INNER JOIN {prefix}item_marginal M1 ON A.i1 = M1.i 
                INNER JOIN {prefix}item_marginal M2 ON A.i2 = M2.i
            CLUSTER BY i1, i2
            """
            )
            self.item_similarity = self.spark.sql(query)
        else:
            raise ValueError(
                "Unknown similarity type: {0}".format(self.similarity_type)
            )

        # store upper triangular
        log.info(
            "sarplus.fit 2/2: compute similiarity metric %s..." % self.similarity_type
        )
        self.item_similarity.write.mode("overwrite").saveAsTable(
            self.f("{prefix}item_similarity_upper")
        )

        # expand upper triangular to full matrix

        query = self.f(
            """
        SELECT i1, i2, value
        FROM
        (
          (SELECT i1, i2, value FROM {prefix}item_similarity_upper)
          UNION ALL
          (SELECT i2 i1, i1 i2, value FROM {prefix}item_similarity_upper WHERE i1 <> i2)
        )
        CLUSTER BY i1
        """
        )

        self.item_similarity = self.spark.sql(query)
        self.item_similarity.write.mode("overwrite").saveAsTable(
            self.f("{prefix}item_similarity")
        )

        # free space
        self.spark.sql(self.f("DROP TABLE {prefix}item_cooccurrence"))
        self.spark.sql(self.f("DROP TABLE {prefix}item_similarity_upper"))

        self.item_similarity = self.spark.table(self.f("{prefix}item_similarity"))