SHARE
TWEET

Untitled

a guest Dec 8th, 2019 77 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. const std = @import("std");
  2.  
  3. pub fn Constant(comptime T: type) type {
  4.     return struct {
  5.         value: T,
  6.     };
  7. }
  8.  
  9. pub fn Operation(comptime T: type) type {
  10.     return struct {
  11.         left: Tensor(T),
  12.         right: Tensor(T),
  13.     };
  14. }
  15.  
  16. pub fn Tensor(comptime T: type) type {
  17.     return union(enum) {
  18.         constant: u64,
  19.         operation: u64,
  20.     };
  21. }
  22.  
  23. pub fn Graph(comptime T: type) type {
  24.     return struct {
  25.         constants: std.ArrayList(Constant(T)),
  26.         operations: std.ArrayList(Operation(T)),
  27.  
  28.         pub const elementType: type = T;
  29.  
  30.         pub fn init(allocator: *std.mem.Allocator) Graph(T) {
  31.             return .{
  32.                 .constants = std.ArrayList(Constant(T)).init(allocator),
  33.                 .operations = std.ArrayList(Operation(T)).init(allocator),
  34.             };
  35.         }
  36.     };
  37. }
  38.  
  39. pub fn constant(graph: var, value: var) !Tensor(@typeOf(graph.*).elementType) {
  40.     const T = @typeOf(graph.*).elementType;
  41.     const c = try graph.constants.addOne();
  42.     c.* = .{ .value = value };
  43.     return Tensor(T){ .constant = graph.constants.count() - 1 };
  44. }
  45.  
  46. test "constant" {
  47.     var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
  48.     defer arena.deinit();
  49.     const allocator = &arena.allocator;
  50.  
  51.     var graph = Graph(f64).init(allocator);
  52.     const x = try constant(&graph, 5);
  53.     const y = try constant(&graph, 10);
  54.     std.testing.expectEqual(graph.constants.at(x.constant).value, 5);
  55.     std.testing.expectEqual(graph.constants.at(y.constant).value, 10);
  56. }
  57.  
  58. pub fn add(graph: var, x: var, y: @typeOf(x)) !@typeOf(x) {
  59.     const T = @typeOf(graph.*).elementType;
  60.     const o = try graph.operations.addOne();
  61.     o.* = .{ .left = x, .right = y };
  62.     return Tensor(T){ .operation = graph.operations.count() - 1 };
  63. }
  64.  
  65. test "add" {
  66.     var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
  67.     defer arena.deinit();
  68.     const allocator = &arena.allocator;
  69.  
  70.     var graph = Graph(f64).init(allocator);
  71.     const x = try constant(&graph, 5);
  72.     const y = try constant(&graph, 10);
  73.     const z = try add(&graph, x, y);
  74.     std.testing.expectEqual(graph.constants.at(x.constant), graph.constants.at(graph.operations.at(z.operation).left.constant));
  75.     std.testing.expectEqual(graph.constants.at(y.constant), graph.constants.at(graph.operations.at(z.operation).right.constant));
  76. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top