Advertisement
Guest User

Untitled

a guest
Dec 8th, 2019
116
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 2.36 KB | None | 0 0
  1. const std = @import("std");
  2.  
  3. pub fn Constant(comptime T: type) type {
  4. return struct {
  5. value: T,
  6. };
  7. }
  8.  
  9. pub fn Operation(comptime T: type) type {
  10. return struct {
  11. left: Tensor(T),
  12. right: Tensor(T),
  13. };
  14. }
  15.  
  16. pub fn Tensor(comptime T: type) type {
  17. return union(enum) {
  18. constant: u64,
  19. operation: u64,
  20. };
  21. }
  22.  
  23. pub fn Graph(comptime T: type) type {
  24. return struct {
  25. constants: std.ArrayList(Constant(T)),
  26. operations: std.ArrayList(Operation(T)),
  27.  
  28. pub const elementType: type = T;
  29.  
  30. pub fn init(allocator: *std.mem.Allocator) Graph(T) {
  31. return .{
  32. .constants = std.ArrayList(Constant(T)).init(allocator),
  33. .operations = std.ArrayList(Operation(T)).init(allocator),
  34. };
  35. }
  36. };
  37. }
  38.  
  39. pub fn constant(graph: var, value: var) !Tensor(@typeOf(graph.*).elementType) {
  40. const T = @typeOf(graph.*).elementType;
  41. const c = try graph.constants.addOne();
  42. c.* = .{ .value = value };
  43. return Tensor(T){ .constant = graph.constants.count() - 1 };
  44. }
  45.  
  46. test "constant" {
  47. var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
  48. defer arena.deinit();
  49. const allocator = &arena.allocator;
  50.  
  51. var graph = Graph(f64).init(allocator);
  52. const x = try constant(&graph, 5);
  53. const y = try constant(&graph, 10);
  54. std.testing.expectEqual(graph.constants.at(x.constant).value, 5);
  55. std.testing.expectEqual(graph.constants.at(y.constant).value, 10);
  56. }
  57.  
  58. pub fn add(graph: var, x: var, y: @typeOf(x)) !@typeOf(x) {
  59. const T = @typeOf(graph.*).elementType;
  60. const o = try graph.operations.addOne();
  61. o.* = .{ .left = x, .right = y };
  62. return Tensor(T){ .operation = graph.operations.count() - 1 };
  63. }
  64.  
  65. test "add" {
  66. var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
  67. defer arena.deinit();
  68. const allocator = &arena.allocator;
  69.  
  70. var graph = Graph(f64).init(allocator);
  71. const x = try constant(&graph, 5);
  72. const y = try constant(&graph, 10);
  73. const z = try add(&graph, x, y);
  74. std.testing.expectEqual(graph.constants.at(x.constant), graph.constants.at(graph.operations.at(z.operation).left.constant));
  75. std.testing.expectEqual(graph.constants.at(y.constant), graph.constants.at(graph.operations.at(z.operation).right.constant));
  76. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement