Retrie/Elaborate.hs (105 lines of code) (raw):
-- Copyright (c) Facebook, Inc. and its affiliates.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree.
--
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE PackageImports #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeApplications #-}
module Retrie.Elaborate
  ( defaultElaborations
  , elaborateRewritesInternal
  ) where
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import "list-t" ListT
import Data.Maybe
import Retrie.Context
import Retrie.ExactPrint
import Retrie.Expr
import Retrie.Fixity
import Retrie.GHC
import Retrie.Quantifiers
import Retrie.Rewrites
import Retrie.Subst
import Retrie.Substitution
import Retrie.SYB
import Retrie.Types
import Retrie.Universe
defaultElaborations :: [RewriteSpec]
defaultElaborations =
  [ Adhoc "forall f x. f $ x = f (x)"
  ]
elaborateRewritesInternal
  :: FixityEnv
  -> [Rewrite Universe]
  -> [Rewrite Universe]
  -> IO [Rewrite Universe]
elaborateRewritesInternal _ [] rewrites = return rewrites
elaborateRewritesInternal fixityEnv elaborations rewrites =
  concat <$> mapM (elaborateOne fixityEnv elaborator) rewrites
  where
    elaborator = foldMap mkRewriter elaborations
elaborateOne :: FixityEnv -> Rewriter -> Rewrite Universe -> IO [Rewrite Universe]
elaborateOne fixityEnv elaborator rr = do
  patterns <-
    transformA (qPattern rr) $ toList .
      everywhereMWithContextBut topDown
        (const False) (\c i x -> lift $ updateContext c i x) elaborate ctxt
  return [ rr { qPattern = pattern } | pattern <- sequenceA patterns ]
  where
    ctxt = emptyContext fixityEnv elaborator mempty
elaborate
  :: (Data a, MonadIO m) => Context -> a -> ListT (TransformT m) a
elaborate c =
  mkM (elaborateImpl @(HsExpr GhcPs) c)
    `extM` (elaborateImpl @(Stmt GhcPs (LHsExpr GhcPs)) c)
    `extM` (elaborateImpl @(HsType GhcPs) c)
    `extM` (elaboratePat c)
elaboratePat :: MonadIO m => Context -> LPat GhcPs -> ListT (TransformT m) (LPat GhcPs)
-- We need to ensure we have a location available at the top level so we can
-- transfer annotations. This ensures we don't try to rewrite a naked Pat.
elaboratePat c p
  | Just lp <- dLPat p = cLPat <$> elaborateImpl c lp
  | otherwise = return p
elaborateImpl
  :: forall ast m. (Data ast, ExactPrint ast, Matchable (LocatedA ast), MonadIO m)
  => Context -> LocatedA ast -> ListT (TransformT m) (LocatedA ast)
elaborateImpl ctxt e = do
  elaborations <- lift $ do
    matches <- runMatcher ctxt (ctxtRewriter ctxt) (getUnparened e)
    validMatches <- allMatches ctxt matches
    forM [ (sub, tmpl) | MatchResult sub tmpl <- validMatches ] $ \(sub, Template{..}) -> do
      -- graft template into target
      t' <- graftA tTemplate
      -- substitute for quantifiers in grafted template
      r <- subst sub ctxt t'
      -- copy appropriate annotations from old expression to template
      r0 <- addAllAnnsT e r
      -- add parens to template if needed
      (mkM (parenify ctxt) `extM` parenifyT ctxt `extM` parenifyP ctxt) r0
  fromFoldable (e : elaborations)
-- | Find the first 'valid' match.
-- Runs the user's 'MatchResultTransformer' and sanity checks the result.
allMatches
  :: (Matchable ast, MonadIO m)
  => Context
  -> [(Substitution, RewriterResult Universe)]
  -> TransformT m [MatchResult ast]
allMatches _ [] = return []
allMatches ctxt matchResults = do
  results <-
    forM matchResults $ \(sub, RewriterResult{..}) -> do
      result <- lift $ liftIO $ rrTransformer ctxt $ MatchResult sub rrTemplate
      return (rrQuantifiers, result)
  return
    [ project <$> result
    | (quantifiers, result@(MatchResult sub' _)) <- results
      -- Check that all quantifiers from the original rewrite have mappings
      -- in the resulting substitution. This is mostly to prevent a bad
      -- user-defined MatchResultTransformer from causing havok.
    , isJust $ sequence [ lookupSubst q sub' | q <- qList quantifiers ]
    ]