-- Very simple generic programming with a sum-of-products representation based on
-- "Type-Directed Diffing of Structured Data"
-- (Miraldo et al 2017)

open import Data.Unit
open import Data.Empty
open import Data.Product
open import Data.Sum
open import Data.List
open import Data.Nat
open import Data.Bool
open import Data.Fin
open import Relation.Nullary
open import Relation.Binary.PropositionalEquality

{--
 Codes for sums-of-products
 (fixed base types, singly-recursive)
 --}

-- codes for constants
data Konst : Set where
  bool : Konst
  nat  : Konst

-- codes for atoms (constants or recursive variables)
data Atom : Set where
  K : Konst  Atom
  I :         Atom -- (the recursive variable)

-- codes for products and sums.
Prod Sum : Set
Prod = List Atom
Sum  = List Prod

{--
 Interpretations of codes
 (large eliminations, i.e. functions from data to types)
 --}

-- interpretations of constants
⟦_〛ₖ : Konst  Set
 bool 〛ₖ = Bool
 nat 〛ₖ = 

-- interpretations of atoms.
-- There are two parameters: the code, and the interpretation of the
-- recursive variable
⟦_〛ₐ : Atom  (Set  Set)
 K k 〛ₐ _ =   k 〛ₖ
 I 〛ₐ   X = X

-- The interpretation of a product [a₁, a₂, …, aₙ]
-- is ⊤ × 〚a₁ 〛× 〚a₂ 〛× … × 〚aₙ 〛
⟦_〛ₚ : Prod  (Set  Set)
 [] 〛ₚ X =  
 α  p 〛ₚ X =  α 〛ₐ X ×  p 〛ₚ X

-- The interpretation of a sum [p₁, p₂, …, pₙ]
-- is ⊥ ⊎ 〚p₁ 〛⊎ 〚p₂ 〛⊎ … ⊎ 〚pₙ 〛
⟦_〛ₛ : Sum  (Set  Set)
 [] 〛ₛ X = 
 π  s 〛ₛ X =  π 〛ₚ X   s 〛ₛ X

-- Fixed points
data Fix (σ : Sum) : Set where
  ⟨_⟩ :  σ 〛ₛ (Fix σ)  Fix σ

-- Constructors are accessed by position
-- so the type of constructors for a code for n-ary sums
--    is a set with n elements (i.e. Fin n)
Constr : Sum  Set
Constr σ = Fin (length σ)

-- typeOf is a lookup function for the code
typeOf : (σ : Sum)  Constr σ  Prod
typeOf [] ()
typeOf (π  _) zero = π
typeOf (_  σ) (suc π) = typeOf σ π

-- inj turns a constructor index (C) into a constructor function
--   whose argument is the interpretation of the constructor code (⟦ typeOf σ C 〛ₚ (Fix σ'))
--   and whose result is the interpretation of the sum-of-products code (⟦ σ 〛ₛ (Fix σ'))
inj : (σ : Sum)  (σ' : Sum)  (C : Constr σ)    typeOf σ C 〛ₚ (Fix σ')   σ 〛ₛ (Fix σ')
inj [] _ ()
inj (x  σ) _ zero p = inj₁ p
inj (x  σ) σ' (suc c) p = inj₂ (inj σ σ' c p)

-- The code for lists of nats
listF : Sum
listF = []  (K nat  I  [])  []

-- The list type, obtained by interpreting the list code
list =  listF 〛ₛ (Fix listF)

-- The nil and cons constructors, built from the codes by inj.
nil : list
nil = inj listF listF zero tt

cons :   list  list
cons h t = inj listF listF (suc zero) (h ,  t  , tt)

-- The code for a pair of natural numbers
natPairF =  (K nat  K nat  [])  []
natPair =  natPairF 〛ₛ (Fix natPairF)

-- The constructor for pairs of natural numbers, built from the code
mkpairℕ :     natPair
mkpairℕ n₁ n₂ = inj natPairF natPairF zero (n₁ , n₁ , tt)

-- Generic equality, indexed by code
--   Each eq⋆ function takes a code and two values of a type obtained by interpreting the code
eqₖ : (κ : Konst)   κ 〛ₖ   κ 〛ₖ  Bool
eqₖ bool x y with x Data.Bool.≟ y
... | yes _ = true
... | no  _ = false
eqₖ nat x y with x Data.Nat.≟ y
... | yes _ = true
... | no  _ = false


-- The signatures are listed first, to support mutual recursion
--   (eqₛ calls eqₚ, which calls eqₐ, which calls eqₛ)
eqₐ : (α : Atom)  (σ : Sum)   α 〛ₐ (Fix σ)   α 〛ₐ (Fix σ)  Bool
eqₚ : (π : Prod)  (σ : Sum)   π 〛ₚ (Fix σ)   π 〛ₚ (Fix σ)  Bool
eqₛ : (σ : Sum)  (σ' : Sum)   σ 〛ₛ (Fix σ')   σ 〛ₛ (Fix σ')  Bool

eqₐ (K k) σ x y = eqₖ k x y
eqₐ I σ  x   y  = eqₛ σ σ x y

eqₚ [] σ x y = true
eqₚ (α  π) σ (u , v) (w , x) = eqₐ α σ u w  eqₚ π σ v x  

eqₛ [] _ () ()
eqₛ (π  _) σ' (inj₁ x) (inj₁ y) = eqₚ π σ' x y
eqₛ (_  _) _  (inj₁ x) (inj₂ y) = false
eqₛ (_  _) _  (inj₂ x) (inj₁ y) = false
eqₛ (π  σ) σ' (inj₂ x) (inj₂ y) = eqₛ σ σ' x y

-- Equality for lists via generic equality
eqList : list  list  Bool
eqList = eqₛ listF listF

-- Equality for pairs of natural numbers via generic equality
eqℕPair : natPair  natPair  Bool
eqℕPair = eqₛ natPairF natPairF

-- The type checker will run our tests for us:
check₁ : eqList
      (cons zero (cons (suc zero) nil))
      (cons zero (cons (suc (suc zero)) nil))  false
check₁ = refl

check₂ : eqList
      (cons zero (cons (suc (suc zero)) nil))
      (cons zero (cons (suc (suc zero)) nil))  true
check₂ = refl

check₃ : eqℕPair (mkpairℕ (suc (suc zero)) zero)
                   (mkpairℕ (suc (suc zero)) zero)
            true
check₃ = refl

check₄ : eqℕPair (mkpairℕ (suc (suc zero)) zero)
                   (mkpairℕ (suc (suc (suc zero))) zero)
            false
check₄ = refl