backend/app/utils/Neo4jHelper.scala (269 lines of code) (raw):

package utils import org.neo4j.driver.v1.Values.parameters import java.util.UUID import java.util.concurrent.CompletionStage import org.neo4j.driver.v1.exceptions.{Neo4jException, NoSuchRecordException, TransientException} import org.neo4j.driver.v1._ import org.neo4j.driver.v1.types.TypeSystem import play.api.Logger import services.Neo4jQueryLoggingConfig import utils.attempt.{Attempt, Failure, Neo4JFailure, Neo4JTransientFailure, NotFoundFailure, UnknownFailure} import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} import scala.util.control.NonFatal import scala.jdk.CollectionConverters._ import utils.Logging class Neo4jHelper(driver: Driver, executionContext: ExecutionContext, queryLoggingConfig: Neo4jQueryLoggingConfig) extends Logging { val slowQueryLogger = Logger("slowqueries") /** * Helper that wraps a Neo4J operation and converts any uncaught exceptions to Neo4JFailures. * * @param f the operation being called * @return an Attempt containing the result (or failure) */ def attemptNeo4J(f: => StatementResult)(): Attempt[StatementResult] = Attempt.catchNonFatal(f) { case NonFatal(t) => Neo4JFailure(t) } /** * This is a useful wrapper around a Neo4J transaction. It has most of the same methods, but instead of blocking * and eventually returning StatementResult objects, it returns a (synchronous) Attempt[StatementResult]. */ class AttemptWrappedTransaction(underlying: StatementRunner, executionContext: ExecutionContext) { def run(statementTemplate: String, statementParameters: Record): Attempt[StatementResult] = attemptNeo4J({ underlying.run(statementTemplate, statementParameters) })() def run(statementTemplate: String): Attempt[StatementResult] = attemptNeo4J({ underlying.run(statementTemplate) })() def run(statementTemplate: String, parameters: Value): Attempt[StatementResult] = attemptNeo4J({ underlying.run(statementTemplate, parameters) })() def run(statementTemplate: String, statementParameters: java.util.Map[String, AnyRef]): Attempt[StatementResult] = attemptNeo4J({ underlying.run(statementTemplate, statementParameters) })() def run(statement: Statement): Attempt[StatementResult] = attemptNeo4J({ underlying.run(statement) })() def run(statementTemplate: String, parameters: (String, AnyRef)*): Attempt[StatementResult] = { val flattenedParameters = parameters.flatten { case (k, v) => Vector(k,v) } val parametersObj = Values.parameters(flattenedParameters:_*) run(statementTemplate, parametersObj) } } /** * This is a wrapper around a Neo4J transaction that logs slow queries. */ class LoggingTransaction(underlying: Transaction, config: Neo4jQueryLoggingConfig) extends Transaction { // a UUID to uniquely identify this transaction if it is slow private lazy val uuid = UUID.randomUUID().toString // collect all of the statements run in this transaction private var statements = mutable.Buffer.empty[() => String] def logNeo4J(logData: => String)(f: => StatementResult): StatementResult = { statements += (() => logData) val start = System.currentTimeMillis() if(config.logAllQueries) { logger.info(s"$uuid START NEO4J QUERY: $logData") } val result = f val timeTaken = System.currentTimeMillis() - start if (timeTaken >= config.slowQueryThreshold.toMillis) { slowQueryLogger.warn(s"$uuid SLOW NEO4J QUERY - ${timeTaken}ms: $logData") } else if(config.logAllQueries) { logger.info(s"$uuid FINISHED NEO4J QUERY - ${timeTaken}ms: $logData") } result } def run(statementTemplate: String, statementParameters: Record): StatementResult = logNeo4J(s"$statementTemplate [${statementParameters.fields().asScala}]"){ underlying.run(statementTemplate, statementParameters) } def run(statementTemplate: String): StatementResult = logNeo4J(s"$statementTemplate"){ underlying.run(statementTemplate) } def run(statementTemplate: String, parameters: Value): StatementResult = logNeo4J(s"$statementTemplate [$parameters]"){ underlying.run(statementTemplate, parameters) } def run(statementTemplate: String, statementParameters: java.util.Map[String, AnyRef]): StatementResult = logNeo4J(s"$statementTemplate [${statementParameters.asScala}]") { underlying.run(statementTemplate, statementParameters) } def run(statement: Statement): StatementResult = logNeo4J(s"$statement") { underlying.run(statement) } // TODO MRB: log slow async queries? (we would have to wait for all the results to complete) def runAsync(statementTemplate: String, statementParameters: Record): CompletionStage[StatementResultCursor] = underlying.runAsync(statementTemplate, statementParameters) def runAsync(statementTemplate: String): CompletionStage[StatementResultCursor] = underlying.runAsync(statementTemplate) def runAsync(statementTemplate: String, parameters: Value): CompletionStage[StatementResultCursor] = underlying.runAsync(statementTemplate, parameters) def runAsync(statementTemplate: String, statementParameters: java.util.Map[String, AnyRef]): CompletionStage[StatementResultCursor] = underlying.runAsync(statementTemplate, statementParameters) override def runAsync(statement: Statement): CompletionStage[StatementResultCursor] = { underlying.runAsync(statement) } override def commitAsync(): CompletionStage[Void] = underlying.commitAsync() override def rollbackAsync(): CompletionStage[Void] = underlying.rollbackAsync() override def typeSystem(): TypeSystem = underlying.typeSystem() override def close(): Unit = underlying.close() override def success(): Unit = { val start = System.currentTimeMillis() try { underlying.success() } finally { val timeTaken = System.currentTimeMillis() - start if (timeTaken >= config.slowQueryThreshold.toMillis) { slowQueryLogger.warn(s"$uuid SLOW NEO4J COMMIT - ${timeTaken}ms: transaction statements: ${statements.map(s => s()).mkString("\n")}") } else if(config.logAllQueries) { logger.info(s"$uuid NEO4J COMMIT - ${timeTaken}ms: transaction statements: ${statements.map(s => s()).mkString("\n")}") } } } override def failure(): Unit = underlying.failure() override def isOpen: Boolean = underlying.isOpen } def attemptTransaction[T](f: AttemptWrappedTransaction => Attempt[T]): Attempt[T] = { val session = driver.session() val tx = session.beginTransaction() val future = Future { // do the whole of the transaction inside a future f(new AttemptWrappedTransaction(new LoggingTransaction(tx, queryLoggingConfig), executionContext)) }(executionContext) .flatMap { // strip off the attempt and flatten it _.underlying }(executionContext) .map { either => // notify the transaction whether the attempt was successful or not either.fold[Unit]( _ => tx.failure(), _ => tx.success() ) either }(executionContext) .transform { value => // ensure the transaction and session are closed when we've finished try { if (tx.isOpen) tx.close() if (session.isOpen) session.close() } catch { case NonFatal(t) => { logger.warn("Failed to close session", t) throw t } } value }(executionContext) // re-wrap the result in Attempt Attempt(future) } def getDeadlockedNodes(tx: Transaction, transientException: TransientException): List[String] = { val nodePattern = """RWLock\[NODE\(([0-9]*)\)""".r val nodeIds = nodePattern .findAllMatchIn(transientException.getMessage) .toList .map(_.group(1).toInt) val nodesStatementResult = tx.run( """MATCH (n) WHERE id(n) IN {nodeIds} RETURN n""", parameters("nodeIds", nodeIds.asJava) ) val nodes = nodesStatementResult .list() .asScala .toList .map(_.get("n").asNode()) nodes.map(n => s"node (:${n.labels().asScala.toList.mkString(":")} {id: ${n.id()}})") } def getDeadlockedRelationships(tx: Transaction, transientException: TransientException): List[String] = { val relationshipPattern = """RWLock\[RELATIONSHIP\(([0-9]*)\)""".r val relationshipIds = relationshipPattern .findAllMatchIn(transientException.getMessage) .toList .map(_.group(1).toInt) val relationshipsStatementResult = tx.run( """MATCH ()-[r]->() WHERE id(r) IN {relationshipIds} RETURN r""", parameters("relationshipIds", relationshipIds.asJava) ) val relationships = relationshipsStatementResult .list() .asScala .toList .map(_.get("r").asRelationship()) relationships.map(r => s"relationship ({id: ${r.startNodeId()}})-[:${r.`type`()}]->({id: ${r.endNodeId()}})") } def getDeadlockedNodesAndRelationships(session: Session, transientException: TransientException): List[String] = { val tx = session.beginTransaction() try { List( getDeadlockedRelationships(tx, transientException), getDeadlockedNodes(tx, transientException) ).flatten } catch { case NonFatal(ex) => { tx.failure() logger.error("Error attempting to get deadlocked nodes and relationships", ex) List() } } finally { tx.close() } } def transaction[T](f: StatementRunner => Either[Failure, T]): Either[Failure, T] = { val session = driver.session() val tx = session.beginTransaction() try { val result = f(new LoggingTransaction(tx, queryLoggingConfig)) if (result.isRight) { tx.success() } else { tx.failure() } tx.close() session.close() result } catch { case transientException: TransientException => // example exception message: // Caught error from neo4j: LockClient[680704] can't wait on resource RWLock[NODE(4751773), hash=582829124] since => LockClient[680704] <-[:HELD_BY]- RWLock[RELATIONSHIP(10943051), hash=1236839988] <-[:WAITING_FOR]- LockClient[680747] <-[:HELD_BY]- RWLock[NODE(4751773), hash=582829124]) // We need to close off the original transaction, because if we try and run further statements we get: // "Cannot run more statements in this transaction, because previous statements // in the transaction has failed and the transaction has been rolled back. // Please start a new transaction to run another statement." tx.failure() tx.close() val deadlockedNodesAndRelationships = getDeadlockedNodesAndRelationships(session, transientException) logger.info(s"""Deadlocked on ${deadlockedNodesAndRelationships.mkString(", ")}""") Left(Neo4JTransientFailure(transientException)) case NonFatal(ex) => tx.failure() // We want to get a stack trace here otherwise we only get async transport level frames from the neo4j client val wrappingException = new IllegalStateException("Caught error from neo4j: " + ex.getMessage, ex) Left(UnknownFailure(wrappingException)) } finally { if (tx.isOpen) tx.close() if (session.isOpen) session.close() } } } object Neo4jHelper { implicit class RichRecords[T <: Iterable[Record]](result: T) { def hasKeyOrFailure(expectedKey: String, errorMessage: String): Either[Failure, T] = { if (!result.exists(x => x.containsKey(expectedKey))) { Left(NotFoundFailure(errorMessage)) } else { Right(result) } } def hasKeyOrFailure(expectedKey: String, error: Failure): Attempt[T] = { if (!result.exists(x => x.containsKey(expectedKey))) { Attempt.Left(error) } else { Attempt.Right(result) } } } implicit class RichRecord(record: Record) { def hasKeyOrFailure(expectedKey: String, error: Failure): Attempt[Record] = { if (!record.containsKey(expectedKey)) { Attempt.Left(error) } else { Attempt.Right(record) } } } implicit class RichStatementResult(result: StatementResult) { def hasKeyOrFailure(expectedKey: String, error: Failure)(implicit executionContext: ExecutionContext): Attempt[Record] = { Attempt.catchNonFatal(result.single()) { case _: NoSuchRecordException => error case neo4j: Neo4jException => Neo4JFailure(neo4j) }.flatMap { _.hasKeyOrFailure(expectedKey, error) } } def iterator: Iterator[Record] = new Iterator[Record]{ override def hasNext: Boolean = result.hasNext override def next(): Record = result.next() } } }