#include "QueryMng.h"
#include "IRPUtils.h"
#include <algorithm>
#include <string>
namespace utils
{
void QueryMng::Subscribe(ULONG controlCode,IQueryDispatcher* dispatcher)
{
utils::SectionWriteGuard guard(&mapLock_);
DispatchMap::iterator it = dispatchMap_.find(controlCode);
if( it == dispatchMap_.end())
{
DispatchList list;
list.push_back(dispatcher);
dispatchMap_.insert( std::make_pair(controlCode,list) );
}
else
{
DispatchList& list = it->second;
DispatchList::const_iterator it =
std::find(list.begin(),list.end(),dispatcher);
if(it != list.end())
throw std::exception("This dispatcher already subscribed");
list.push_back(dispatcher);
}
}
void QueryMng::UnSubscribe(ULONG controlCode,IQueryDispatcher* dispatcher)
{
utils::SectionWriteGuard guard(&mapLock_);
DispatchMap::iterator it = dispatchMap_.find(controlCode);
if( it == dispatchMap_.end())
{
DispatchList list;
list.push_back(dispatcher);
dispatchMap_.insert( std::make_pair(controlCode,list) );
}
else
{
it->second.push_back(dispatcher);
}
}
NTSTATUS QueryMng::ProcessIrp(PIRP irp)
{
NTSTATUS status = STATUS_SUCCESS;
ULONG bytesTxd =0; // Number of transmitted,received bytes
PIO_STACK_LOCATION IrpStack=IoGetCurrentIrpStackLocation(irp);
// Getting the IOCTL code
ULONG controlCode = IrpStack->Parameters.DeviceIoControl.IoControlCode;
// Getting the exchange method
ULONG method = controlCode & 0x03;
if(method!=METHOD_BUFFERED)
return utils::CompleteIrp(irp,STATUS_INVALID_PARAMETER,bytesTxd);
// input buffer size
ULONG inputLength =
IrpStack->Parameters.DeviceIoControl.InputBufferLength;
// output buffer size
ULONG outputLength =
IrpStack->Parameters.DeviceIoControl.OutputBufferLength;
// if there is no buffer generate the error
if( outputLength < 1 || inputLength < 1)
return utils::CompleteIrp(irp,STATUS_INVALID_PARAMETER,bytesTxd);
WCHAR *buff = (PWCHAR)irp->AssociatedIrp.SystemBuffer;
try
{
CallDispatchers(controlCode,buff,inputLength,outputLength,&bytesTxd);
}
catch (const std::exception& ex)
{
std::string str(ex.what());
size_t toWrite = min(str.size(),outputLength - 1);
memcpy(buff,str.c_str(),toWrite);
bytesTxd = toWrite + 1; // add '\\0' character
char* mbStr = (char*)buff;
mbStr[toWrite] = '\0';
}
return utils::CompleteIrp(irp,STATUS_SUCCESS,bytesTxd);
}
void QueryMng::CallDispatchers(ULONG controlCode,
WCHAR* buf,
ULONG inputBufSize,
ULONG outputBufSize,
ULONG* bytesTxd)
{
utils::SectionReadGuard guard(&mapLock_);
DispatchMap::const_iterator it = dispatchMap_.find(controlCode);
if( it == dispatchMap_.end() )
throw std::exception(__FUNCTION__"No one dispatcher for this control code");
const DispatchList& dispatchList = it->second;
DispatchList::const_iterator itList = dispatchList.begin();
for( ; itList != dispatchList.end() ; ++itList )
{
(*itList)->Dispatch(controlCode,buf,inputBufSize,outputBufSize,bytesTxd);
}
}
}