def wraprepo()

in eden/scm/edenscm/hgext/hgsql.py [0:0]


def wraprepo(repo):
    class sqllocalrepo(repo.__class__):
        def sqlconnect(self):
            if self.sqlconn:
                return

            retry = 3
            while True:
                try:
                    try:
                        self.sqlconn = mysql.connector.connect(
                            force_ipv6=True, ssl_disabled=True, **self.sqlargs
                        )
                    except AttributeError:
                        self.sqlconn = mysql.connector.connect(
                            force_ipv6=True, **self.sqlargs
                        )

                    # The default behavior is to return byte arrays, when we
                    # need strings. This custom convert returns strings.
                    self.sqlconn.set_converter_class(CustomConverter)
                    break
                except mysql.connector.errors.Error:
                    # mysql can be flakey occasionally, so do some minimal
                    # retrying.
                    retry -= 1
                    if retry == 0:
                        raise
                    time.sleep(0.2)

            waittimeout = self.ui.config("hgsql", "waittimeout")
            sqltimeout = self.ui.configint("hgsql", "sqltimeout") * 1000
            waittimeout = self.sqlconn.converter.escape("%s" % waittimeout)

            self.engine = self.ui.config("hgsql", "engine")
            self.locktimeout = self.ui.config("hgsql", "locktimeout")
            self.locktimeout = self.sqlconn.converter.escape("%s" % self.locktimeout)

            self.sqlcursor = self.sqlconn.cursor()
            self._sqlreadrows = 0
            self._sqlwriterows = 0

            # Patch sqlcursor so it updates the read write counters.
            def _fetchallupdatereadcount(orig):
                result = orig()
                self._sqlreadrows += self.sqlcursor.rowcount
                return result

            def _executeupdatewritecount(orig, sql, *args, **kwargs):
                result = orig(sql, *args, **kwargs)
                # naive ways to detect "writes"
                if sql.split(" ", 1)[0].upper() in {"DELETE", "UPDATE", "INSERT"}:
                    self._sqlwriterows += self.sqlcursor.rowcount
                return result

            wrapfunction(self.sqlcursor, "fetchall", _fetchallupdatereadcount)
            wrapfunction(self.sqlcursor, "execute", _executeupdatewritecount)

            self.sqlcursor.execute("SET wait_timeout=%s" % waittimeout)
            self.sqlcursor.execute("SET SESSION MAX_STATEMENT_TIME=%s" % sqltimeout)

            if self.engine == "rocksdb":
                self.sqlcursor.execute(
                    "SET rocksdb_lock_wait_timeout=%s" % self.locktimeout
                )
            elif self.engine == "innodb":
                self.sqlcursor.execute(
                    "SET innodb_lock_wait_timeout=%s" % self.locktimeout
                )
            else:
                raise RuntimeError("unsupported hgsql.engine %s" % self.engine)

        def sqlclose(self):
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                self.sqlcursor.close()
                self.sqlconn.close()
            self.sqlcursor = None
            self.sqlconn = None

        def sqlreporeadonlystate(self):
            NO_WRITE = 0
            MONONOKE_WRITE = 2
            DEFAULT_REASON = "no reason was provided"
            MONONOKE_REASON = "writes are being served by Mononoke (fburl.com/mononoke)"
            query = "SELECT state, reason FROM repo_lock WHERE repo = %s"

            self.sqlconnect()

            self.sqlcursor.execute(query, (self.sqlreponame,))
            rows = self.sqlcursor.fetchall()

            if not rows:
                # If there isn't an entry for this repo, let's treat it as
                # unlocked.
                return (False, DEFAULT_REASON)

            state, reason = rows[0]

            readonly = state == NO_WRITE or state == MONONOKE_WRITE

            if reason is None:
                reason = {MONONOKE_WRITE: MONONOKE_REASON}.get(state, DEFAULT_REASON)
            else:
                reason = decodeutf8(reason)

            return (readonly, reason)

        def sqlisreporeadonly(self):
            """deprecated: use sqlreporeadonlystate() to also get the reason"""
            return self.sqlreporeadonlystate()[0]

        def _hgsqlnote(self, message):
            if self.ui.configbool("hgsql", "verbose"):
                self.ui.write_err("[hgsql] %s\n" % message)
            self.ui.debug("%s\n" % message)

        def _lockname(self, name):
            lockname = "%s_%s" % (name, self.sqlreponame)
            return self.sqlconn.converter.escape(lockname)

        def _sqllock(self, name, trysync):
            """If trysync is True, try to sync the repo outside the lock so it
            stays closer to the actual repo when the lock is acquired.
            """
            lockname = self._lockname(name)
            syncinterval = float(self.ui.config("hgsql", "syncinterval") or -1)
            if syncinterval < 0:
                trysync = False

            if trysync:
                minwaittime = self.ui.configint("hgsql", "debugminsqllockwaittime")
                # SELECT GET_LOCK(...) will block. Break the lock attempt into
                # smaller lock attempts
                starttime = time.time()
                locktimeout = float(self.locktimeout)
                while True:
                    elapsed = time.time() - starttime
                    if elapsed >= locktimeout:
                        raise util.Abort(
                            "timed out waiting for mysql repo lock (%s)" % lockname
                        )

                    # Sync outside the SQL lock hoping that the repo is closer
                    # to the SQL repo when we got the lock.
                    self.pullfromdb(enforcepullfromdb=True)
                    if elapsed < minwaittime:
                        # Pretend we wait and timed out, without actually
                        # getting the SQL lock. This is useful for testing.
                        time.sleep(syncinterval)
                    else:
                        # Try to acquire SQL lock, with a small timeout. So
                        # "forcesync" can get executed more frequently.
                        self.sqlcursor.execute(
                            "SELECT GET_LOCK('%s', %s)" % (lockname, syncinterval)
                        )
                        result = int(self.sqlcursor.fetchall()[0][0])
                        if result == 1:
                            break
            else:
                self.sqlcursor.execute(
                    "SELECT GET_LOCK('%s', %s)" % (lockname, self.locktimeout)
                )
                # cast to int to prevent passing bad sql data
                result = int(self.sqlcursor.fetchall()[0][0])
                if result != 1:
                    raise util.Abort(
                        "timed out waiting for mysql repo lock (%s)" % lockname
                    )
            self.heldlocks.add(name)

        def sqlwritelock(self, trysync=False):
            self._enforcelocallocktaken()
            self._sqllock(writelock, trysync)

        def _hassqllock(self, name, checkserver=True):
            if not name in self.heldlocks:
                return False

            if not checkserver:
                return True

            lockname = self._lockname(name)
            self.sqlcursor.execute("SELECT IS_USED_LOCK('%s')" % (lockname,))
            lockheldby = self.sqlcursor.fetchall()[0][0]
            if lockheldby == None:
                raise Exception("unable to check %s lock" % lockname)

            self.sqlcursor.execute("SELECT CONNECTION_ID()")
            myconnectid = self.sqlcursor.fetchall()[0][0]
            if myconnectid == None:
                raise Exception("unable to read connection id")

            return lockheldby == myconnectid

        def hassqlwritelock(self, checkserver=True):
            return self._hassqllock(writelock, checkserver)

        def _sqlunlock(self, name):
            lockname = self._lockname(name)
            self.sqlcursor.execute("SELECT RELEASE_LOCK('%s')" % (lockname,))
            self.sqlcursor.fetchall()
            self.heldlocks.discard(name)

            for callback in self.sqlpostrelease:
                callback()
            self.sqlpostrelease = []

        def sqlwriteunlock(self):
            self._enforcelocallocktaken()
            self._sqlunlock(writelock)

        def _enforcelocallocktaken(self):
            if self._issyncing:
                return
            if self._currentlock(self._lockref):
                return
            raise error.ProgrammingError("invalid lock order")

        def transaction(self, *args, **kwargs):
            tr = super(sqllocalrepo, self).transaction(*args, **kwargs)
            if tr.count > 1:
                return tr

            validator = tr.validator

            def pretxnclose(tr):
                validator(tr)
                self.committodb(tr)
                del self.pendingrevs[:]

            tr.validator = pretxnclose

            def transactionabort(orig):
                del self.pendingrevs[:]
                return orig()

            wrapfunction(tr, "_abort", transactionabort)

            tr.repo = self
            return tr

        def needsyncfast(self):
            """Returns True if the local repo might be out of sync.
            False otherwise.

            Faster than needsync. But do not return bookmarks or heads.
            """
            # Calculate local checksum in background.
            localsynchashes = []
            localthread = threading.Thread(
                target=lambda results, repo: results.append(repo._localsynchash()),
                args=(localsynchashes, self),
            )
            localthread.start()
            # Let MySQL do the same calculation on their side
            sqlsynchash = self._sqlsynchash()
            localthread.join()
            return sqlsynchash != localsynchashes[0]

        def _localsynchash(self):
            refs = dict(self._bookmarks)
            refs["tip"] = self["tip"].rev()
            sha = ""
            for k, v in sorted(pycompat.iteritems(refs)):
                if k != "tip":
                    v = hex(v)
                sha = hashlib.sha1(encodeutf8("%s%s%s" % (sha, k, v))).hexdigest()
            return sha

        def _sqlsynchash(self):
            sql = """
            SET @sha := '', @id = 0;
            SELECT sha FROM (
                SELECT
                    @id := @id + 1 as id,
                    @sha := sha1(concat(@sha, name, value)) as sha
                FROM revision_references
                WHERE repo = %s AND namespace IN ('bookmarks', 'tip') ORDER BY name
            ) AS t ORDER BY id DESC LIMIT 1;
            """

            sqlresults = [
                sqlresult.fetchall()
                for sqlresult in repo.sqlcursor.execute(
                    sql, (self.sqlreponame,), multi=True
                )
                if sqlresult.with_rows
            ]
            # is it a new repo with empty references?
            if sqlresults == [[]]:
                return hashlib.sha1(encodeutf8("%s%s" % ("tip", -1))).hexdigest()
            # sqlresults looks like [[('59237a7416a6a1764ea088f0bc1189ea58d5b592',)]]
            sqlsynchash = sqlresults[0][0][0]
            if len(sqlsynchash) != 40:
                raise RuntimeError("malicious SHA1 returned by MySQL: %r" % sqlsynchash)
            return decodeutf8(sqlsynchash)

        def needsync(self):
            """Returns True if the local repo is not in sync with the database.
            If it returns False, the heads and bookmarks match the database.

            Also return bookmarks and heads.
            """
            self.sqlcursor.execute(