{-# LANGUAGE TypeFamilies #-}

-- |
-- Module      : AutoProof.Internal.AST
-- Copyright   : (c) Artem Mavrin, 2021
-- License     : BSD3
-- Maintainer  : artemvmavrin@gmail.com
-- Stability   : experimental
-- Portability : POSIX
--
-- Defines an abstract syntax tree class and related functions.
module AutoProof.Internal.AST
  ( -- * Abstract syntax tree class and metadata type
    AST (Root, root, children, height, size, metadata),
    ASTMetadata (ASTMetadata, getHeight, getSize),

    -- * AST functions
    subtrees,
    properSubtrees,

    -- *  Helper functions for creating AST constructors
    atomicASTConstructor,
    unaryASTConstructor,
    binaryASTConstructor,
    unaryRootedASTConstructor,
    binaryRootedASTConstructor,
    ternaryRootedASTConstructor,
  )
where

import Data.Set (Set)
import qualified Data.Set as Set

-- | Container type for AST properties, intended for constant-time access.
data ASTMetadata = ASTMetadata
  { -- | Get an AST's height (see also 'height')
    ASTMetadata -> Int
getHeight :: !Int,
    -- | Get an AST's size (see also 'size')
    ASTMetadata -> Int
getSize :: !Int
  }
  deriving (ASTMetadata -> ASTMetadata -> Bool
(ASTMetadata -> ASTMetadata -> Bool)
-> (ASTMetadata -> ASTMetadata -> Bool) -> Eq ASTMetadata
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ASTMetadata -> ASTMetadata -> Bool
$c/= :: ASTMetadata -> ASTMetadata -> Bool
== :: ASTMetadata -> ASTMetadata -> Bool
$c== :: ASTMetadata -> ASTMetadata -> Bool
Eq, Eq ASTMetadata
Eq ASTMetadata =>
(ASTMetadata -> ASTMetadata -> Ordering)
-> (ASTMetadata -> ASTMetadata -> Bool)
-> (ASTMetadata -> ASTMetadata -> Bool)
-> (ASTMetadata -> ASTMetadata -> Bool)
-> (ASTMetadata -> ASTMetadata -> Bool)
-> (ASTMetadata -> ASTMetadata -> ASTMetadata)
-> (ASTMetadata -> ASTMetadata -> ASTMetadata)
-> Ord ASTMetadata
ASTMetadata -> ASTMetadata -> Bool
ASTMetadata -> ASTMetadata -> Ordering
ASTMetadata -> ASTMetadata -> ASTMetadata
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ASTMetadata -> ASTMetadata -> ASTMetadata
$cmin :: ASTMetadata -> ASTMetadata -> ASTMetadata
max :: ASTMetadata -> ASTMetadata -> ASTMetadata
$cmax :: ASTMetadata -> ASTMetadata -> ASTMetadata
>= :: ASTMetadata -> ASTMetadata -> Bool
$c>= :: ASTMetadata -> ASTMetadata -> Bool
> :: ASTMetadata -> ASTMetadata -> Bool
$c> :: ASTMetadata -> ASTMetadata -> Bool
<= :: ASTMetadata -> ASTMetadata -> Bool
$c<= :: ASTMetadata -> ASTMetadata -> Bool
< :: ASTMetadata -> ASTMetadata -> Bool
$c< :: ASTMetadata -> ASTMetadata -> Bool
compare :: ASTMetadata -> ASTMetadata -> Ordering
$ccompare :: ASTMetadata -> ASTMetadata -> Ordering
$cp1Ord :: Eq ASTMetadata
Ord)

-- | Abstract syntax tree class.
class AST t where
  -- | The type of the values annotating AST nodes.
  type Root t

  -- | The value at the AST's root node.
  root :: t -> Root t

  -- | The AST's child ASTs.
  children :: t -> [t]

  -- | The AST's metadata
  metadata :: t -> ASTMetadata

  -- | Number of edges on the longest path from the root of the AST to a leaf.
  height :: t -> Int
  height = ASTMetadata -> Int
getHeight (ASTMetadata -> Int) -> (t -> ASTMetadata) -> t -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> ASTMetadata
forall t. AST t => t -> ASTMetadata
metadata

  -- | Number of nodes in the AST.
  size :: t -> Int
  size = ASTMetadata -> Int
getSize (ASTMetadata -> Int) -> (t -> ASTMetadata) -> t -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> ASTMetadata
forall t. AST t => t -> ASTMetadata
metadata

-- Helper functions for constructing ASTs

-- | Helper function for creating metadata-aware ASTs.
atomicASTConstructor :: (ASTMetadata -> a -> t) -> a -> t
atomicASTConstructor :: (ASTMetadata -> a -> t) -> a -> t
atomicASTConstructor g :: ASTMetadata -> a -> t
g = ASTMetadata -> a -> t
g ASTMetadata
atomicMetadata

-- | Helper function for creating metadata-aware ASTs.
unaryASTConstructor :: AST t => (ASTMetadata -> t -> t) -> t -> t
unaryASTConstructor :: (ASTMetadata -> t -> t) -> t -> t
unaryASTConstructor g :: ASTMetadata -> t -> t
g t :: t
t = ASTMetadata -> t -> t
g (t -> ASTMetadata
forall t. AST t => t -> ASTMetadata
unaryMetadata t
t) t
t

-- | Helper function for creating metadata-aware ASTs.
binaryASTConstructor :: AST t => (ASTMetadata -> t -> t -> t) -> t -> t -> t
binaryASTConstructor :: (ASTMetadata -> t -> t -> t) -> t -> t -> t
binaryASTConstructor g :: ASTMetadata -> t -> t -> t
g t :: t
t u :: t
u = ASTMetadata -> t -> t -> t
g (t -> t -> ASTMetadata
forall t. AST t => t -> t -> ASTMetadata
binaryMetadata t
t t
u) t
t t
u

-- | Helper function for creating metadata-aware rooted ASTs.
unaryRootedASTConstructor :: AST t => (ASTMetadata -> a -> t -> t) -> a -> t -> t
unaryRootedASTConstructor :: (ASTMetadata -> a -> t -> t) -> a -> t -> t
unaryRootedASTConstructor g :: ASTMetadata -> a -> t -> t
g a :: a
a t :: t
t = ASTMetadata -> a -> t -> t
g (t -> ASTMetadata
forall t. AST t => t -> ASTMetadata
unaryMetadata t
t) a
a t
t

-- | Helper function for creating metadata-aware rooted ASTs.
binaryRootedASTConstructor :: AST t => (ASTMetadata -> a -> t -> t -> t) -> a -> t -> t -> t
binaryRootedASTConstructor :: (ASTMetadata -> a -> t -> t -> t) -> a -> t -> t -> t
binaryRootedASTConstructor g :: ASTMetadata -> a -> t -> t -> t
g a :: a
a t :: t
t u :: t
u = ASTMetadata -> a -> t -> t -> t
g (t -> t -> ASTMetadata
forall t. AST t => t -> t -> ASTMetadata
binaryMetadata t
t t
u) a
a t
t t
u

-- | Helper function for creating metadata-aware rooted ASTs.
ternaryRootedASTConstructor :: AST t => (ASTMetadata -> a -> t -> t -> t -> t) -> a -> t -> t -> t -> t
ternaryRootedASTConstructor :: (ASTMetadata -> a -> t -> t -> t -> t) -> a -> t -> t -> t -> t
ternaryRootedASTConstructor g :: ASTMetadata -> a -> t -> t -> t -> t
g a :: a
a t :: t
t u :: t
u v :: t
v = ASTMetadata -> a -> t -> t -> t -> t
g (t -> t -> t -> ASTMetadata
forall t. AST t => t -> t -> t -> ASTMetadata
ternaryMetadata t
t t
u t
v) a
a t
t t
u t
v

-- Internal helper functions for computing metadata for new ASTs

atomicMetadata :: ASTMetadata
atomicMetadata :: ASTMetadata
atomicMetadata = $WASTMetadata :: Int -> Int -> ASTMetadata
ASTMetadata {getHeight :: Int
getHeight = 0, getSize :: Int
getSize = 1}

unaryMetadata :: AST t => t -> ASTMetadata
unaryMetadata :: t -> ASTMetadata
unaryMetadata t :: t
t = $WASTMetadata :: Int -> Int -> ASTMetadata
ASTMetadata {getHeight :: Int
getHeight = 1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ t -> Int
forall t. AST t => t -> Int
height t
t, getSize :: Int
getSize = 1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ t -> Int
forall t. AST t => t -> Int
size t
t}

binaryMetadata :: AST t => t -> t -> ASTMetadata
binaryMetadata :: t -> t -> ASTMetadata
binaryMetadata t :: t
t u :: t
u =
  $WASTMetadata :: Int -> Int -> ASTMetadata
ASTMetadata
    { getHeight :: Int
getHeight = 1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (t -> Int
forall t. AST t => t -> Int
height t
t) (t -> Int
forall t. AST t => t -> Int
height t
u),
      getSize :: Int
getSize = 1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ t -> Int
forall t. AST t => t -> Int
size t
t Int -> Int -> Int
forall a. Num a => a -> a -> a
+ t -> Int
forall t. AST t => t -> Int
size t
u
    }

ternaryMetadata :: AST t => t -> t -> t -> ASTMetadata
ternaryMetadata :: t -> t -> t -> ASTMetadata
ternaryMetadata t :: t
t u :: t
u v :: t
v =
  $WASTMetadata :: Int -> Int -> ASTMetadata
ASTMetadata
    { getHeight :: Int
getHeight = 1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (t -> Int
forall t. AST t => t -> Int
height t
t) (Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (t -> Int
forall t. AST t => t -> Int
height t
u) (t -> Int
forall t. AST t => t -> Int
height t
v)),
      getSize :: Int
getSize = 1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ t -> Int
forall t. AST t => t -> Int
size t
t Int -> Int -> Int
forall a. Num a => a -> a -> a
+ t -> Int
forall t. AST t => t -> Int
size t
u Int -> Int -> Int
forall a. Num a => a -> a -> a
+ t -> Int
forall t. AST t => t -> Int
size t
v
    }

-- | @('subtrees' t)@ is the set of all subtrees of an AST @t@ (including @t@
-- itself).
subtrees :: (AST t, Ord t) => t -> Set t
subtrees :: t -> Set t
subtrees t :: t
t = (t -> Set t -> Set t) -> Set t -> [t] -> Set t
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Set t -> Set t -> Set t
forall a. Ord a => Set a -> Set a -> Set a
Set.union (Set t -> Set t -> Set t) -> (t -> Set t) -> t -> Set t -> Set t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Set t
forall t. (AST t, Ord t) => t -> Set t
subtrees) (t -> Set t
forall a. a -> Set a
Set.singleton t
t) (t -> [t]
forall t. AST t => t -> [t]
children t
t)

-- | @('properSubtrees' t)@ is the the set of all /proper/ subtrees of an AST
-- @t@ (i.e., not including @t@ itself).
properSubtrees :: (AST t, Ord t) => t -> Set t
properSubtrees :: t -> Set t
properSubtrees t :: t
t = (t -> Set t -> Set t) -> Set t -> [t] -> Set t
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Set t -> Set t -> Set t
forall a. Ord a => Set a -> Set a -> Set a
Set.union (Set t -> Set t -> Set t) -> (t -> Set t) -> t -> Set t -> Set t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Set t
forall t. (AST t, Ord t) => t -> Set t
subtrees) Set t
forall a. Set a
Set.empty (t -> [t]
forall t. AST t => t -> [t]
children t
t)