From 1260c86f603b07a468ad3ca19b575a118ea3f41b Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Mon, 29 Apr 2024 19:33:18 +0200 Subject: [PATCH] generate proxying wrappers Adds InstanceProxy, DeviceProxy, QueueProxy, CommandBufferProxy. These act similar to the wrapper types, except that they store and implicitly pass their respective handles. --- src/vulkan/render.zig | 233 +++++++++++++++++++++++++++++++++++------- 1 file changed, 195 insertions(+), 38 deletions(-) diff --git a/src/vulkan/render.zig b/src/vulkan/render.zig index 8190621..00d974a 100644 --- a/src/vulkan/render.zig +++ b/src/vulkan/render.zig @@ -144,6 +144,40 @@ const foreign_types = std.StaticStringMap([]const u8).initComptime(.{ .{ "IDirectFBSurface", "opaque {}" }, }); +const CommandDispatchType = enum { + base, + instance, + device, + + fn name(self: CommandDispatchType) []const u8 { + return switch (self) { + .base => "Base", + .instance => "Instance", + .device => "Device", + }; + } + + fn nameLower(self: CommandDispatchType) []const u8 { + return switch (self) { + .base => "base", + .instance => "instance", + .device => "device", + }; + } +}; + +const dispatchable_handles = std.StaticStringMap(CommandDispatchType).initComptime(.{ + .{ "VkDevice", .device }, + .{ "VkCommandBuffer", .device }, + .{ "VkQueue", .device }, + .{ "VkInstance", .instance }, +}); + +const dispatch_override_functions = std.StaticStringMap(CommandDispatchType).initComptime(.{ + .{ "vkGetInstanceProcAddr", .base }, + .{ "vkGetDeviceProcAddr", .instance }, +}); + fn eqlIgnoreCase(lhs: []const u8, rhs: []const u8) bool { if (lhs.len != rhs.len) { return false; @@ -194,6 +228,7 @@ fn Renderer(comptime WriterType: type) type { bitflags, mut_buffer_len, buffer_len, + dispatch_handle, other, }; @@ -206,12 +241,6 @@ fn Renderer(comptime WriterType: type) type { }, }; - const CommandDispatchType = enum { - base, - instance, - device, - }; - writer: WriterType, allocator: Allocator, registry: *const reg.Registry, @@ -417,6 +446,10 @@ fn Renderer(comptime WriterType: type) type { } }, .name => |name| { + if (dispatchable_handles.get(name) != null) { + return .dispatch_handle; + } + if ((try self.extractBitflagName(name)) != null or self.isFlags(name)) { return .bitflags; } @@ -432,35 +465,20 @@ fn Renderer(comptime WriterType: type) type { } fn classifyCommandDispatch(name: []const u8, command: reg.Command) CommandDispatchType { - const device_handles = std.StaticStringMap(void).initComptime(.{ - .{ "VkDevice", {} }, - .{ "VkCommandBuffer", {} }, - .{ "VkQueue", {} }, - }); - - const override_functions = std.StaticStringMap(CommandDispatchType).initComptime(.{ - .{ "vkGetInstanceProcAddr", .base }, - .{ "vkCreateInstance", .base }, - .{ "vkEnumerateInstanceLayerProperties", .base }, - .{ "vkEnumerateInstanceExtensionProperties", .base }, - .{ "vkEnumerateInstanceVersion", .base }, - .{ "vkGetDeviceProcAddr", .instance }, - }); - - if (override_functions.get(name)) |dispatch_type| { + if (dispatch_override_functions.get(name)) |dispatch_type| { return dispatch_type; } switch (command.params[0].param_type) { .name => |first_param_type_name| { - if (device_handles.get(first_param_type_name)) |_| { - return .device; + if (dispatchable_handles.get(first_param_type_name)) |dispatch_type| { + return dispatch_type; } }, else => {}, } - return .instance; + return .base; } fn render(self: *Self) !void { @@ -478,6 +496,7 @@ fn Renderer(comptime WriterType: type) type { try self.renderFeatureInfo(); try self.renderExtensionInfo(); try self.renderWrappers(); + try self.renderProxies(); } fn renderApiConstant(self: *Self, api_constant: reg.ApiConstant) !void { @@ -1109,6 +1128,8 @@ fn Renderer(comptime WriterType: type) type { \\ ); // The commands in an extension are not pre-sorted based on if they are instance or device functions. + var base_commands = std.BufSet.init(self.allocator); + defer base_commands.deinit(); var instance_commands = std.BufSet.init(self.allocator); defer instance_commands.deinit(); var device_commands = std.BufSet.init(self.allocator); @@ -1131,8 +1152,9 @@ fn Renderer(comptime WriterType: type) type { }; const class = classifyCommandDispatch(command_name, command); switch (class) { - // Vulkan extensions cannot add base functions. - .base => return error.InvalidRegistry, + .base => { + try base_commands.insert(command_name); + }, .instance => { try instance_commands.insert(command_name); }, @@ -1143,6 +1165,10 @@ fn Renderer(comptime WriterType: type) type { } } // and write them out + try self.writer.writeAll(".base_commands = "); + try self.renderCommandFlags(&base_commands); + base_commands.hash_map.clearRetainingCapacity(); + try self.writer.writeAll(".instance_commands = "); try self.renderCommandFlags(&instance_commands); instance_commands.hash_map.clearRetainingCapacity(); @@ -1223,11 +1249,8 @@ fn Renderer(comptime WriterType: type) type { } fn renderWrappersOfDispatchType(self: *Self, dispatch_type: CommandDispatchType) !void { - const name, const name_lower = switch (dispatch_type) { - .base => .{ "Base", "base" }, - .instance => .{ "Instance", "instance" }, - .device => .{ "Device", "device" }, - }; + const name = dispatch_type.name(); + const name_lower = dispatch_type.nameLower(); try self.writer.print( \\pub const {0s}CommandFlags = packed struct {{ @@ -1264,7 +1287,7 @@ fn Renderer(comptime WriterType: type) type { }; if (classifyCommandDispatch(decl.name, command) == dispatch_type) { - try self.writer.writeAll((" " ** 8) ++ "."); + try self.writer.writeByte('.'); try self.writeIdentifierWithCase(.camel, trimVkNamespace(decl.name)); try self.writer.writeAll(" => "); try self.renderCommandPtrName(decl.name); @@ -1288,7 +1311,7 @@ fn Renderer(comptime WriterType: type) type { }; if (classifyCommandDispatch(decl.name, command) == dispatch_type) { - try self.writer.writeAll((" " ** 8) ++ "."); + try self.writer.writeByte('.'); try self.writeIdentifierWithCase(.camel, trimVkNamespace(decl.name)); try self.writer.print( \\ => "{s}", @@ -1416,6 +1439,123 @@ fn Renderer(comptime WriterType: type) type { , .{ .params = params, .first_arg = loader_first_arg }); } + fn renderProxies(self: *Self) !void { + try self.renderProxy(.instance, "VkInstance", null); + try self.renderProxy(.device, "VkDevice", null); + try self.renderProxy(.device, "VkCommandBuffer", "VkDevice"); + try self.renderProxy(.device, "VkQueue", "VkDevice"); + } + + fn renderProxy( + self: *Self, + dispatch_type: CommandDispatchType, + dispatch_handle: []const u8, + maybe_parent_dispatch_handle: ?[]const u8, + ) !void { + const loader_name = dispatch_type.name(); + + try self.writer.print( + \\pub fn {0s}Proxy(comptime apis: []const ApiInfo) type {{ + \\ return struct {{ + \\ const Self = @This(); + \\ pub const Wrapper = {1s}Wrapper(apis); + \\ + \\ handle: {0s}, + \\ wrapper: *const Wrapper, + \\ + \\ pub fn init(handle: {0s}, wrapper: *const Wrapper) error{{CommandLoadFailure}}!Self {{ + \\ return .{{ + \\ .handle = handle, + \\ .wrapper = wrapper, + \\ }}; + \\ }} + , .{ trimVkNamespace(dispatch_handle), loader_name }); + + for (self.registry.decls) |decl| { + const decl_type = self.resolveAlias(decl.decl_type) catch continue; + const command = switch (decl_type) { + .command => |cmd| cmd, + else => continue, + }; + + if (classifyCommandDispatch(decl.name, command) != dispatch_type) { + continue; + } + + switch (command.params[0].param_type) { + .name => |name| { + if (!mem.eql(u8, name, dispatch_handle)) { + // Also render queue/cmdBuf functions in the proxy of the device, for conveniece + if (maybe_parent_dispatch_handle) |parent_dispatch_handle| { + if (!mem.eql(u8, name, parent_dispatch_handle)) { + continue; + } + } else { + continue; + } + } + }, + else => continue, // Not a dispatchable handle + } + + try self.renderProxyCommand(decl.name, command); + } + + try self.writer.writeAll( + \\ }; + \\} + ); + } + + fn renderProxyCommand(self: *Self, name: []const u8, command: reg.Command) !void { + const returns_vk_result = command.return_type.* == .name and mem.eql(u8, command.return_type.name, "VkResult"); + const returns = try self.extractReturns(command); + + if (returns_vk_result) { + try self.writer.writeAll("pub const "); + try self.renderErrorSetName(name); + try self.writer.writeAll(" = Wrapper."); + try self.renderErrorSetName(name); + try self.writer.writeAll(";\n"); + } + + if (returns.len > 1) { + try self.writer.writeAll("pub const "); + try self.renderReturnStructName(name); + try self.writer.writeAll(" = Wrapper."); + try self.renderReturnStructName(name); + try self.writer.writeAll(";\n"); + } + + try self.renderWrapperPrototype(name, command, returns, .proxy); + + try self.writer.writeAll( + \\{ + \\return self.wrapper. + ); + try self.writeIdentifierWithCase(.camel, trimVkNamespace(name)); + try self.writer.writeByte('('); + + for (command.params) |param| { + switch (try self.classifyParam(param)) { + .out_pointer => continue, + .dispatch_handle => { + try self.writer.writeAll("self.handle"); + }, + else => { + try self.writeIdentifierWithCase(.snake, param.name); + }, + } + try self.writer.writeAll(", "); + } + + try self.writer.writeAll( + \\); + \\} + \\ + ); + } + fn derefName(name: []const u8) []const u8 { var it = id_render.SegmentIterator.init(name); return if (mem.eql(u8, it.next().?, "p")) @@ -1424,14 +1564,31 @@ fn Renderer(comptime WriterType: type) type { name; } - fn renderWrapperPrototype(self: *Self, name: []const u8, command: reg.Command, returns: []const ReturnValue) !void { + const WrapperKind = enum { + wrapper, + proxy, + }; + + fn renderWrapperPrototype( + self: *Self, + name: []const u8, + command: reg.Command, + returns: []const ReturnValue, + kind: WrapperKind, + ) !void { try self.writer.writeAll("pub fn "); try self.writeIdentifierWithCase(.camel, trimVkNamespace(name)); try self.writer.writeAll("(self: Self, "); for (command.params) |param| { + const class = try self.classifyParam(param); + // Skip the dispatch type for proxying wrappers + if (kind == .proxy and class == .dispatch_handle) { + continue; + } + // This parameter is returned instead. - if ((try self.classifyParam(param)) == .out_pointer) { + if (class == .out_pointer) { continue; } @@ -1479,7 +1636,7 @@ fn Renderer(comptime WriterType: type) type { try self.writeIdentifierWithCase(.snake, derefName(param.name)); } }, - .bitflags, .in_pointer, .in_out_pointer, .buffer_len, .mut_buffer_len, .other => { + else => { try self.writeIdentifierWithCase(.snake, param.name); }, } @@ -1570,7 +1727,7 @@ fn Renderer(comptime WriterType: type) type { try self.writer.writeAll(";\n"); } - try self.renderWrapperPrototype(name, command, returns); + try self.renderWrapperPrototype(name, command, returns, .wrapper); if (returns.len == 1 and returns[0].origin == .inner_return_value) { try self.writer.writeAll("{\n\n");