/*
 *  Copyright 2001-2005 Internet2
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/* SAMLPOSTProfile.cpp - implements basics of the SAML POST profile

   Scott Cantor
   8/12/02

   $History:$
*/

#include <xercesc/util/XMLChar.hpp>
#include <xercesc/util/Base64.hpp>

#include "internal.h"

#include <ctime>
#include <sstream>
using namespace saml;
using namespace std;
using namespace log4cpp;

class MemoryReplayCache : public IReplayCache
{
public:
    MemoryReplayCache(const DOMElement*) {}
    ~MemoryReplayCache() {}
    bool check(const XMLCh* str, time_t expires) {auto_ptr_XMLCh temp(str); return check(temp.get(),expires);}
    bool check(const char* str, time_t expires);
    void thread_init() {}
    void thread_end() {}

private:
    set<string> m_replayCache;
    multimap<time_t,string> m_replayExpMap;
};

IPlugIn* MemoryReplayCacheFactory(const DOMElement* e)
{
    return new MemoryReplayCache(e);
}

IReplayCache* IReplayCache::getInstance(const DOMElement* e)
{
    return getInstance(DEFAULT_REPLAYCACHE_PROVIDER,e);
}

IReplayCache* IReplayCache::getInstance(const char* type, const DOMElement* e)
{
    IPlugIn* p=SAMLConfig::getConfig().getPlugMgr().newPlugin(type,e);
    IReplayCache* rc=dynamic_cast<IReplayCache*>(p);
    if (rc)
        return rc;
    delete p;
    throw UnsupportedExtensionException("factory returned plugin type other than IReplayCache");
}

bool MemoryReplayCache::check(const char* str, time_t expires)
{
    SAMLConfig& config=SAMLConfig::getConfig();

    // Garbage collect any expired entries.
    time_t now=time(NULL)-config.clock_skew_secs;

    config.saml_lock();

    try
    {
        multimap<time_t,string>::iterator stop=m_replayExpMap.upper_bound(now);
        for (multimap<time_t,string>::iterator i=m_replayExpMap.begin(); i!=stop; m_replayExpMap.erase(i++))
            m_replayCache.erase(i->second);

        // If it's already been seen, bail.
        if (!m_replayCache.insert(str).second) {
            config.saml_unlock();
            return false;
        }

        // Add the pair to the expiration map.
        m_replayExpMap.insert(multimap<time_t,string>::value_type(expires,str));
    }
    catch(...)
    {
        config.saml_unlock();
        Category::getInstance(SAML_LOGCAT".MemoryReplayCache").error("check() caught an exception");
        return false;
    }
    config.saml_unlock();
    return true;
}

class BrowserProfile : virtual public SAMLBrowserProfile
{
public:
    BrowserProfile(const DOMElement*) {}
    ~BrowserProfile() {}
    
    BrowserProfileResponse receive(
        const char* packet,
        const XMLCh* recipient,
        int supportedProfiles,
        IReplayCache* replayCache=NULL,
        ArtifactMapper* callback=NULL,
        int minorVersion=1
        ) const;

private:
    
    class CgiParse
    {
    public:
        CgiParse(const char* data, unsigned int len);
        ~CgiParse();

        typedef multimap<string,char*>::const_iterator walker;
        pair<walker,walker> get_values(const char* name) const;
        char* get_value(const char* name) const;
        
    private:
        char * fmakeword(char stop, unsigned int *cl, const char** ppch);
        char * makeword(char *line, char stop);
        void plustospace(char *str);
        char x2c(char *what);
        void url_decode(char *url);
        
        multimap<string,char*> kvp_map;
    };
};

IPlugIn* BrowserProfileFactory(const DOMElement* e)
{
    return new BrowserProfile(e);
}

SAMLBrowserProfile* SAMLBrowserProfile::getInstance(const DOMElement* e)
{
    return getInstance(DEFAULT_BROWSERPROFILE_PROVIDER,e);
}

SAMLBrowserProfile* SAMLBrowserProfile::getInstance(const char* type, const DOMElement* e)
{
    IPlugIn* p=SAMLConfig::getConfig().getPlugMgr().newPlugin(type,e);
    SAMLBrowserProfile* bp=dynamic_cast<SAMLBrowserProfile*>(p);
    if (bp)
        return bp;
    delete p;
    throw UnsupportedExtensionException("factory returned plugin type other than SAMLBrowserProfile");
}

SAMLBrowserProfile::BrowserProfileResponse BrowserProfile::receive(
    const char* packet,
    const XMLCh* recipient,
    int supportedProfiles,
    IReplayCache* replayCache,
    ArtifactMapper* artifactMapper,
    int minorVersion
    ) const
{
#ifdef _DEBUG
    saml::NDC ndc("receive");
#endif
    Category& log=Category::getInstance(SAML_LOGCAT".BrowserProfile");
    SAMLConfig& config=SAMLConfig::getConfig();
    time_t now=time(NULL);
    SAMLResponse* response = NULL;
    SAMLAssertion* assertion = NULL;
    SAMLAuthenticationStatement* authnStatement = NULL;

    // First we process the input packet. Technically artifact requires GET, but
    // it doesn't really matter to the implementation of the profile how it comes in.
    CgiParse parser(packet,strlen(packet));
    char* bits=parser.get_value("SAMLResponse");
    if (bits) {
        // This is a POST profile packet.
        if (!(supportedProfiles & SAMLBrowserProfile::Post)) {
            auto_ptr_char recip(recipient);
            log.error("attempted use of Browser/POST profile at unupported endpoint (%s)", recip.get());
            throw UnsupportedProfileException("blocked use of Browser/POST profile");
        }
        supportedProfiles=SAMLBrowserProfile::Post;
    
        // The value is writable, so we collapse any whitespace inline to save a buffer copy.
        unsigned int x,y;
        for (x=0,y=0; bits[y]; y++)
            if (!XMLChar1_0::isWhitespace(bits[y]))
                bits[x++]=bits[y];
        bits[x]=0;
        
        // Decode the base64 into SAML.
        XMLByte* decoded=Base64::decode(reinterpret_cast<XMLByte*>(bits),&x);
        if (!decoded)
            throw FatalProfileException("unable to decode base64 in POST profile response");
        log.debug("decoded SAML response:\n%s",decoded);
        stringstream str(reinterpret_cast<char*>(decoded));
        XMLString::release(&decoded);
        response=new SAMLResponse(str,minorVersion);
        
        try {
            // Check security bits in the outer wrapper (Recipient and IssueInstant).
            const XMLCh* recipient2=response->getRecipient();
            if (!recipient || !*recipient || !recipient2 || !*recipient2 || XMLString::compareString(recipient,recipient2))
                throw FatalProfileException("detected recipient mismatch in POST profile response");
        
            if (response->getIssueInstant()->getEpoch() < now-(2*config.clock_skew_secs))
                throw ExpiredAssertionException("detected expired POST profile response");
            
            // We don't verify the signature, but at least check for one.
            if (!response->isSigned())
                throw FatalProfileException("detected unsigned POST profile response");
        }
        catch (SAMLException& e) {
            // Save off issuer...
            Iterator<SAMLAssertion*> assertions=response->getAssertions();
            if (assertions.hasNext()) {
                auto_ptr_char issuer(assertions.next()->getIssuer());
                e.addProperty("issuer",issuer.get());
            }
            delete response;
            throw;
        }
#ifndef _DEBUG
        catch (...) {
            delete response;
            throw;
        }
#endif
    }
    else {
        // Check for artifacts.
        pair<CgiParse::walker,CgiParse::walker> raw_artifacts = parser.get_values("SAMLart");
        if (raw_artifacts.first==raw_artifacts.second)
            throw FatalProfileException("Unable to locate SAMLResponse or SAMLart parameters");

        if ((supportedProfiles & SAMLBrowserProfile::Artifact) == 0) {
            auto_ptr_char recip(recipient);
            log.error("attempted use of Browser/Artifact profile at unupported endpoint (%s)", recip.get());
            throw UnsupportedProfileException("blocked use of Browser/Artifact profile");
        }
        else if (!artifactMapper) {
            throw FatalProfileException("Support of artifact profile requires ArtifactMapper interface object");
        }
        supportedProfiles=SAMLBrowserProfile::Artifact;
        
        // Import the artifacts.
        vector<SAMLArtifact*> artifacts;
        for (; raw_artifacts.first!=raw_artifacts.second; raw_artifacts.first++) {
            try {
                log.debug("processing encoded artifact (%s)", raw_artifacts.first->second);
                
                // If a replay cache was provided, check for replay.
                if (replayCache) {
                    string key("A_");
                    key+=raw_artifacts.first->second;
                    if (!replayCache->check(key.c_str(),now + 2*config.clock_skew_secs))
                        throw ReplayedAssertionException(
                            string("rejecting replayed artifact (") + raw_artifacts.first->second + ")"
                            );
                }
                else
                    log.warn("replay cache was not provided, this is a potential security risk!");
                artifacts.push_back(SAMLArtifact::parse(raw_artifacts.first->second));
            }
            catch (...) {
                log.error("invalid artifact (%s)", raw_artifacts.first->second);
                for (Iterator<SAMLArtifact*> arts=artifacts; arts.hasNext();)
                    delete arts.next();
                throw;
            }
        }
        
        auto_ptr<SAMLRequest> request(new SAMLRequest(artifacts));
        request->setMinorVersion(minorVersion);
        // The request object (housed in the smart ptr) owns the artifacts now.
        artifacts.clear();
        
        // That's actually the hard part. The rest of the work is mostly done by the caller.
        // An exception might get tossed here, of course.
        response = artifactMapper->resolve(request.get());
    }
    
    // At this point, we have a seemingly valid response, either via POST or from an artifact callback.
    // This is messy. We have to basically guess as to where the authentication statement is, by finding
    // one with an appropriate subject confirmation method. We go for the first match inside a valid assertion.
    try {
        bool bExpired=false;
        for (Iterator<SAMLAssertion*> assertions=response->getAssertions(); !assertion && assertions.hasNext();) {
            bExpired=false;
            SAMLAssertion* a=assertions.next();
    
            // The assertion must be bounded front and back.
            const SAMLDateTime* notBefore=a->getNotBefore();
            const SAMLDateTime* notOnOrAfter=a->getNotOnOrAfter();
            if (!notBefore || !notOnOrAfter) {
                log.debug("skipping assertion without time conditions...");
                continue;
            }
    
            if (now+config.clock_skew_secs < notBefore->getEpoch()) {
                bExpired=true;
                log.debug("skipping assertion that's not yet valid...");
                continue;
            }
    
            if (notOnOrAfter->getEpoch() <= now-config.clock_skew_secs) {
                bExpired=true;
                log.debug("skipping expired assertion...");
                continue;
            }
    
            // Look for an authentication statement.
            for (Iterator<SAMLStatement*> statements=a->getStatements(); !authnStatement && statements.hasNext();) {
                SAMLStatement* s=statements.next();
                SAMLAuthenticationStatement* as=dynamic_cast<SAMLAuthenticationStatement*>(s);
                if (!as)
                    continue;
    
                const SAMLSubject* subject=as->getSubject();
                for (Iterator<const XMLCh*> methods=subject->getConfirmationMethods(); methods.hasNext();) {
                    const XMLCh* m=methods.next();
                    if ((supportedProfiles==SAMLBrowserProfile::Post &&
                            !XMLString::compareString(m,SAMLSubject::CONF_BEARER)) ||
                        !XMLString::compareString(m,SAMLSubject::CONF_ARTIFACT) ||
                        !XMLString::compareString(m,SAMLSubject::CONF_ARTIFACT01)) {
                        authnStatement=as;
                        assertion=a;
                        break;
                    }
                }
            }
        }
        if (!authnStatement) {
            if (bExpired && response->getAssertions().size()==1)
                throw ExpiredAssertionException("unable to accept assertion because of clock skew");
            throw FatalProfileException("Unable to locate a valid authentication statement");
        }
        else if (supportedProfiles==SAMLBrowserProfile::Post) {
            // Check for assertion replay. With artifact, the back-channel acts as a replay guard.
            if (replayCache) {
                auto_ptr_char id(assertion->getId());
                string key(id.get());
                key="P_" + key;
                if (!replayCache->check(key.c_str(),assertion->getNotOnOrAfter()->getEpoch()))
                    throw ReplayedAssertionException(string("Rejecting replayed assertion ID (") + id.get() + ")");
            }
            else
                log.warn("replay cache was not provided, this is a serious security risk!");
        }
    }
    catch (SAMLException& ex) {
        Iterator<SAMLAssertion*> assertions=response->getAssertions();
        if (assertions.hasNext()) {
            auto_ptr_char issuer(assertions.next()->getIssuer());
            ex.addProperty("issuer",issuer.get());
        }
        delete response;
        throw;
    }
#ifndef _DEBUG
    catch (...) {
        delete response;
        throw;
    }
#endif
    
    BrowserProfileResponse profileResponse;
    profileResponse.profile = static_cast<profiles_t>(supportedProfiles);
    profileResponse.response = response;
    profileResponse.assertion = assertion;
    profileResponse.authnStatement = authnStatement;
    
    // Extract TARGET parameter, if any. Might be required in SAML, but this is more forgiving.
    bits=parser.get_value("TARGET");
    if (bits)
        profileResponse.TARGET=bits;

    return profileResponse;
}

/*************************************************************************
 * CGI Parser implementation
 */

BrowserProfile::CgiParse::CgiParse(const char* data, unsigned int len)
{
    const char* pch = data;
    unsigned int cl = len;
        
    while (cl && pch) {
        char *name;
        char *value;
        value=fmakeword('&',&cl,&pch);
        plustospace(value);
        url_decode(value);
        name=makeword(value,'=');
        kvp_map.insert(pair<const string,char*>(name,value));
        free(name);
    }
}

BrowserProfile::CgiParse::~CgiParse()
{
    for (multimap<string,char*>::iterator i=kvp_map.begin(); i!=kvp_map.end(); i++)
        free(i->second);
}

char*
BrowserProfile::CgiParse::get_value(const char* name) const
{
    if (kvp_map.count(name)!=1)
        return NULL;
    return kvp_map.lower_bound(name)->second;
}

pair<BrowserProfile::CgiParse::walker,BrowserProfile::CgiParse::walker>
BrowserProfile::CgiParse::get_values(const char* name) const
{
    return kvp_map.equal_range(name);
}

/* Parsing routines modified from NCSA source. */
char *
BrowserProfile::CgiParse::makeword(char *line, char stop)
{
    int x = 0,y;
    char *word = (char *) malloc(sizeof(char) * (strlen(line) + 1));

    for(x=0;((line[x]) && (line[x] != stop));x++)
        word[x] = line[x];

    word[x] = '\0';
    if(line[x])
        ++x;
    y=0;

    while(line[x])
      line[y++] = line[x++];
      
    line[y] = '\0';
    return word;
}

char *
BrowserProfile::CgiParse::fmakeword(char stop, unsigned int *cl, const char** ppch)
{
    int wsize;
    char *word;
    int ll;

    wsize = 1024;
    ll=0;
    word = (char *) malloc(sizeof(char) * (wsize + 1));

    while(1)
    {
        word[ll] = *((*ppch)++);
        if(ll==wsize-1)
        {
            word[ll+1] = '\0';
            wsize+=1024;
            word = (char *)realloc(word,sizeof(char)*(wsize+1));
        }
        --(*cl);
        if((word[ll] == stop) || word[ll] == EOF || (!(*cl)))
        {
            if(word[ll] != stop)
                ll++;
            word[ll] = '\0';
            return word;
        }
        ++ll;
    }
}

void
BrowserProfile::CgiParse::plustospace(char *str)
{
    register int x;

    for(x=0;str[x];x++)
        if(str[x] == '+') str[x] = ' ';
}

char
BrowserProfile::CgiParse::x2c(char *what)
{
    register char digit;

    digit = (what[0] >= 'A' ? ((what[0] & 0xdf) - 'A')+10 : (what[0] - '0'));
    digit *= 16;
    digit += (what[1] >= 'A' ? ((what[1] & 0xdf) - 'A')+10 : (what[1] - '0'));
    return(digit);
}

void
BrowserProfile::CgiParse::url_decode(char *url)
{
    register int x,y;

    for(x=0,y=0;url[y];++x,++y)
    {
        if((url[x] = url[y]) == '%')
        {
            url[x] = x2c(&url[y+1]);
            y+=2;
        }
    }
    url[x] = '\0';
}

const XMLCh SAMLBrowserProfile::BROWSER_ARTIFACT[] = // urn:oasis:names:tc:SAML:1.0:profiles:artifact-01
{ chLatin_u, chLatin_r, chLatin_n, chColon, chLatin_o, chLatin_a, chLatin_s, chLatin_i, chLatin_s, chColon,
  chLatin_n, chLatin_a, chLatin_m, chLatin_e, chLatin_s, chColon, chLatin_t, chLatin_c, chColon,
  chLatin_S, chLatin_A, chLatin_M, chLatin_L, chColon, chDigit_1, chPeriod, chDigit_0, chColon,
  chLatin_p, chLatin_r, chLatin_o, chLatin_f, chLatin_i, chLatin_l, chLatin_e, chLatin_s, chColon,
  chLatin_a, chLatin_r, chLatin_t, chLatin_i, chLatin_f, chLatin_a, chLatin_c, chLatin_t, chDash, chDigit_0, chDigit_1, chNull
};

const XMLCh SAMLBrowserProfile::BROWSER_POST[] = // urn:oasis:names:tc:SAML:1.0:profiles:browser-post
{ chLatin_u, chLatin_r, chLatin_n, chColon, chLatin_o, chLatin_a, chLatin_s, chLatin_i, chLatin_s, chColon,
  chLatin_n, chLatin_a, chLatin_m, chLatin_e, chLatin_s, chColon, chLatin_t, chLatin_c, chColon,
  chLatin_S, chLatin_A, chLatin_M, chLatin_L, chColon, chDigit_1, chPeriod, chDigit_0, chColon,
  chLatin_p, chLatin_r, chLatin_o, chLatin_f, chLatin_i, chLatin_l, chLatin_e, chLatin_s, chColon,
  chLatin_b, chLatin_r, chLatin_o, chLatin_w, chLatin_s, chLatin_e, chLatin_r, chDash,
    chLatin_p, chLatin_o, chLatin_s, chLatin_t, chNull
};
