(* Title: Tools/Compute_Oracle/am_sml.ML Author: Steven Obua ToDO: "parameterless rewrite cannot be used in pattern": In a lot of cases it CAN be used, and these cases should be handled properly; right now, all cases throw an exception. *) signature AM_SML = sig include ABSTRACT_MACHINE val save_result : (string * term) -> unit val set_compiled_rewriter : (term -> term) -> unit val list_nth : 'a list * int -> 'a val dump_output : (string option) ref end structure AM_SML : AM_SML = struct open AbstractMachine; val dump_output = ref (NONE: string option) type program = string * string * (int Inttab.table) * (int Inttab.table) * (term Inttab.table) * (term -> term) val saved_result = ref (NONE:(string*term)option) fun save_result r = (saved_result := SOME r) fun clear_result () = (saved_result := NONE) val list_nth = List.nth (*fun list_nth (l,n) = (writeln (makestring ("list_nth", (length l,n))); List.nth (l,n))*) val compiled_rewriter = ref (NONE:(term -> term)Option.option) fun set_compiled_rewriter r = (compiled_rewriter := SOME r) fun count_patternvars PVar = 1 | count_patternvars (PConst (_, ps)) = List.foldl (fn (p, count) => (count_patternvars p)+count) 0 ps fun update_arity arity code a = (case Inttab.lookup arity code of NONE => Inttab.update_new (code, a) arity | SOME (a': int) => if a > a' then Inttab.update (code, a) arity else arity) (* We have to find out the maximal arity of each constant *) fun collect_pattern_arity PVar arity = arity | collect_pattern_arity (PConst (c, args)) arity = fold collect_pattern_arity args (update_arity arity c (length args)) (* We also need to find out the maximal toplevel arity of each function constant *) fun collect_pattern_toplevel_arity PVar arity = raise Compile "internal error: collect_pattern_toplevel_arity" | collect_pattern_toplevel_arity (PConst (c, args)) arity = update_arity arity c (length args) local fun collect applevel (Var _) arity = arity | collect applevel (Const c) arity = update_arity arity c applevel | collect applevel (Abs m) arity = collect 0 m arity | collect applevel (App (a,b)) arity = collect 0 b (collect (applevel + 1) a arity) in fun collect_term_arity t arity = collect 0 t arity end fun collect_guard_arity (Guard (a,b)) arity = collect_term_arity b (collect_term_arity a arity) fun rep n x = if n < 0 then raise Compile "internal error: rep" else if n = 0 then [] else x::(rep (n-1) x) fun beta (Const c) = Const c | beta (Var i) = Var i | beta (App (Abs m, b)) = beta (unlift 0 (subst 0 m (lift 0 b))) | beta (App (a, b)) = (case beta a of Abs m => beta (App (Abs m, b)) | a => App (a, beta b)) | beta (Abs m) = Abs (beta m) | beta (Computed t) = Computed t and subst x (Const c) t = Const c | subst x (Var i) t = if i = x then t else Var i | subst x (App (a,b)) t = App (subst x a t, subst x b t) | subst x (Abs m) t = Abs (subst (x+1) m (lift 0 t)) and lift level (Const c) = Const c | lift level (App (a,b)) = App (lift level a, lift level b) | lift level (Var i) = if i < level then Var i else Var (i+1) | lift level (Abs m) = Abs (lift (level + 1) m) and unlift level (Const c) = Const c | unlift level (App (a, b)) = App (unlift level a, unlift level b) | unlift level (Abs m) = Abs (unlift (level+1) m) | unlift level (Var i) = if i < level then Var i else Var (i-1) fun nlift level n (Var m) = if m < level then Var m else Var (m+n) | nlift level n (Const c) = Const c | nlift level n (App (a,b)) = App (nlift level n a, nlift level n b) | nlift level n (Abs b) = Abs (nlift (level+1) n b) fun subst_const (c, t) (Const c') = if c = c' then t else Const c' | subst_const _ (Var i) = Var i | subst_const ct (App (a, b)) = App (subst_const ct a, subst_const ct b) | subst_const ct (Abs m) = Abs (subst_const ct m) (* Remove all rules that are just parameterless rewrites. This is necessary because SML does not allow functions with no parameters. *) fun inline_rules rules = let fun term_contains_const c (App (a, b)) = term_contains_const c a orelse term_contains_const c b | term_contains_const c (Abs m) = term_contains_const c m | term_contains_const c (Var i) = false | term_contains_const c (Const c') = (c = c') fun find_rewrite [] = NONE | find_rewrite ((prems, PConst (c, []), r) :: _) = if check_freevars 0 r then if term_contains_const c r then raise Compile "parameterless rewrite is caught in cycle" else if not (null prems) then raise Compile "parameterless rewrite may not be guarded" else SOME (c, r) else raise Compile "unbound variable on right hand side or guards of rule" | find_rewrite (_ :: rules) = find_rewrite rules fun remove_rewrite (c,r) [] = [] | remove_rewrite (cr as (c,r)) ((rule as (prems', PConst (c', args), r'))::rules) = (if c = c' then if null args andalso r = r' andalso null (prems') then remove_rewrite cr rules else raise Compile "incompatible parameterless rewrites found" else rule :: (remove_rewrite cr rules)) | remove_rewrite cr (r::rs) = r::(remove_rewrite cr rs) fun pattern_contains_const c (PConst (c', args)) = (c = c' orelse exists (pattern_contains_const c) args) | pattern_contains_const c (PVar) = false fun inline_rewrite (ct as (c, _)) (prems, p, r) = if pattern_contains_const c p then raise Compile "parameterless rewrite cannot be used in pattern" else (map (fn (Guard (a,b)) => Guard (subst_const ct a, subst_const ct b)) prems, p, subst_const ct r) fun inline inlined rules = (case find_rewrite rules of NONE => (Inttab.make inlined, rules) | SOME ct => let val rules = map (inline_rewrite ct) (remove_rewrite ct rules) val inlined = ct :: (map (fn (c', r) => (c', subst_const ct r)) inlined) in inline inlined rules end) in inline [] rules end (* Calculate the arity, the toplevel_arity, and adjust rules so that all toplevel pattern constants have maximal arity. Also beta reduce the adjusted right hand side of a rule. *) fun adjust_rules rules = let val arity = fold (fn (prems, p, t) => fn arity => fold collect_guard_arity prems (collect_term_arity t (collect_pattern_arity p arity))) rules Inttab.empty val toplevel_arity = fold (fn (_, p, t) => fn arity => collect_pattern_toplevel_arity p arity) rules Inttab.empty fun arity_of c = the (Inttab.lookup arity c) fun toplevel_arity_of c = the (Inttab.lookup toplevel_arity c) fun test_pattern PVar = () | test_pattern (C as PConst (c, args)) = if (length args <> arity_of c) then raise Compile ("Constant inside pattern must have maximal arity") else (map test_pattern args; ()) fun adjust_rule (_, PVar, _) = raise Compile ("pattern may not be a variable") | adjust_rule (_, PConst (_, []), _) = raise Compile ("cannot deal with rewrites that take no parameters") | adjust_rule (rule as (prems, p as PConst (c, args),t)) = let val patternvars_counted = count_patternvars p fun check_fv t = check_freevars patternvars_counted t val _ = if not (check_fv t) then raise Compile ("unbound variables on right hand side of rule") else () val _ = if not (forall (fn (Guard (a,b)) => check_fv a andalso check_fv b) prems) then raise Compile ("unbound variables in guards") else () val _ = map test_pattern args val len = length args val arity = arity_of c val lift = nlift 0 fun addapps_tm n t = if n=0 then t else addapps_tm (n-1) (App (t, Var (n-1))) fun adjust_term n t = addapps_tm n (lift n t) fun adjust_guard n (Guard (a,b)) = Guard (lift n a, lift n b) in if len = arity then rule else if arity >= len then (map (adjust_guard (arity-len)) prems, PConst (c, args @ (rep (arity-len) PVar)), adjust_term (arity-len) t) else (raise Compile "internal error in adjust_rule") end fun beta_rule (prems, p, t) = ((prems, p, beta t) handle Match => raise Compile "beta_rule") in (arity, toplevel_arity, map (beta_rule o adjust_rule) rules) end fun print_term module arity_of toplevel_arity_of pattern_var_count pattern_lazy_var_count = let fun str x = string_of_int x fun protect_blank s = if exists_string Symbol.is_ascii_blank s then "(" ^ s ^")" else s val module_prefix = (case module of NONE => "" | SOME s => s^".") fun print_apps d f [] = f | print_apps d f (a::args) = print_apps d (module_prefix^"app "^(protect_blank f)^" "^(protect_blank (print_term d a))) args and print_call d (App (a, b)) args = print_call d a (b::args) | print_call d (Const c) args = (case arity_of c of NONE => print_apps d (module_prefix^"Const "^(str c)) args | SOME 0 => module_prefix^"C"^(str c) | SOME a => let val len = length args in if a <= len then let val strict_a = (case toplevel_arity_of c of SOME sa => sa | NONE => a) val _ = if strict_a > a then raise Compile "strict" else () val s = module_prefix^"c"^(str c)^(concat (map (fn t => " "^(protect_blank (print_term d t))) (List.take (args, strict_a)))) val s = s^(concat (map (fn t => " (fn () => "^print_term d t^")") (List.drop (List.take (args, a), strict_a)))) in print_apps d s (List.drop (args, a)) end else let fun mk_apps n t = if n = 0 then t else mk_apps (n-1) (App (t, Var (n - 1))) fun mk_lambdas n t = if n = 0 then t else mk_lambdas (n-1) (Abs t) fun append_args [] t = t | append_args (c::cs) t = append_args cs (App (t, c)) in print_term d (mk_lambdas (a-len) (mk_apps (a-len) (nlift 0 (a-len) (append_args args (Const c))))) end end) | print_call d t args = print_apps d (print_term d t) args and print_term d (Var x) = if x < d then "b"^(str (d-x-1)) else let val n = pattern_var_count - (x-d) - 1 val x = "x"^(str n) in if n < pattern_var_count - pattern_lazy_var_count then x else "("^x^" ())" end | print_term d (Abs c) = module_prefix^"Abs (fn b"^(str d)^" => "^(print_term (d + 1) c)^")" | print_term d t = print_call d t [] in print_term 0 end fun section n = if n = 0 then [] else (section (n-1))@[n-1] fun print_rule gnum arity_of toplevel_arity_of (guards, p, t) = let fun str x = Int.toString x fun print_pattern top n PVar = (n+1, "x"^(str n)) | print_pattern top n (PConst (c, [])) = (n, (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "")) | print_pattern top n (PConst (c, args)) = let val f = (if top then "c" else "C")^(str c)^(if top andalso gnum > 0 then "_"^(str gnum) else "") val (n, s) = print_pattern_list 0 top (n, f) args in (n, s) end and print_pattern_list' counter top (n,p) [] = if top then (n,p) else (n,p^")") | print_pattern_list' counter top (n, p) (t::ts) = let val (n, t) = print_pattern false n t in print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^", "^t) ts end and print_pattern_list counter top (n, p) (t::ts) = let val (n, t) = print_pattern false n t in print_pattern_list' (counter + 1) top (n, if top then p^" (a"^(str counter)^" as ("^t^"))" else p^" ("^t) ts end val c = (case p of PConst (c, _) => c | _ => raise Match) val (n, pattern) = print_pattern true 0 p val lazy_vars = the (arity_of c) - the (toplevel_arity_of c) fun print_tm tm = print_term NONE arity_of toplevel_arity_of n lazy_vars tm fun print_guard (Guard (a,b)) = "term_eq ("^(print_tm a)^") ("^(print_tm b)^")" val else_branch = "c"^(str c)^"_"^(str (gnum+1))^(concat (map (fn i => " a"^(str i)) (section (the (arity_of c))))) fun print_guards t [] = print_tm t | print_guards t (g::gs) = "if ("^(print_guard g)^")"^(concat (map (fn g => " andalso ("^(print_guard g)^")") gs))^" then ("^(print_tm t)^") else "^else_branch in (if null guards then gnum else gnum+1, pattern^" = "^(print_guards t guards)) end fun group_rules rules = let fun add_rule (r as (_, PConst (c,_), _)) groups = let val rs = (case Inttab.lookup groups c of NONE => [] | SOME rs => rs) in Inttab.update (c, r::rs) groups end | add_rule _ _ = raise Compile "internal error group_rules" in fold_rev add_rule rules Inttab.empty end fun sml_prog name code rules = let val buffer = ref "" fun write s = (buffer := (!buffer)^s) fun writeln s = (write s; write "\n") fun writelist [] = () | writelist (s::ss) = (writeln s; writelist ss) fun str i = Int.toString i val (inlinetab, rules) = inline_rules rules val (arity, toplevel_arity, rules) = adjust_rules rules val rules = group_rules rules val constants = Inttab.keys arity fun arity_of c = Inttab.lookup arity c fun toplevel_arity_of c = Inttab.lookup toplevel_arity c fun rep_str s n = concat (rep n s) fun indexed s n = s^(str n) fun string_of_tuple [] = "" | string_of_tuple (x::xs) = "("^x^(concat (map (fn s => ", "^s) xs))^")" fun string_of_args [] = "" | string_of_args (x::xs) = x^(concat (map (fn s => " "^s) xs)) fun default_case gnum c = let val leftargs = concat (map (indexed " x") (section (the (arity_of c)))) val rightargs = section (the (arity_of c)) val strict_args = (case toplevel_arity_of c of NONE => the (arity_of c) | SOME sa => sa) val xs = map (fn n => if n < strict_args then "x"^(str n) else "x"^(str n)^"()") rightargs val right = (indexed "C" c)^" "^(string_of_tuple xs) val message = "(\"unresolved lazy call: "^(string_of_int c)^", \"^(makestring x"^(string_of_int (strict_args - 1))^"))" val right = if strict_args < the (arity_of c) then "raise AM_SML.Run "^message else right in (indexed "c" c)^(if gnum > 0 then "_"^(str gnum) else "")^leftargs^" = "^right end fun eval_rules c = let val arity = the (arity_of c) val strict_arity = (case toplevel_arity_of c of NONE => arity | SOME sa => sa) fun eval_rule n = let val sc = string_of_int c val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section n) ("AbstractMachine.Const "^sc) fun arg i = let val x = indexed "x" i val x = if i < n then "(eval bounds "^x^")" else x val x = if i < strict_arity then x else "(fn () => "^x^")" in x end val right = "c"^sc^" "^(string_of_args (map arg (section arity))) val right = fold_rev (fn i => fn s => "Abs (fn "^(indexed "x" i)^" => "^s^")") (List.drop (section arity, n)) right val right = if arity > 0 then right else "C"^sc in " | eval bounds ("^left^") = "^right end in map eval_rule (rev (section (arity + 1))) end fun convert_computed_rules (c: int) : string list = let val arity = the (arity_of c) fun eval_rule () = let val sc = string_of_int c val left = fold (fn i => fn s => "AbstractMachine.App ("^s^(indexed ", x" i)^")") (section arity) ("AbstractMachine.Const "^sc) fun arg i = "(convert_computed "^(indexed "x" i)^")" val right = "C"^sc^" "^(string_of_tuple (map arg (section arity))) val right = if arity > 0 then right else "C"^sc in " | convert_computed ("^left^") = "^right end in [eval_rule ()] end fun mk_constr_type_args n = if n > 0 then " of Term "^(rep_str " * Term" (n-1)) else "" val _ = writelist [ "structure "^name^" = struct", "", "datatype Term = Const of int | App of Term * Term | Abs of (Term -> Term)", " "^(concat (map (fn c => " | C"^(str c)^(mk_constr_type_args (the (arity_of c)))) constants)), ""] fun make_constr c argprefix = "(C"^(str c)^" "^(string_of_tuple (map (fn i => argprefix^(str i)) (section (the (arity_of c)))))^")" fun make_term_eq c = " | term_eq "^(make_constr c "a")^" "^(make_constr c "b")^" = "^ (case the (arity_of c) of 0 => "true" | n => let val eqs = map (fn i => "term_eq a"^(str i)^" b"^(str i)) (section n) val (eq, eqs) = (List.hd eqs, map (fn s => " andalso "^s) (List.tl eqs)) in eq^(concat eqs) end) val _ = writelist [ "fun term_eq (Const c1) (Const c2) = (c1 = c2)", " | term_eq (App (a1,a2)) (App (b1,b2)) = term_eq a1 b1 andalso term_eq a2 b2"] val _ = writelist (map make_term_eq constants) val _ = writelist [ " | term_eq _ _ = false", "" ] val _ = writelist [ "fun app (Abs a) b = a b", " | app a b = App (a, b)", ""] fun defcase gnum c = (case arity_of c of NONE => [] | SOME a => if a > 0 then [default_case gnum c] else []) fun writefundecl [] = () | writefundecl (x::xs) = writelist ((("and "^x)::(map (fn s => " | "^s) xs))) fun list_group c = (case Inttab.lookup rules c of NONE => [defcase 0 c] | SOME rs => let val rs = fold (fn r => fn rs => let val (gnum, l, rs) = (case rs of [] => (0, [], []) | (gnum, l)::rs => (gnum, l, rs)) val (gnum', r) = print_rule gnum arity_of toplevel_arity_of r in if gnum' = gnum then (gnum, r::l)::rs else let val args = concat (map (fn i => " a"^(str i)) (section (the (arity_of c)))) fun gnumc g = if g > 0 then "c"^(str c)^"_"^(str g)^args else "c"^(str c)^args val s = gnumc (gnum) ^ " = " ^ gnumc (gnum') in (gnum', [])::(gnum, s::r::l)::rs end end) rs [] val rs = (case rs of [] => [(0,defcase 0 c)] | (gnum,l)::rs => (gnum, (defcase gnum c)@l)::rs) in rev (map (fn z => rev (snd z)) rs) end) val _ = map (fn z => (map writefundecl z; writeln "")) (map list_group constants) val _ = writelist [ "fun convert (Const i) = AM_SML.Const i", " | convert (App (a, b)) = AM_SML.App (convert a, convert b)", " | convert (Abs _) = raise AM_SML.Run \"no abstraction in result allowed\""] fun make_convert c = let val args = map (indexed "a") (section (the (arity_of c))) val leftargs = case args of [] => "" | (x::xs) => "("^x^(concat (map (fn s => ", "^s) xs))^")" val args = map (indexed "convert a") (section (the (arity_of c))) val right = fold (fn x => fn s => "AM_SML.App ("^s^", "^x^")") args ("AM_SML.Const "^(str c)) in " | convert (C"^(str c)^" "^leftargs^") = "^right end val _ = writelist (map make_convert constants) val _ = writelist [ "", "fun convert_computed (AbstractMachine.Abs b) = raise AM_SML.Run \"no abstraction in convert_computed allowed\"", " | convert_computed (AbstractMachine.Var i) = raise AM_SML.Run \"no bound variables in convert_computed allowed\""] val _ = map (writelist o convert_computed_rules) constants val _ = writelist [ " | convert_computed (AbstractMachine.Const c) = Const c", " | convert_computed (AbstractMachine.App (a, b)) = App (convert_computed a, convert_computed b)", " | convert_computed (AbstractMachine.Computed a) = raise AM_SML.Run \"no nesting in convert_computed allowed\""] val _ = writelist [ "", "fun eval bounds (AbstractMachine.Abs m) = Abs (fn b => eval (b::bounds) m)", " | eval bounds (AbstractMachine.Var i) = AM_SML.list_nth (bounds, i)"] val _ = map (writelist o eval_rules) constants val _ = writelist [ " | eval bounds (AbstractMachine.App (a, b)) = app (eval bounds a) (eval bounds b)", " | eval bounds (AbstractMachine.Const c) = Const c", " | eval bounds (AbstractMachine.Computed t) = convert_computed t"] val _ = writelist [ "", "fun export term = AM_SML.save_result (\""^code^"\", convert term)", "", "val _ = AM_SML.set_compiled_rewriter (fn t => (convert (eval [] t)))", "", "end"] in (arity, toplevel_arity, inlinetab, !buffer) end val guid_counter = ref 0 fun get_guid () = let val c = !guid_counter val _ = guid_counter := !guid_counter + 1 in (LargeInt.toString (Time.toMicroseconds (Time.now ()))) ^ (string_of_int c) end fun writeTextFile name s = File.write (Path.explode name) s fun use_source src = use_text ML_Context.local_context (1, "") false src fun compile cache_patterns const_arity eqs = let val guid = get_guid () val code = Real.toString (random ()) val module = "AMSML_"^guid val (arity, toplevel_arity, inlinetab, source) = sml_prog module code eqs val _ = case !dump_output of NONE => () | SOME p => writeTextFile p source val _ = compiled_rewriter := NONE val _ = use_source source in case !compiled_rewriter of NONE => raise Compile "broken link to compiled function" | SOME f => (module, code, arity, toplevel_arity, inlinetab, f) end fun run' (module, code, arity, toplevel_arity, inlinetab, compiled_fun) t = let val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms") fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t) | inline (Var i) = Var i | inline (App (a, b)) = App (inline a, inline b) | inline (Abs m) = Abs (inline m) val t = beta (inline t) fun arity_of c = Inttab.lookup arity c fun toplevel_arity_of c = Inttab.lookup toplevel_arity c val term = print_term NONE arity_of toplevel_arity_of 0 0 t val source = "local open "^module^" in val _ = export ("^term^") end" val _ = writeTextFile "Gencode_call.ML" source val _ = clear_result () val _ = use_source source in case !saved_result of NONE => raise Run "broken link to compiled code" | SOME (code', t) => (clear_result (); if code' = code then t else raise Run "link to compiled code was hijacked") end fun run (module, code, arity, toplevel_arity, inlinetab, compiled_fun) t = let val _ = if check_freevars 0 t then () else raise Run ("can only compute closed terms") fun inline (Const c) = (case Inttab.lookup inlinetab c of NONE => Const c | SOME t => t) | inline (Var i) = Var i | inline (App (a, b)) = App (inline a, inline b) | inline (Abs m) = Abs (inline m) | inline (Computed t) = Computed t in compiled_fun (beta (inline t)) end fun discard p = () end