Add proper support for function arguments

This commit is contained in:
Harrison Lambeth 2025-01-26 14:43:51 -07:00
parent 7c65f31f46
commit 4f4605eff9
2 changed files with 76 additions and 37 deletions

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional, TypedDict, NotRequired
from typing import Optional, TypedDict, NotRequired, Union
from pycparser import c_ast, parse_file, preprocess_file
from pathlib import Path
import os
@ -9,22 +9,24 @@ import logging
logger = logging.getLogger(__name__)
ExtractedSymbolType = Union[str, "ExtractedFunction"]
class ExtractedStructAttributeUnion(TypedDict):
type: Optional[str]
type: Optional[ExtractedSymbolType]
class ExtractedStructAttribute(TypedDict):
type: Optional[str]
union: Optional[dict[str, Optional[str]]]
type: NotRequired[ExtractedSymbolType]
union: NotRequired[dict[str, Optional[ExtractedSymbolType]]]
class ExtractedStruct(TypedDict):
attrs: dict[str, ExtractedStructAttribute]
is_union: NotRequired[bool]
ExtractedEnum = dict[str, Optional[str]]
ExtractedFunctionParam = tuple[str, Optional[str]]
ExtractedFunctionParam = tuple[str, Optional[ExtractedSymbolType]]
class ExtractedFunction(TypedDict):
return_type: Optional[str]
return_type: Optional["ExtractedSymbolType"]
params: list[ExtractedFunctionParam]
@dataclass
@ -33,11 +35,21 @@ class ExtractedSymbols:
enums: dict[str, ExtractedEnum]
functions: dict[str, ExtractedFunction]
def get_type_names(node: c_ast.Node, prefix: str="") -> Optional[str]:
def get_type_names(node: c_ast.Node, prefix: str="") -> Optional[ExtractedSymbolType]:
if isinstance(node, c_ast.TypeDecl) and hasattr(node, 'quals') and node.quals:
prefix = " ".join(node.quals) + " " + prefix
if isinstance(node, c_ast.PtrDecl):
prefix = "*" + prefix
if isinstance(node, c_ast.FuncDecl):
func: ExtractedFunction = {
'return_type': get_type_names(node.type),
'params': [],
}
for param in node.args.params:
if param.name is None:
continue
func['params'].append((param.name, get_type_names(param)))
return func
if hasattr(node, 'names'):
return prefix + node.names[0] # type: ignore
@ -62,7 +74,7 @@ class Visitor(c_ast.NodeVisitor):
if hasattr(node_type, "declname"):
return_type = get_type_names(node_type.type)
if return_type is not None and is_pointer:
if return_type is not None and isinstance(return_type, str) and is_pointer:
return_type = "*" + return_type
func: ExtractedFunction = {
'return_type': return_type,
@ -91,6 +103,8 @@ class Visitor(c_ast.NodeVisitor):
def visit_Typedef(self, node: c_ast.Typedef):
# node.show()
if hasattr(node.type, 'type') and hasattr(node.type.type, 'decls') and node.type.type.decls:
if node.name == "Clay_ErrorHandler":
logger.debug(node)
struct = {}
for decl in node.type.type.decls:
if hasattr(decl, 'type') and hasattr(decl.type, 'type') and isinstance(decl.type.type, c_ast.Union):