(* MySQL database interface for mod_caml programs.
 * Copyright (C) 2003 Merjis Ltd.
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the Free
 * Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 * $Id: dbi_mysql.ml,v 1.6 2004/08/03 16:49:41 rwmj Exp $
 *)

open Printf

(* XXX Rather naive method of finding the [?] placeholders in the query
 * string. We parse up the query into [ "select name from foo where id = ";
 * "?"; " and bar = "; "?" ]. This doesn't handle naked question-marks within
 * strings properly of course.
 *)
let rec split_query query =
  try
    let i = String.index query '?' in
    let n = String.length query in
    let before, after =
      String.sub query 0 i, String.sub query (i+1) (n-(i+1)) in
    let after_split, count = split_query after in
    (before :: "?" :: after_split), (count+1)
  with
      Not_found -> [query], 0


let list_of_array a =
  let rec loop acc i =
    if i < 0 then acc
    else loop (Array.get a i :: acc) (i - 1) in
  loop [] ((Array.length a) - 1)


class statement dbh conn in_transaction original_query =

  (* Split up the query, and calculate the number of placeholders. *)
  let query, nr_args = split_query original_query in

object (self)
  inherit Dbi.statement dbh

  val mutable tuples = None
  val mutable name_list = None

  method execute args =
    if dbh#closed then
      failwith "Dbi_mysql: executed called on a closed database handle.";

    if List.length args <> nr_args then
      invalid_arg "Dbi_mysql: execute called with wrong number of args.";

    (* Finish previous statement, if any. *)
    self#finish ();

    (* In transaction? If not we need to issue a BEGIN command. *)
    if not !in_transaction then (
      (* So we don't go into an infinite recursion ... *)
      in_transaction := true;
      let sth = dbh#prepare_cached "begin" in
      sth#execute []
    );

    (* Substitute the arguments and create the query which we'll send to
     * the database.
     *)
    let args = ref args in
    let query =
      String.concat ""
	(List.map
	   (function
		"?" ->
		  let arg = List.hd !args in
		  args := List.tl !args;
		  (match arg with
		      `Null ->
			"null"
		    | `Int i ->
			string_of_int i
                    | `Float f ->
                        string_of_float f
		    | `String s ->
			Mysql.escape s
		    | `Bool b ->
			if b then "'t'" else "'f'"
		    | _ ->
			failwith "Dbi_mysql: unknown argument type in execute"
		  )
	      | str -> str) query) in

    (* Send the query to the database. *)
    let res = Mysql.exec conn query in

    (* Check the status. *)
    match Mysql.status conn with
	Mysql.StatusEmpty ->
	  ()
      | Mysql.StatusOK ->
	  tuples <- Some res;
          name_list <- None;
      | Mysql.StatusError _ ->
	  let msg = match Mysql.errmsg conn with None -> "unknown"
	                                       | Some err -> err in
	  (* dbh#close (); -- used to do this, not a good idea *)
	  raise (Dbi.SQL_error msg)

  method fetch1 () =
    match tuples with
    | None -> failwith "Dbi_mysql.statement#fetch1"
    | Some tuples ->
	  let row = Mysql.fetch tuples in
	  match row with
	      None ->
		raise Not_found
	    | Some row ->
		let types = Mysql.types tuples in
		let n = Array.length row in
		let rec loop i =
		  if i < n then (
		    let field =
		      match row.(i) with
			  None -> `Null
			| Some v ->
			    match types.(i) with
				Mysql.IntTy -> `Int (int_of_string v)
			      | Mysql.FloatTy -> `Float (float_of_string v)
			      | _ -> `String v in
		    field :: loop (i+1)
		  ) else
		    []
		in
		loop 0


  method names =
    match tuples with
    | None -> failwith "Dbi_mysql.statement#names"
    | Some tuples ->
        begin match name_list with
        | Some l -> l
        | None ->
            let l = list_of_array(Mysql.names tuples) in
            name_list <- Some l;
            l
        end

  method serial seq =
    (* Is it possible to get the serial column from a previous INSERT statement
     * with MySQL? XXX *)
    failwith "XXX not implemented yet"

  method finish () =
    (* Can we free up the resources used by a query? XXX *)
    tuples <- None

end

and connection ?host ?port ?user ?password database =

  let conn =
    let port =
      match port with
	  None -> None
	| Some str -> Some (int_of_string str) in
    Mysql.quick_connect ?host ?port ?user ?password ~database () in

  (* We pass this reference around to the statement class so that all
   * statements belonging to this connection can keep track of our
   * transaction state and issue the appropriate BEGIN WORK command at
   * the right time.
   *)
  let in_transaction = ref false in

object (self)
  inherit Dbi.connection ?host ?port ?user ?password database as super

  method database_type = "mysql"

  method prepare query =
    if self#closed then
      failwith "Dbi_mysql: prepare called on closed database handle.";
    new statement
      (self : #Dbi.connection :> Dbi.connection)
      conn in_transaction query

  method commit () =
    super#commit ();
    let sth = self#prepare_cached "commit" in
    sth#execute [];
    in_transaction := false

  method rollback () =
    let sth = self#prepare_cached "rollback" in
    sth#execute [];
    in_transaction := false;
    super#rollback ()

  method close () =
    Mysql.disconnect conn;
    super#close ()

  method ping () =
    Mysql.ping conn;
    true

  initializer
    match Mysql.status conn with
	Mysql.StatusError _ ->
	  let msg = match Mysql.errmsg conn with None -> "unknown"
	                                       | Some err -> err in
	  raise (Dbi.SQL_error msg)
      | _ -> ()
end

let connect ?host ?port ?user ?password database =
  new connection ?host ?port ?user ?password database
let close (dbh : connection) = dbh#close ()
let closed (dbh : connection) = dbh#closed
let commit (dbh : connection) = dbh#commit ()
let ping (dbh : connection) = dbh#ping ()
let rollback (dbh : connection) = dbh#rollback ()
