(* syntax.sml *)

structure Syntax =
struct

local
open Error
in


(* ---------------------------------------------------------------------- *)
(* Datatypes *)

datatype ty
  = TyBool
  | TyArr of ty * ty

datatype term
  = TmTrue
  | TmFalse
  | TmIf of term * term * term
  | TmVar of int * int
  | TmAbs of string * ty * term
  | TmApp of term * term

datatype binding
  = NameBind       (* dummy binding used only during parsing *)
  | VarBind of ty

type context = (string * binding) list

datatype command
  = Eval of term
  | Bind of string * binding

(* equality of types *)
fun equalTy (TyBool, TyBool) = true
  | equalTy (TyArr(t1,t2),TyArr(u1,u2)) =
     equalTy(t1,u1) andalso equalTy(t2,u2)
  | equalTy _ = false

(* ---------------------------------------------------------------------- *)
(* Context management *)

val emptycontext = []

fun ctxlength ctx = List.length ctx

fun addbinding (ctx,x,bind) = (x,bind)::ctx

fun addname(ctx,x) = addbinding(ctx,x,NameBind)

fun isnamebound (ctx,x) =
    case ctx
      of [] => false
       | (y,_)::rest =>
          if y=x then true
          else isnamebound(rest,x)

fun pickfreshname (ctx,x) =
    if isnamebound(ctx,x) then pickfreshname(ctx,(x^"'"))
    else ((x,NameBind)::ctx, x)

fun index2name (ctx: context,x) =
    #1(List.nth(ctx,x))
    handle Size =>
      let val msg =
	      concat["Variable lookup failure: offset: ", Int.toString x,
		     ", ctx size: ", Int.toString (List.length ctx)]
       in error msg
      end

fun name2index(ctx,x) =
    case ctx
      of [] => error ("Identifier " ^ x ^ " is unbound")
       | (y,_)::rest =>
          if y = x then 0
          else 1 + name2index(rest,x)

fun getbinding(ctx:context,i) =
    #2(List.nth(ctx,i))
    handle Size =>
      let val msg =
	      concat["Variable lookup failure: offset: ", Int.toString i,
		     ", ctx size: ", Int.toString (List.length ctx)]
       in error msg
      end

fun getTypeFromContext(ctx,i) =
    case getbinding(ctx,i)
      of VarBind(tyT) => tyT
       | _ => error  
            ("getTypeFromContext: Wrong kind of binding for variable " 
             ^ index2name(ctx,i)) 


(* ---------------------------------------------------------------------- *)
(* Shifting *)

fun tmmap onvar c t =
    let fun walk(c,t) =
            case t
	      of TmTrue => t
	       | TmFalse => t
	       | TmIf(t1,t2,t3) => TmIf(walk(c,t1),walk(c,t2),walk(c,t3))
	       | TmVar(x,n) => onvar(c,x,n)
	       | TmAbs(x,tyT1,t2) => TmAbs(x,tyT1,walk(c+1,t2))
	       | TmApp(t1,t2) => TmApp(walk(c,t1),walk(c,t2))
     in walk(c,t)
    end

fun termShiftAbove(d,c,t) =
  tmmap
    (fn (c,x,n) => if x>=c then TmVar(x+d,n+d) else TmVar(x,n+d))
    c t

fun termShift(d,t) = termShiftAbove(d,0,t)

(* ---------------------------------------------------------------------- *)
(* Substitution *)

fun termSubst(j,s,t) =
  tmmap
    (fn (j,x,n) => if x=j then termShift(j,s) else TmVar(x,n))
    j t

fun termSubstTop(s,t) = 
    termShift(~1, termSubst(0,termShift(1,s),t))
(* ---------------------------------------------------------------------- *)
(* Printing *)

(* The printing functions call these utility functions to insert grouping
  information and line-breaking hints for the pretty-printing library:
     obox   Open a "box" whose contents will be indented by two spaces if
            the whole box cannot fit on the current line
     obox0  Same but indent continuation lines to the same column as the
            beginning of the box rather than 2 more columns to the right
     cbox   Close the current box
     break  Insert a breakpoint indicating where the line maybe broken if
            necessary.
  See the documentation for the Format module in the OCaml library for
  more details. 
*)

open PPUtil

fun obox0() = open_hvbox 0
fun obox() = open_hvbox 2
fun cbox() = close_box()
fun break() = print_break 0 0
val pr = print_string

fun small t = 
    case t
      of TmVar(_,_) => true
       | _ => false

fun printty_Type(outer,tyT) =
    printty_ArrowType(outer,tyT)

and printty_ArrowType(outer,tyT) =
    case tyT
      of TyArr(tyT1,tyT2) =>
	 (obox0(); 
	  printty_AType(false,tyT1);
	  if outer then pr " " else  ();
	  pr "->";
	  if outer then print_space() else break();
	  printty_ArrowType(outer,tyT2);
	  cbox())
       | _ => printty_AType(outer,tyT)

and printty_AType(outer,tyT) =
    case tyT
      of TyBool => pr "Bool"
       | tyT => (pr "("; printty_Type(outer,tyT); pr ")")

fun printty tyT = printty_Type(true,tyT)

fun printtm_Term(outer,ctx,t) =
    case t
      of TmIf(t1, t2, t3) =>
       (obox0();
	pr "if ";
	printtm_Term(false,ctx,t1);
	print_space();
	pr "then ";
	printtm_Term(false,ctx,t2);
	print_space();
	pr "else ";
	printtm_Term(false,ctx,t3);
	cbox())
  | TmAbs(x,tyT1,t2) =>
       let val (ctx',x') = pickfreshname(ctx,x)
        in obox(); pr "lambda ";
           pr x'; pr ":"; printty_Type(false,tyT1); pr ".";
           if (small t2) andalso not outer then break() else print_space();
           printtm_Term(outer,ctx',t2);
           cbox()
       end
  | t => printtm_AppTerm(outer,ctx,t)

and printtm_AppTerm(outer,ctx,t) =
    case t
      of TmApp( t1, t2) =>
         (obox0();
          printtm_AppTerm(false,ctx,t1);
          print_space();
          printtm_ATerm(false,ctx,t2);
          cbox())
       | _ => printtm_ATerm(outer,ctx,t)

and printtm_ATerm(outer,ctx,t) =
    case t
      of TmTrue => pr "true"
       | TmFalse => pr "false"
       | TmVar(x,n) =>
	  if ctxlength ctx = n then
	    pr (index2name(ctx,x))
	  else
	    pr ("[bad index: " ^ (Int.toString x) ^ "/" ^ (Int.toString n)
		^ " in {"
		^ (List.foldl (fn ((x,_),s) => s ^ " " ^ x) "" ctx)
		^ " }]")
       | t => (pr "("; printtm_Term(outer,ctx,t); pr ")")

fun printtm(ctx,t) = printtm_Term(true,ctx,t)

fun prbinding(ctx,b) =
    case b
      of NameBind => ()
       | VarBind(tyT) => (pr ": "; printty tyT)


end (* local *)
end (* structure Syntax *)
