Click here to Skip to main content
15,880,972 members
Articles / Programming Languages / C#
Article

Get the calling module that invoked your function

Rate me:
Please Sign up or sign in to vote.
4.38/5 (6 votes)
29 Jan 2007CPOL2 min read 39.2K   846   20   2
How to get the calling module that invoked your function as an entry point.

Sample Image - getcallingmodule.png

Introduction

Sometimes is necessary to know which module is calling your exported function. It could be that you want to verify that the caller module is certified by you, or your function should return a different result based on the calling context or whatnot.

Background

My article is based on Chavdar Dimitrov's article that works great for native code only. His example did not address the .NET caller situation, and the old DLLs would make the release version crash, probably due to tightened Windows security. Therefore, I brushed it a bit, and made it fulfill these goals, albeit with reduced functionality.

Digging the code

The function is written in C++ to allow easy calls from both managed and unmanaged code. I've provided DLLs and EXEs for various situations that can occur. Also, it's mandatory to have the .NET 2.0 runtime installed since stackdumper.dll is a mixed DLL, where certain files are compiled with the /clr option. The external code will call GetCallingModulePath, an entry point in the stackdumper.dll. For some reason, it is important that this function has an argument.

MC++
const char* _stdcall GetCallingModulePath(int arg)
{
  long reg_ebd;
  __asm{
    mov eax, ebp
    mov reg_ebd, eax
  }
  ADDR callerAddr;
  unsigned i = 0;
  HANDLE h = GetCurrentProcess();
    module.empty();      
    callerAddr = GetCallerAddr(reg_ebd);
    if (callerAddr == 0)
        goto last;
      
    if(getFuncInfo(callerAddr,module) > 0)
    {
       BOOL bnet = IsDotNetRuntime((char *)module.c_str());
       if(bnet)
       {
        BOOL bres = GetDotNetCallerFileName(module);
        if(bres == TRUE)
            goto last;
       }
    }
      long temp = 0;
      SIZE_T cnt;
      long* p = (long*)reg_ebd;
      BOOL bres = ReadProcessMemory(h,(LPCVOID)p,(LPVOID)&temp,sizeof(long),&cnt);
      reg_ebd = temp;
      i++;

last:
.......

If you want to understand how stack tracing works in native code, you should start with Dimitrov's article mentioned above.

Points of interest

The GetDotNetCallerFileName managed function will skip all the .NET runtime assemblies in the stack, returning the assembly that really called your function.

MC++
int GetDotNetCallerFileName(string& module)
{
    int res = FALSE;
    try
    {
        Assembly^ callerAssembly = Assembly::GetCallingAssembly();
        if(callerAssembly == nullptr)
            return FALSE;
        String^ sysdir = 
         System::Runtime::InteropServices::RuntimeEnvironment::GetRuntimeDirectory();
        //skip all .net framework assemblies and calls from the same assembly

        String^ strCallerPath = callerAssembly->Location;
        String^ directoryName = nullptr;
        if(strCallerPath != nullptr)
            directoryName = Path::GetDirectoryName( strCallerPath ) + "\\";

        while(directoryName != nullptr 
              && (String::Compare(sysdir,directoryName,true) == 0) 
              || callerAssembly == Assembly::GetExecutingAssembly())
        {
            strCallerPath = nullptr;
            callerAssembly = callerAssembly->GetCallingAssembly();
            directoryName = nullptr;
            if(strCallerPath != nullptr)
                directoryName = Path::GetDirectoryName( strCallerPath ) + "\\";
            strCallerPath = callerAssembly->Location;
        }
.........

Calling the code

For native calls, everything is straightforward for direct and indirect calls:

MC++
const char *szc = GetCallingModulePath(1);
printf("expected result: test.exe\nfinal result = %s\n\n", szc);

szc = NativeDllCall(1);
printf("expected result: Nativedll.dll\nfinal result = %s\n", szc);

char** pszc = NULL, **original = NULL;
int cnt = GetModuleStackTraceFromNative(pszc);
original = pszc;
printf("\nstack trace---------\n");
while(cnt-- > 0)
{
  printf("Trace: = %s\n", *pszc);
  CoTaskMemFree(*pszc);
  pszc++;
}
CoTaskMemFree(original);

There is no need to call CoTaskMemFree after GetCallingModulePath because the returned pointer is a global variable in stackdumper.dll. If calling from managed code( like C#), you have to get the method as an extern import:

C#
[DllImport("stackdumper.dll", CharSet = CharSet.Ansi)]
static extern string GetCallingModulePath(int arg);

and then invoke it as a static method.

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)


Written By
Software Developer (Senior)
United States United States
Decebal Mihailescu is a software engineer with interest in .Net, C# and C++.

Comments and Discussions

 
QuestionCode does not compile Pin
theprogrammer15-Oct-14 21:53
theprogrammer15-Oct-14 21:53 
Hi,
I'm trying to use the code of this article to get the public key token of caller assembly from native Cpp DLL, but the project crashes when compiling with VS2013.

Could you post an updated example or provide feedbacks to fix the solution ?

Thanks,
Maurizio
AnswerRe: Code does not compile Pin
dmihailescu16-Oct-14 2:14
dmihailescu16-Oct-14 2:14 

General General    News News    Suggestion Suggestion    Question Question    Bug Bug    Answer Answer    Joke Joke    Praise Praise    Rant Rant    Admin Admin   

Use Ctrl+Left/Right to switch messages, Ctrl+Up/Down to switch threads, Ctrl+Shift+Left/Right to switch pages.