// DataPluggableProtocol.cpp : Implementation of CDataPluggableProtocol
#include "stdafx.h"
#include "DataURL.h"
#include "DataPluggableProtocol.h"
bool CDataURL::Parse(LPCWSTR szURL)
{
bool bRet = false;
Reset();
CAtlRegExp<CAtlRECharTraitsA> re;
re.Parse("data:{(.*?/.*?)}?(;{.*?}={.*?})?{;(base64)?}?,{.*}", 0);
CAtlREMatchContext<CAtlRECharTraitsA> ctxt;
//Unescape the URL
CAtlStringA strUrlUnescaped(szURL);
DWORD cchUnescaped = strUrlUnescaped.GetLength() + 1;
LPSTR szURLBuff = strUrlUnescaped.GetBuffer();
HRESULT hr = UrlUnescapeA(szURLBuff, NULL, &cchUnescaped, URL_UNESCAPE_INPLACE);
strUrlUnescaped.ReleaseBuffer();
if (SUCCEEDED(hr))
{
//See if the URL matches the pattern for data urls
if (re.Match(strUrlUnescaped, &ctxt))
{
bRet = true;
AssignMatchGroup(m_strMimeType, ctxt, 0);
AssignMatchGroup(m_strAttribute, ctxt, 1);
AssignMatchGroup(m_strValue, ctxt, 2);
CAtlStringA strBase64;
AssignMatchGroup(strBase64, ctxt, 3);
bool bIsBase64 = strBase64.GetLength() == 7;
CAtlStringA strData;
AssignMatchGroup(strData, ctxt, 4);
if (bIsBase64)
{
int nReqLen = Base64DecodeGetRequiredLength(strData.GetLength());
m_pvData = new BYTE[nReqLen];
int nDestLen = nReqLen;
bRet = Base64Decode(strData, strData.GetLength(), m_pvData, &nDestLen) != 0;
m_dwDataLength = nDestLen;
}
if(m_strMimeType.Left(5).CompareNoCase("text/") == 0)
{
//If it was base64 we already decoded it so we need to copy the data back to strData
if (bIsBase64)
{
strData.SetString(reinterpret_cast<LPSTR>(m_pvData), m_dwDataLength);
//Cleanup
delete [] m_pvData;
m_pvData = NULL;
m_dwDataLength = 0;
}
//Get the code page
HRESULT hr = S_OK;
CComPtr<IMultiLanguage2> spMLang;
if (SUCCEEDED(hr = spMLang.CoCreateInstance(CLSID_CMultiLanguage)))
{
MIMECSETINFO mi;
if (SUCCEEDED(hr = spMLang->GetCharsetInfo(CComBSTR(GetCharset()), &mi)))
{
int nSrcLen = strData.GetLength();
UINT uCodePage = mi.uiInternetEncoding;
int nWideChar = MultiByteToWideChar(uCodePage, 0, (LPCSTR)strData, nSrcLen, NULL, 0);
if (nWideChar == 0)
{
uCodePage = mi.uiCodePage;
nWideChar = MultiByteToWideChar(uCodePage, 0, (LPCSTR)strData, nSrcLen, NULL, 0);
}
if (nWideChar != 0)
{
WCHAR* sz = new WCHAR[nWideChar + 1];
MultiByteToWideChar(uCodePage, 0, (LPCSTR)strData, nSrcLen, sz + 1, nWideChar);
m_pvData = (BYTE*)sz;
m_dwDataLength = (nWideChar + 1) * 2;
//If data is in Unicode it should have unicode lead bytes
m_pvData[0] = 0xFF;
m_pvData[1] = 0xFE;
}
}
}
}
}
}
return bRet;
}
// CDataPluggableProtocol
STDMETHODIMP CDataPluggableProtocol::Start(
LPCWSTR szUrl,
IInternetProtocolSink *pIProtSink,
IInternetBindInfo *pIBindInfo,
DWORD grfSTI,
DWORD dwReserved)
{
HRESULT hr = S_OK;
if (m_url.Parse(szUrl))
{
m_dwPos = 0;
CAtlString strData(m_url.GetDataString());
pIProtSink->ReportProgress(BINDSTATUS_FINDINGRESOURCE, strData);
pIProtSink->ReportProgress(BINDSTATUS_CONNECTING, strData);
pIProtSink->ReportProgress(BINDSTATUS_SENDINGREQUEST, strData);
pIProtSink->ReportProgress(BINDSTATUS_VERIFIEDMIMETYPEAVAILABLE, CAtlString(m_url.GetMimeType()));
pIProtSink->ReportData(BSCF_FIRSTDATANOTIFICATION, 0, m_url.GetDataLength());
pIProtSink->ReportData(BSCF_LASTDATANOTIFICATION | BSCF_DATAFULLYAVAILABLE, m_url.GetDataLength(), m_url.GetDataLength());
}
else
{
if (grfSTI & PI_PARSE_URL)
hr = S_FALSE;
}
return hr;
}
STDMETHODIMP CDataPluggableProtocol::Continue(PROTOCOLDATA *pStateInfo)
{
return S_OK;
}
STDMETHODIMP CDataPluggableProtocol::Abort(HRESULT hrReason,DWORD dwOptions)
{
return S_OK;
}
STDMETHODIMP CDataPluggableProtocol::Terminate(DWORD dwOptions)
{
return S_OK;
}
STDMETHODIMP CDataPluggableProtocol::Suspend()
{
return E_NOTIMPL;
}
STDMETHODIMP CDataPluggableProtocol::Resume()
{
return E_NOTIMPL;
}
STDMETHODIMP CDataPluggableProtocol::LockRequest(DWORD dwOptions)
{
ATLTRACE(_T("LockRequest\n"));
return S_OK;
}
STDMETHODIMP CDataPluggableProtocol::UnlockRequest()
{
ATLTRACE(_T("UnlockRequest\n"));
//m_dwPos = 0;
return S_OK;
}
STDMETHODIMP CDataPluggableProtocol::Read(void *pv, ULONG cb, ULONG *pcbRead)
{
ATLTRACE(_T("READ - requested=%8d\n"), cb);
HRESULT hr = S_OK;
if (m_dwPos >= m_url.GetDataLength())
return S_FALSE;
BYTE* pbData = m_url.GetData() + m_dwPos;
DWORD cbAvail = m_url.GetDataLength() - m_dwPos;
memcpy_s(pv, cb, pbData, cbAvail);
if (cbAvail > cb)
{
m_dwPos += cb;
*pcbRead = cb;
}
else
{
m_dwPos += cbAvail;
*pcbRead = cbAvail;
}
return hr;
}
STDMETHODIMP CDataPluggableProtocol::Seek(
LARGE_INTEGER dlibMove,
DWORD dwOrigin,
ULARGE_INTEGER *plibNewPosition)
{
return E_NOTIMPL;
}
STDMETHODIMP CDataPluggableProtocol::CombineUrl(LPCWSTR pwzBaseUrl, LPCWSTR pwzRelativeUrl, DWORD dwCombineFlags,
LPWSTR pwzResult, DWORD cchResult, DWORD *pcchResult, DWORD dwReserved)
{
return INET_E_DEFAULT_ACTION;
}
STDMETHODIMP CDataPluggableProtocol::CompareUrl(LPCWSTR pwszUrl1, LPCWSTR pwszUrl2, DWORD dwCompareFlags)
{
ATLTRACE(_T("CompareUrl\n"));
if (pwszUrl1 == NULL || pwszUrl2 == NULL)
return E_POINTER;
HRESULT hr = S_FALSE;
CDataURL url1, url2;
if (url1.Parse(pwszUrl1) && url2.Parse(pwszUrl2) && url1 == url2)
{
hr = S_OK;
}
return hr;
}
STDMETHODIMP CDataPluggableProtocol::ParseUrl(LPCWSTR pwzUrl, PARSEACTION parseAction, DWORD dwParseFlags,
LPWSTR pwzResult, DWORD cchResult, DWORD *pcchResult, DWORD dwReserved)
{
return INET_E_DEFAULT_ACTION;
}
STDMETHODIMP CDataPluggableProtocol::QueryInfo( LPCWSTR pwzUrl, QUERYOPTION QueryOption, DWORD dwQueryFlags,
LPVOID pBuffer, DWORD cbBuffer, DWORD *pcbBuf, DWORD dwReserved)
{
return INET_E_DEFAULT_ACTION;
}