From e297865e93abd0603dc82d6762350dc31068fc37 Mon Sep 17 00:00:00 2001 From: David Allemang Date: Wed, 10 Jul 2024 12:25:39 -0400 Subject: [PATCH] better hook type --- src/nu/hooks.zig | 174 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 src/nu/hooks.zig diff --git a/src/nu/hooks.zig b/src/nu/hooks.zig new file mode 100644 index 0000000..79debaf --- /dev/null +++ b/src/nu/hooks.zig @@ -0,0 +1,174 @@ +const std = @import("std"); + +pub fn Hook(ftype: type) type { + const F: std.builtin.Type.Fn = @typeInfo(ftype).Fn; + const Result: type = F.return_type.?; + + return struct { + const Self = @This(); + + handlers: std.AutoArrayHashMap(*const ftype, void), + + pub fn init(alloc: std.mem.Allocator) Self { + return Self{ + .handlers = std.AutoArrayHashMap(*const ftype, void).init(alloc), + }; + } + + pub fn deinit(self: *Self) void { + self.handlers.deinit(); + } + + pub fn register(self: *Self, f: ftype) !void { + try self.handlers.putNoClobber(f, {}); + } + + pub fn unregister(self: *Self, f: ftype) void { + _ = self.handlers.orderedRemove(f); + } + + fn invoke_alloc_results(self: Self, alloc: std.mem.Allocator, args: anytype) ![]Result { + const results = try alloc.alloc(Result, self.handlers.count()); + for (self.handlers.keys(), results) |handler, *result| { + result.* = @call(.auto, handler, args); + } + return results; + } + + fn invoke_void(self: Self, args: anytype) void { + for (self.handlers.keys()) |handler| { + @call(.auto, handler, args); + } + } + + pub const invoke = switch (@typeInfo(Result)) { + .Void => invoke_void, + else => invoke_alloc_results, + }; + }; +} + +test "void hooks" { + const hooks = struct { + pub fn set_one(f: *usize) void { + f.* |= 0b01; + } + + pub fn set_two(f: *usize) void { + f.* |= 0b10; + } + }; + + var set_flags = Hook(fn (*usize) void).init(std.testing.allocator); + defer set_flags.deinit(); + + var flag: usize = undefined; + + flag = 0; + set_flags.invoke(.{&flag}); + try std.testing.expect(flag == 0b00); + + try set_flags.register(hooks.set_one); + + flag = 0; + set_flags.invoke(.{&flag}); + try std.testing.expect(flag == 0b01); + + try set_flags.register(hooks.set_two); + + flag = 0; + set_flags.invoke(.{&flag}); + try std.testing.expect(flag == 0b11); + + set_flags.unregister(hooks.set_one); + + flag = 0; + set_flags.invoke(.{&flag}); + try std.testing.expect(flag == 0b10); + + set_flags.unregister(hooks.set_two); + + flag = 0; + set_flags.invoke(.{&flag}); + try std.testing.expect(flag == 0b00); +} + +test "collect hooks" { + const hooks = struct { + pub fn double(f: usize) usize { + return f * 2; + } + + pub fn square(f: usize) usize { + return f * f; + } + }; + + var collect = Hook(fn (usize) usize).init(std.testing.allocator); + defer collect.deinit(); + + { + const result = try collect.invoke(std.testing.allocator, .{3}); + defer std.testing.allocator.free(result); + try std.testing.expectEqualSlices(usize, &.{}, result); + } + + try collect.register(hooks.double); + + { + const result = try collect.invoke(std.testing.allocator, .{4}); + defer std.testing.allocator.free(result); + try std.testing.expectEqualSlices(usize, &.{8}, result); + } + + try collect.register(hooks.square); + + { + const result = try collect.invoke(std.testing.allocator, .{5}); + defer std.testing.allocator.free(result); + try std.testing.expectEqualSlices(usize, &.{ 10, 25 }, result); + } + + collect.unregister(hooks.double); + + { + const result = try collect.invoke(std.testing.allocator, .{6}); + defer std.testing.allocator.free(result); + try std.testing.expectEqualSlices(usize, &.{36}, result); + } + + collect.unregister(hooks.square); + + { + const result = try collect.invoke(std.testing.allocator, .{7}); + defer std.testing.allocator.free(result); + try std.testing.expectEqualSlices(usize, &.{}, result); + } +} + +test "error_hooks" { + const CollectError = error{Fail}; + const Collect = Hook(fn (usize) CollectError!usize); + var collect = Collect.init(std.testing.allocator); + defer collect.deinit(); + + const hooks = struct { + pub fn halve(f: usize) !usize { + if (f % 2 == 0) return f / 2; + return CollectError.Fail; + } + + pub fn third(f: usize) !usize { + if (f % 3 == 0) return f / 3; + return CollectError.Fail; + } + }; + + try collect.register(hooks.halve); + try collect.register(hooks.third); + + const result = try collect.invoke(std.testing.allocator, .{4}); + defer std.testing.allocator.free(result); + try std.testing.expectEqual(2, try result[0]); + try std.testing.expectError(CollectError.Fail, result[1]); +}