Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # -*- coding: utf-8 -*-
- '''
- edgelistとlabellistを0始まりのIDに変換する
- '''
- import os
- import sys
- #from snlocest.largedict import LargeDict
- def parse_args():
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument('mode', help='処理のモード')
- parser.add_argument('--tablepath', required=True)
- parser.add_argument('inputfiles', nargs='+')
- return parser.parse_args()
- def load_table(filepath):
- user2id = dict() #LargeDict()
- with open(filepath, 'r') as fd:
- for line in fd:
- user_id, idx = line.rstrip().split('\t')
- user2id[user_id] = int(idx)
- idcnt = max(user2id.values()) + 1
- return user2id, idcnt
- if __name__ == '__main__':
- args = parse_args()
- if args.mode == 'table':
- # Generate convert table from inputfiles and save to tablepath
- if os.path.exists(args.tablepath):
- print('Table path already exists', file=sys.stderr)
- sys.exit(1)
- # create table
- user2id = dict() #LargeDict()
- idcnt = 0
- for filepath in args.inputfiles:
- with open(filepath, 'r') as inputfile:
- for line in inputfile:
- tokens = line.rstrip().split('\t')
- user_id = tokens[0]
- if user_id not in user2id:
- user2id[user_id] = idcnt
- idcnt += 1
- user_id = tokens[1]
- if user_id not in user2id:
- user2id[user_id] = idcnt
- idcnt += 1
- # save table
- with open(args.tablepath, 'w') as fd:
- for k, v in user2id.items():
- print(k, v, sep='\t', file=fd)
- elif args.mode == 'edgelist':
- # Convert edgelist using table
- user2id, idcnt = load_table(args.tablepath)
- for filepath in args.inputfiles:
- with open(filepath, 'r') as fd:
- for line in fd:
- row = line.rstrip().split('\t')
- src = row[0]
- dst = row[1]
- print(user2id[src], user2id[dst], *row[2:], sep='\t')
- elif args.mode == 'label':
- # Convert label file
- user2id, idcnt = load_table(args.tablepath)
- for filepath in args.inputfiles:
- with open(filepath, 'r') as fd:
- for line in fd:
- row = line.rstrip().split('\t')
- label = row[0]
- print(user2id[label], *row[1:], sep='\t')
- else:
- print('Invalid mode. Choose "table" or "edgelist" or "label"', file=sys.stderr)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement