/* $Id: sfsconnect.C,v 1.13 2001/08/19 00:27:15 dm Exp $ */

/*
 *
 * Copyright (C) 1999 David Mazieres (dm@uun.org)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#include "sfskey.h"
#include "srp.h"

struct srpcon {
  const sfs_connect_cb cb;
  str user;
  str host;
  str pwd;
  ptr<sfscon> sc;
  ptr<aclnt> c;
  sfsauth_srpres sres;
  srp_client *srpp;
  str *pwdp;
  str *userp;
  int ntries;

  srpcon (const sfs_connect_cb &c)
    : cb (c), srpp (NULL), pwdp (NULL), userp (NULL), ntries (0) {}
  void fail (str msg) { (*cb) (NULL, msg); delete this; }
  void succeed () { c = NULL; (*cb) (sc, NULL); delete this; }

  void start (str &u);
  void getcon (ptr<sfscon> sc, str err);
  void initsrp ();
  void srpcb (clnt_stat err);
};

void
srpcon::start (str &u)
{
  static rxx usrhost ("^([^@]+)?@(.*)$");
  if (!usrhost.search (u)) {
    *userp = u;
    fail ("not of form [user]@host");
    return;
  }

  user = usrhost[1];
  host = usrhost[2];
  if (!user && !(user = myusername ())) {
    fail ("Could not get local username");
    return;
  }

  random_start ();

  sfs_connect_host (host, SFS_AUTHSERV, wrap (this, &srpcon::getcon), true);
}

void
srpcon::getcon (ptr<sfscon> s, str err)
{
  sc = s;
  if (!s) {
    fail (err);
    return;
  }

  c = aclnt::alloc (sc->x, sfsauth_program_1);
  initsrp ();
}

void
srpcon::initsrp ()
{
  sfssrp_init_arg arg;
  arg.username = user;
  if (!srpp->init (&arg.msg, sc->authid, user)) {
    fail ("SRP client initialization failed");
    return;
  }
  c->call (SFSAUTHPROC_SRP_INIT, &arg, &sres, wrap (this, &srpcon::srpcb));
}

void
srpcon::srpcb (clnt_stat err)
{
  if (err) {
    fail (host << ": " << err);
    return;
  }
  if (sres.status != SFSAUTH_OK) {
    if (!pwd || ntries++ >= 3) {
      fail ("Server aborted SRP protocol");
      return;
    }
    pwd = NULL;
    warnx ("Server rejected passphrase.\n");
    initsrp ();
    return;
  }

 reswitch:
  switch (srpp->next (sres.msg, sres.msg.addr ())) {
  case SRP_SETPWD:
    pwd = getpwd (strbuf () << "Passphrase for " << srpp->getname () << ": ");
    srpp->setpwd (pwd);
    if (!pwd.len ()) {
      fail ("Aborted.");
      return;
    }
    goto reswitch;
  case SRP_NEXT:
    c->call (SFSAUTHPROC_SRP_MORE, sres.msg.addr (), &sres, 
	     wrap (this, &srpcon::srpcb));
    break;
  case SRP_DONE:
    if (user)
      *userp = user << "@" << srpp->host;
    if (pwdp)
      *pwdp = pwd;
    sc->hostid_valid = (srpp->host == sc->servinfo.host.hostname);
    succeed ();
    break;
  default:
    fail (host << ": server returned invalid SRP message");
    break;
  }
}

void
sfs_connect_srp (str &user, srp_client *srpp, sfs_connect_cb cb,
		 str *userp, str *pwdp)
{
  assert (srpp);
  srpcon *sc = New srpcon (cb);
  sc->srpp = srpp;
  sc->pwdp = pwdp;
  sc->userp = userp;
  sc->start (user);
}
