Advertisement
Guest User

Untitled

a guest
Dec 4th, 2019
186
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.61 KB | None | 0 0
  1. const std = @import("std");
  2.  
  3. pub fn Tensor(comptime T: type) type {
  4. return struct {
  5. shape: []const u64,
  6. };
  7. }
  8.  
  9. pub fn Constant(comptime T: type) type {
  10. return struct {
  11. value: T,
  12. };
  13. }
  14.  
  15. pub fn Operation(comptime T: type) type {
  16. return struct {
  17. left: Tensor(T),
  18. right: Tensor(T),
  19. };
  20. }
  21.  
  22. pub fn Graph(comptime T: type) type {
  23. return struct {
  24. constants: std.ArrayList(Constant(T)),
  25. operations: std.ArrayList(Operation(T)),
  26.  
  27. pub const elementType: type = T;
  28.  
  29. pub fn init(allocator: *std.mem.Allocator) Graph(T) {
  30. return .{
  31. .constants = std.ArrayList(Constant(T)).init(allocator),
  32. .operations = std.ArrayList(Operation(T)).init(allocator),
  33. };
  34. }
  35. };
  36. }
  37.  
  38. pub fn constant(graph: var, value: var) !Tensor(@typeOf(graph.*).elementType) {
  39. const T = @typeOf(graph.*).elementType;
  40. try graph.constants.append(.{ .value = value });
  41. return Tensor(T){ .shape = &[_]u64{} };
  42. }
  43.  
  44. pub fn add(graph: var, x: var, y: @typeOf(x)) !@typeOf(x) {
  45. try graph.operations.append(.{ .left = x, .right = y });
  46. return @typeOf(x){ .shape = &[_]u64{} };
  47. }
  48.  
  49. test "create graph" {
  50. var arena = std.heap.ArenaAllocator.init(std.heap.page_allocator);
  51. defer arena.deinit();
  52. const allocator = &arena.allocator;
  53.  
  54. var graph = Graph(f64).init(allocator);
  55. comptime std.testing.expectEqual(@typeOf(graph).elementType, f64);
  56. const x = try constant(&graph, 5);
  57. const y = try constant(&graph, 10);
  58. const z = try add(&graph, x, y);
  59. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement