mirror of
				https://github.com/nicbarker/clay.git
				synced 2025-11-03 16:16:18 +00:00 
			
		
		
		
	Add proper support for function arguments
This commit is contained in:
		
							parent
							
								
									7c65f31f46
								
							
						
					
					
						commit
						4f4605eff9
					
				| 
						 | 
					@ -1,6 +1,7 @@
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from parser import ExtractedSymbolType
 | 
				
			||||||
from generators.base_generator import BaseGenerator
 | 
					from generators.base_generator import BaseGenerator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
| 
						 | 
					@ -61,23 +62,12 @@ TYPE_MAPPING = {
 | 
				
			||||||
    'int32_t': 'c.int32_t',
 | 
					    'int32_t': 'c.int32_t',
 | 
				
			||||||
    'uintptr_t': 'rawptr',
 | 
					    'uintptr_t': 'rawptr',
 | 
				
			||||||
    'void': 'void',
 | 
					    'void': 'void',
 | 
				
			||||||
 | 
					 | 
				
			||||||
    '*Clay_RectangleElementConfig': '^RectangleElementConfig',
 | 
					 | 
				
			||||||
    '*Clay_TextElementConfig': '^TextElementConfig',
 | 
					 | 
				
			||||||
    '*Clay_ImageElementConfig': '^ImageElementConfig',
 | 
					 | 
				
			||||||
    '*Clay_FloatingElementConfig': '^FloatingElementConfig',
 | 
					 | 
				
			||||||
    '*Clay_CustomElementConfig': '^CustomElementConfig',
 | 
					 | 
				
			||||||
    '*Clay_ScrollElementConfig': '^ScrollElementConfig',
 | 
					 | 
				
			||||||
    '*Clay_BorderElementConfig': '^BorderElementConfig',
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
STRUCT_TYPE_OVERRIDES = {
 | 
					STRUCT_TYPE_OVERRIDES = {
 | 
				
			||||||
    'Clay_Arena': {
 | 
					    'Clay_Arena': {
 | 
				
			||||||
        'nextAllocation': 'uintptr',
 | 
					        'nextAllocation': 'uintptr',
 | 
				
			||||||
        'capacity': 'uintptr',
 | 
					        'capacity': 'uintptr',
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
    'Clay_ErrorHandler': {
 | 
					 | 
				
			||||||
        'errorHandlerFunction': 'proc "c" (errorData: ErrorData)',
 | 
					 | 
				
			||||||
    },
 | 
					 | 
				
			||||||
    'Clay_SizingAxis': {
 | 
					    'Clay_SizingAxis': {
 | 
				
			||||||
        'size': 'SizingConstraints',
 | 
					        'size': 'SizingConstraints',
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
| 
						 | 
					@ -108,7 +98,6 @@ FUNCTION_TYPE_OVERRIDES = {
 | 
				
			||||||
        'offset': '[^]u8',
 | 
					        'offset': '[^]u8',
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
    'Clay_SetMeasureTextFunction': {
 | 
					    'Clay_SetMeasureTextFunction': {
 | 
				
			||||||
        'measureTextFunction': 'proc "c" (text: ^StringSlice, config: ^TextElementConfig, userData: uintptr) -> Dimensions',
 | 
					 | 
				
			||||||
        'userData': 'uintptr',
 | 
					        'userData': 'uintptr',
 | 
				
			||||||
    },
 | 
					    },
 | 
				
			||||||
    'Clay_RenderCommandArray_Get': {
 | 
					    'Clay_RenderCommandArray_Get': {
 | 
				
			||||||
| 
						 | 
					@ -155,7 +144,20 @@ class OdinGenerator(BaseGenerator):
 | 
				
			||||||
            return base_name
 | 
					            return base_name
 | 
				
			||||||
        raise ValueError(f'Unknown symbol: {symbol}')
 | 
					        raise ValueError(f'Unknown symbol: {symbol}')
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def resolve_binding_type(self, symbol: str, member: str | None, member_type: str | None, type_overrides: dict[str, dict[str, str]]) -> str | None:
 | 
					    def format_type(self, type: ExtractedSymbolType) -> str:
 | 
				
			||||||
 | 
					        if isinstance(type, str):
 | 
				
			||||||
 | 
					            return type
 | 
				
			||||||
 | 
					        
 | 
				
			||||||
 | 
					        parameter_strs = []
 | 
				
			||||||
 | 
					        for param_name, param_type in type['params']:
 | 
				
			||||||
 | 
					            parameter_strs.append(f"{param_name}: {self.format_type(param_type or 'unknown')}")
 | 
				
			||||||
 | 
					        return_type_str = ''
 | 
				
			||||||
 | 
					        if type['return_type'] is not None and type['return_type'] != 'void':
 | 
				
			||||||
 | 
					            return_type_str = ' -> ' + self.format_type(type['return_type'])
 | 
				
			||||||
 | 
					        return f"proc \"c\" ({', '.join(parameter_strs)}){return_type_str}"
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    def resolve_binding_type(self, symbol: str, member: str | None, member_type: ExtractedSymbolType | None, type_overrides: dict[str, dict[str, str]]) -> ExtractedSymbolType | None:
 | 
				
			||||||
 | 
					        if isinstance(member_type, str):
 | 
				
			||||||
            if member_type in SYMBOL_COMPLETE_OVERRIDES:
 | 
					            if member_type in SYMBOL_COMPLETE_OVERRIDES:
 | 
				
			||||||
                return SYMBOL_COMPLETE_OVERRIDES[member_type]
 | 
					                return SYMBOL_COMPLETE_OVERRIDES[member_type]
 | 
				
			||||||
            if symbol in type_overrides and member in type_overrides[symbol]:
 | 
					            if symbol in type_overrides and member in type_overrides[symbol]:
 | 
				
			||||||
| 
						 | 
					@ -169,6 +171,22 @@ class OdinGenerator(BaseGenerator):
 | 
				
			||||||
                if result:
 | 
					                if result:
 | 
				
			||||||
                    return f"^{result}"
 | 
					                    return f"^{result}"
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
 | 
					        if member_type is None:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        resolved_parameters = []
 | 
				
			||||||
 | 
					        for param_name, param_type in member_type['params']:
 | 
				
			||||||
 | 
					            resolved_param = self.resolve_binding_type(symbol, param_name, param_type, type_overrides)
 | 
				
			||||||
 | 
					            if resolved_param is None:
 | 
				
			||||||
 | 
					                return None
 | 
				
			||||||
 | 
					            resolved_parameters.append((param_name, resolved_param))
 | 
				
			||||||
 | 
					        resolved_return_type = self.resolve_binding_type(symbol, None, member_type['return_type'], type_overrides)
 | 
				
			||||||
 | 
					        if resolved_return_type is None:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					        return {
 | 
				
			||||||
 | 
					            "params": resolved_parameters,
 | 
				
			||||||
 | 
					            "return_type": resolved_return_type,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def generate_structs(self) -> None:
 | 
					    def generate_structs(self) -> None:
 | 
				
			||||||
        for struct, struct_data in sorted(self.extracted_symbols.structs.items(), key=lambda x: x[0]):
 | 
					        for struct, struct_data in sorted(self.extracted_symbols.structs.items(), key=lambda x: x[0]):
 | 
				
			||||||
| 
						 | 
					@ -184,12 +202,15 @@ class OdinGenerator(BaseGenerator):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            if struct in STRUCT_OVERRIDE_AS_FIXED_ARRAY:
 | 
					            if struct in STRUCT_OVERRIDE_AS_FIXED_ARRAY:
 | 
				
			||||||
                array_size = len(members)
 | 
					                array_size = len(members)
 | 
				
			||||||
                array_type = list(members.values())[0]['type']
 | 
					                first_elem = list(members.values())[0]
 | 
				
			||||||
 | 
					                array_type = None
 | 
				
			||||||
 | 
					                if 'type' in first_elem:
 | 
				
			||||||
 | 
					                   array_type = first_elem['type']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if array_type in TYPE_MAPPING:
 | 
					                if array_type in TYPE_MAPPING:
 | 
				
			||||||
                    array_binding_type = TYPE_MAPPING[array_type]
 | 
					                    array_binding_type = TYPE_MAPPING[array_type]
 | 
				
			||||||
                elif array_type and self.has_symbol(array_type):
 | 
					                elif array_type and self.has_symbol(self.format_type(array_type)):
 | 
				
			||||||
                    array_binding_type = self.get_symbol_name(array_type)
 | 
					                    array_binding_type = self.get_symbol_name(self.format_type(array_type))
 | 
				
			||||||
                else:
 | 
					                else:
 | 
				
			||||||
                    self._write('struct', f"// {struct} ({array_type}) - has no mapping")
 | 
					                    self._write('struct', f"// {struct} ({array_type}) - has no mapping")
 | 
				
			||||||
                    continue
 | 
					                    continue
 | 
				
			||||||
| 
						 | 
					@ -221,6 +242,7 @@ class OdinGenerator(BaseGenerator):
 | 
				
			||||||
                if member_binding_type is None:
 | 
					                if member_binding_type is None:
 | 
				
			||||||
                    self._write('struct', f"    // {binding_member_name} ({member_type}) - has no mapping")
 | 
					                    self._write('struct', f"    // {binding_member_name} ({member_type}) - has no mapping")
 | 
				
			||||||
                    continue
 | 
					                    continue
 | 
				
			||||||
 | 
					                member_binding_type = self.format_type(member_binding_type)
 | 
				
			||||||
                self._write('struct', f"    {binding_member_name}: {member_binding_type}, // {member} ({member_type})")
 | 
					                self._write('struct', f"    {binding_member_name}: {member_binding_type}, // {member} ({member_type})")
 | 
				
			||||||
            self._write('struct', "}")
 | 
					            self._write('struct', "}")
 | 
				
			||||||
            self._write('struct', '')
 | 
					            self._write('struct', '')
 | 
				
			||||||
| 
						 | 
					@ -272,6 +294,7 @@ class OdinGenerator(BaseGenerator):
 | 
				
			||||||
            if binding_return_type is None:
 | 
					            if binding_return_type is None:
 | 
				
			||||||
                self._write(write_to, f"    // {function} ({return_type}) - has no mapping")
 | 
					                self._write(write_to, f"    // {function} ({return_type}) - has no mapping")
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
 | 
					            binding_return_type = self.format_type(binding_return_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            skip = False
 | 
					            skip = False
 | 
				
			||||||
            binding_params = []
 | 
					            binding_params = []
 | 
				
			||||||
| 
						 | 
					@ -282,6 +305,8 @@ class OdinGenerator(BaseGenerator):
 | 
				
			||||||
                binding_param_type = self.resolve_binding_type(function, param_name, param_type, FUNCTION_TYPE_OVERRIDES)
 | 
					                binding_param_type = self.resolve_binding_type(function, param_name, param_type, FUNCTION_TYPE_OVERRIDES)
 | 
				
			||||||
                if binding_param_type is None:
 | 
					                if binding_param_type is None:
 | 
				
			||||||
                    skip = True
 | 
					                    skip = True
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    binding_param_type = self.format_type(binding_param_type)
 | 
				
			||||||
                binding_params.append(f"{binding_param_name}: {binding_param_type}")
 | 
					                binding_params.append(f"{binding_param_name}: {binding_param_type}")
 | 
				
			||||||
            if skip:
 | 
					            if skip:
 | 
				
			||||||
                self._write(write_to, f"    // {function} - has no mapping")
 | 
					                self._write(write_to, f"    // {function} - has no mapping")
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,5 +1,5 @@
 | 
				
			||||||
from dataclasses import dataclass
 | 
					from dataclasses import dataclass
 | 
				
			||||||
from typing import Optional, TypedDict, NotRequired
 | 
					from typing import Optional, TypedDict, NotRequired, Union
 | 
				
			||||||
from pycparser import c_ast, parse_file, preprocess_file
 | 
					from pycparser import c_ast, parse_file, preprocess_file
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
| 
						 | 
					@ -9,22 +9,24 @@ import logging
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ExtractedSymbolType = Union[str, "ExtractedFunction"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ExtractedStructAttributeUnion(TypedDict):
 | 
					class ExtractedStructAttributeUnion(TypedDict):
 | 
				
			||||||
    type: Optional[str]
 | 
					    type: Optional[ExtractedSymbolType]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ExtractedStructAttribute(TypedDict):
 | 
					class ExtractedStructAttribute(TypedDict):
 | 
				
			||||||
    type: Optional[str]
 | 
					    type: NotRequired[ExtractedSymbolType]
 | 
				
			||||||
    union: Optional[dict[str, Optional[str]]]
 | 
					    union: NotRequired[dict[str, Optional[ExtractedSymbolType]]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ExtractedStruct(TypedDict):
 | 
					class ExtractedStruct(TypedDict):
 | 
				
			||||||
    attrs: dict[str, ExtractedStructAttribute]
 | 
					    attrs: dict[str, ExtractedStructAttribute]
 | 
				
			||||||
    is_union: NotRequired[bool]
 | 
					    is_union: NotRequired[bool]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ExtractedEnum = dict[str, Optional[str]]
 | 
					ExtractedEnum = dict[str, Optional[str]]
 | 
				
			||||||
ExtractedFunctionParam = tuple[str, Optional[str]]
 | 
					ExtractedFunctionParam = tuple[str, Optional[ExtractedSymbolType]]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ExtractedFunction(TypedDict):
 | 
					class ExtractedFunction(TypedDict):
 | 
				
			||||||
    return_type: Optional[str]
 | 
					    return_type: Optional["ExtractedSymbolType"]
 | 
				
			||||||
    params: list[ExtractedFunctionParam]
 | 
					    params: list[ExtractedFunctionParam]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
| 
						 | 
					@ -33,11 +35,21 @@ class ExtractedSymbols:
 | 
				
			||||||
    enums: dict[str, ExtractedEnum]
 | 
					    enums: dict[str, ExtractedEnum]
 | 
				
			||||||
    functions: dict[str, ExtractedFunction]
 | 
					    functions: dict[str, ExtractedFunction]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_type_names(node: c_ast.Node, prefix: str="") -> Optional[str]:
 | 
					def get_type_names(node: c_ast.Node, prefix: str="") -> Optional[ExtractedSymbolType]:
 | 
				
			||||||
    if isinstance(node, c_ast.TypeDecl) and hasattr(node, 'quals') and node.quals:
 | 
					    if isinstance(node, c_ast.TypeDecl) and hasattr(node, 'quals') and node.quals:
 | 
				
			||||||
        prefix = " ".join(node.quals) + " " + prefix
 | 
					        prefix = " ".join(node.quals) + " " + prefix
 | 
				
			||||||
    if isinstance(node, c_ast.PtrDecl):
 | 
					    if isinstance(node, c_ast.PtrDecl):
 | 
				
			||||||
        prefix = "*" + prefix
 | 
					        prefix = "*" + prefix
 | 
				
			||||||
 | 
					    if isinstance(node, c_ast.FuncDecl):
 | 
				
			||||||
 | 
					        func: ExtractedFunction = {
 | 
				
			||||||
 | 
					            'return_type': get_type_names(node.type),
 | 
				
			||||||
 | 
					            'params': [],
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        for param in node.args.params:
 | 
				
			||||||
 | 
					            if param.name is None:
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            func['params'].append((param.name, get_type_names(param)))
 | 
				
			||||||
 | 
					        return func
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if hasattr(node, 'names'):
 | 
					    if hasattr(node, 'names'):
 | 
				
			||||||
        return prefix + node.names[0] # type: ignore
 | 
					        return prefix + node.names[0] # type: ignore
 | 
				
			||||||
| 
						 | 
					@ -62,7 +74,7 @@ class Visitor(c_ast.NodeVisitor):
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        if hasattr(node_type, "declname"):
 | 
					        if hasattr(node_type, "declname"):
 | 
				
			||||||
            return_type = get_type_names(node_type.type)
 | 
					            return_type = get_type_names(node_type.type)
 | 
				
			||||||
            if return_type is not None and is_pointer:
 | 
					            if return_type is not None and isinstance(return_type, str) and is_pointer:
 | 
				
			||||||
                return_type = "*" + return_type
 | 
					                return_type = "*" + return_type
 | 
				
			||||||
            func: ExtractedFunction = {
 | 
					            func: ExtractedFunction = {
 | 
				
			||||||
                'return_type': return_type,
 | 
					                'return_type': return_type,
 | 
				
			||||||
| 
						 | 
					@ -91,6 +103,8 @@ class Visitor(c_ast.NodeVisitor):
 | 
				
			||||||
    def visit_Typedef(self, node: c_ast.Typedef):
 | 
					    def visit_Typedef(self, node: c_ast.Typedef):
 | 
				
			||||||
        # node.show()
 | 
					        # node.show()
 | 
				
			||||||
        if hasattr(node.type, 'type') and hasattr(node.type.type, 'decls') and node.type.type.decls:
 | 
					        if hasattr(node.type, 'type') and hasattr(node.type.type, 'decls') and node.type.type.decls:
 | 
				
			||||||
 | 
					            if node.name == "Clay_ErrorHandler":
 | 
				
			||||||
 | 
					                logger.debug(node)
 | 
				
			||||||
            struct = {}
 | 
					            struct = {}
 | 
				
			||||||
            for decl in node.type.type.decls:
 | 
					            for decl in node.type.type.decls:
 | 
				
			||||||
                if hasattr(decl, 'type') and hasattr(decl.type, 'type') and isinstance(decl.type.type, c_ast.Union):
 | 
					                if hasattr(decl, 'type') and hasattr(decl.type, 'type') and isinstance(decl.type.type, c_ast.Union):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue