Click here to Skip to main content
Click here to Skip to main content

Get the calling module that invoked your function

, 29 Jan 2007 CPOL
Rate this:
Please Sign up or sign in to vote.
How to get the calling module that invoked your function as an entry point.

Sample Image - getcallingmodule.png

Introduction

Sometimes is necessary to know which module is calling your exported function. It could be that you want to verify that the caller module is certified by you, or your function should return a different result based on the calling context or whatnot.

Background

My article is based on Chavdar Dimitrov's article that works great for native code only. His example did not address the .NET caller situation, and the old DLLs would make the release version crash, probably due to tightened Windows security. Therefore, I brushed it a bit, and made it fulfill these goals, albeit with reduced functionality.

Digging the code

The function is written in C++ to allow easy calls from both managed and unmanaged code. I've provided DLLs and EXEs for various situations that can occur. Also, it's mandatory to have the .NET 2.0 runtime installed since stackdumper.dll is a mixed DLL, where certain files are compiled with the /clr option. The external code will call GetCallingModulePath, an entry point in the stackdumper.dll. For some reason, it is important that this function has an argument.

const char* _stdcall GetCallingModulePath(int arg)
{
  long reg_ebd;
  __asm{
    mov eax, ebp
    mov reg_ebd, eax
  }
  ADDR callerAddr;
  unsigned i = 0;
  HANDLE h = GetCurrentProcess();
    module.empty();      
    callerAddr = GetCallerAddr(reg_ebd);
    if (callerAddr == 0)
        goto last;
      
    if(getFuncInfo(callerAddr,module) > 0)
    {
       BOOL bnet = IsDotNetRuntime((char *)module.c_str());
       if(bnet)
       {
        BOOL bres = GetDotNetCallerFileName(module);
        if(bres == TRUE)
            goto last;
       }
    }
      long temp = 0;
      SIZE_T cnt;
      long* p = (long*)reg_ebd;
      BOOL bres = ReadProcessMemory(h,(LPCVOID)p,(LPVOID)&temp,sizeof(long),&cnt);
      reg_ebd = temp;
      i++;

last:
.......

If you want to understand how stack tracing works in native code, you should start with Dimitrov's article mentioned above.

Points of interest

The GetDotNetCallerFileName managed function will skip all the .NET runtime assemblies in the stack, returning the assembly that really called your function.

int GetDotNetCallerFileName(string& module)
{
    int res = FALSE;
    try
    {
        Assembly^ callerAssembly = Assembly::GetCallingAssembly();
        if(callerAssembly == nullptr)
            return FALSE;
        String^ sysdir = 
         System::Runtime::InteropServices::RuntimeEnvironment::GetRuntimeDirectory();
        //skip all .net framework assemblies and calls from the same assembly

        String^ strCallerPath = callerAssembly->Location;
        String^ directoryName = nullptr;
        if(strCallerPath != nullptr)
            directoryName = Path::GetDirectoryName( strCallerPath ) + "\\";

        while(directoryName != nullptr 
              && (String::Compare(sysdir,directoryName,true) == 0) 
              || callerAssembly == Assembly::GetExecutingAssembly())
        {
            strCallerPath = nullptr;
            callerAssembly = callerAssembly->GetCallingAssembly();
            directoryName = nullptr;
            if(strCallerPath != nullptr)
                directoryName = Path::GetDirectoryName( strCallerPath ) + "\\";
            strCallerPath = callerAssembly->Location;
        }
.........

Calling the code

For native calls, everything is straightforward for direct and indirect calls:

  const char *szc = GetCallingModulePath(1);
  printf("expected result: test.exe\nfinal result = %s\n\n", szc);

  szc = NativeDllCall(1);
  printf("expected result: Nativedll.dll\nfinal result = %s\n", szc);

  char** pszc = NULL, **original = NULL;
  int cnt = GetModuleStackTraceFromNative(pszc);
  original = pszc;
  printf("\nstack trace---------\n");
  while(cnt-- > 0)
  {
    printf("Trace: = %s\n", *pszc);
    CoTaskMemFree(*pszc);
    pszc++;
  }
  CoTaskMemFree(original);

There is no need to call CoTaskMemFree after GetCallingModulePath because the returned pointer is a global variable in stackdumper.dll. If calling from managed code( like C#), you have to get the method as an extern import:

  [DllImport("stackdumper.dll", CharSet = CharSet.Ansi)]
  static extern string GetCallingModulePath(int arg);

and then invoke it as a static method.

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)

Share

About the Author

dmihailescu
Software Developer (Senior)
United States United States
Decebal Mihailescu is a software engineer with interest in .Net, C# and C++.

Comments and Discussions

 
QuestionCode does not compile Pinmembertheprogrammer15-Oct-14 22:53 
AnswerRe: Code does not compile Pinmemberdmihailescu16-Oct-14 3:14 

General General    News News    Suggestion Suggestion    Question Question    Bug Bug    Answer Answer    Joke Joke    Rant Rant    Admin Admin   

Use Ctrl+Left/Right to switch messages, Ctrl+Up/Down to switch threads, Ctrl+Shift+Left/Right to switch pages.

| Advertise | Privacy | Terms of Use | Mobile
Web03 | 2.8.1411019.1 | Last Updated 29 Jan 2007
Article Copyright 2007 by dmihailescu
Everything else Copyright © CodeProject, 1999-2014
Layout: fixed | fluid