Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- const std = @import("std");
- pub fn Constant(comptime T: type) type {
- return struct {
- value: T,
- };
- }
- pub fn Operation(comptime T: type) type {
- return struct {
- left: Tensor(T),
- right: Tensor(T),
- };
- }
- pub fn Tensor(comptime T: type) type {
- return union(enum) {
- constant: *const Constant(T),
- operation: *const Operation(T),
- };
- }
- pub fn Graph(comptime T: type) type {
- return struct {
- constants: std.ArrayList(Constant(T)),
- operations: std.ArrayList(Operation(T)),
- pub const elementType: type = T;
- pub fn init(allocator: *std.mem.Allocator) Graph(T) {
- return .{
- .constants = std.ArrayList(Constant(T)).init(allocator),
- .operations = std.ArrayList(Operation(T)).init(allocator),
- };
- }
- };
- }
- pub fn constant(graph: var, value: var) !Tensor(@typeOf(graph.*).elementType) {
- const T = @typeOf(graph.*).elementType;
- const c = Constant(T){ .value = value };
- try graph.constants.append(c);
- return Tensor(T){ .constant = &c };
- }
- test "constant" {
- var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
- defer arena.deinit();
- const allocator = &arena.allocator;
- var graph = Graph(f64).init(allocator);
- const x = try constant(&graph, 5);
- const y = try constant(&graph, 10);
- std.testing.expectEqual(x.constant.value, 5);
- std.testing.expectEqual(y.constant.value, 10);
- }
- pub fn add(graph: var, x: var, y: @typeOf(x)) !@typeOf(x) {
- const T = @typeOf(graph.*).elementType;
- const o = Operation(T){ .left = x, .right = y };
- try graph.operations.append(o);
- return Tensor(T){ .operation = &o };
- }
- test "add" {
- var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
- defer arena.deinit();
- const allocator = &arena.allocator;
- var graph = Graph(f64).init(allocator);
- const x = try constant(&graph, 5);
- const y = try constant(&graph, 10);
- const z = try add(&graph, x, y);
- std.testing.expectEqual(x.constant, z.operation.left.constant);
- std.testing.expectEqual(y.constant, z.operation.right.constant);
- }
Add Comment
Please, Sign In to add comment