加载中…
个人资料
风之力量007
风之力量007
  • 博客等级:
  • 博客积分:0
  • 博客访问:17,663
  • 关注人气:8
  • 获赠金笔:0支
  • 赠出金笔:0支
  • 荣誉徽章:
相关博文
推荐博文
谁看过这篇博文
加载中…
正文 字体大小:

内存数据库简单实现 - python版

(2009-09-14 16:53:01)
标签:

杂谈

以下是Python代码实现(添加了事务回滚功能):

#!/usr/bin/python

import sys, struct, threading, time

## sync for threading
class ObjectSync:
    def __init__(self,name):
        self.refcount = 0
        self.synclock = threading.RLock()
        self.keyname  = name
       
    def Lock(self):
        self.synclock.acquire()
        self.refcount = self.refcount + 1
       
    def Unlock(self):
        self.refcount = self.refcount - 1
        self.synclock.release()
       
class ObjectSyncFactory:
    def __init__(self):
        self.globalLock = ObjectSync("")
        self.rowlocks = {}
       
    def __RemoveLock(self,sem,keyname):
        self.globalLock.Lock()

        self.rowlocks[keyname] = None

        self.globalLock.Unlock()
   
    def GetLock(self,tablename,key):
        keyname = tablename + "," + str(key)
       
        self.globalLock.Lock()

        l = None
        try:
            l = self.rowlocks[keyname]
            if l == None:
                self.rowlocks[keyname] = ObjectSync(keyname)
                l = self.rowlocks[keyname]
        except:
            self.rowlocks[keyname] = ObjectSync(keyname)
            l = self.rowlocks[keyname]

        self.globalLock.Unlock()
       
        return l

class PairGuard:
    def __init__(self,factory,sem):
        self.syncfactory = factory
        self.host = sem
        self.host.Lock()
       
    def __del__(self):
        self.host.Unlock()
        if self.host.refcount == 0 :
            self.syncfactory._ObjectSyncFactory__RemoveLock(self.host,self.host.keyname)

## Database table
class MemTable:
    def __init__(self):
        self.rows = {}
        self.tableLock = ObjectSync("")
       
    def GetRowCount(self):
        return len(self.rows)
    def DeleteAll(self):
        self.tableLock.Lock()
        self.rows = {}
        self.tableLock.Unlock()
    def __DeleteAll(self):
        self.rows = {}
    def GetAllValue(self):
        return self.rows
   
    #########################
    def GetValue(self,key): # throw keyerror if key not found.
        return self.rows[key]
    def AddValue(self,key,value): # if not exist:Add;if exist:Update
        self.tableLock.Lock()
        self.rows[key] = value
        self.tableLock.Unlock()
    def __AddValue(self,key,value):
        self.rows[key] = value
    def DelValue(self,key):
        self.AddValue(key,None)
    def __DelValue(self,key):
        self._MemTable__AddValue(key,None)

##
class MemDB:
    def __init__(self):
        self.tables = {}
        self.syncFactory = ObjectSyncFactory()

    def CreateTable(self,tablename): # is not thread safed
        self.tables[tablename] = MemTable()
    def DropTable(self,tablename): # is not thread safed
        self.tables[tablename] = None

    def GetValue(self,tablename,key):
        mt = self.tables[tablename]
        PairGuard(self.syncFactory,self.syncFactory.GetLock(tablename,key))
        return mt.GetValue(key)
    def AddValue(self,tablename,key,value):
        mt = self.tables[tablename]
        PairGuard(self.syncFactory,self.syncFactory.GetLock(tablename,key))
        mt.AddValue(key,value)
    def DelValue(self,tablename,key):
        mt = self.tables[tablename]
        PairGuard(self.syncFactory,self.syncFactory.GetLock(tablename,key))
        mt.DelValue(key)
       
    def __GetValue(self,tablename,key):
        mt = self.tables[tablename]
        return mt.GetValue(key)
    def __AddValue(self,tablename,key,value):
        mt = self.tables[tablename]
        mt._MemTable__AddValue(key,value)
    def __DelValue(self,tablename,key):
        mt = self.tables[tablename]
        mt._MemTable__DelValue(key)
       
class Transaction:
    def __init__(self,conn):
        self.dbconn = conn
        self.logs = []
    def Commit(self):
        syncs = []
        tables = {}
        for p in self.logs:
            tables[p[0]] = True
        for name in tables:
            syncTable = self.dbconn.memdb.syncFactory.GetLock(name,'table')
            syncs.append( (syncTable.keyname,syncTable) )
        syncs.sort()
       
        #lock
        guards = []
        for sync in syncs:
            guards.append(PairGuard(self.dbconn.memdb.syncFactory,sync[1]))
       
        #commit
        self.logs.reverse()
        while True:
            if len(self.logs) == 0:
                break
            p = self.logs.pop()
            self.dbconn.memdb._MemDB__AddValue(p[0],p[1],p[2])
           
        #unlock
        guards.reverse()
        while True:
            if len(guards) == 0:
                break
            guards.pop()
           
        self.dbconn._MemDBConnect__EndTransaction()
    def Rollback(self):
        self.dbconn._MemDBConnect__EndTransaction()
    def LogPoint(self,tablename,key,value):
        self.logs.append((tablename,key,value))
       
class MemDBConnect:
    def __init__(self,db):
        self.memdb = db
        self.bTransaction = False
        self.trans = None
    def BeginTransaction(self):
        self.bTransaction = True
        self.trans = Transaction(self)
        return self.trans
    def __EndTransaction(self):
        self.bTransaction = False
        self.trans = None
       
    def GetValue(self,tablename,key):
        if self.bTransaction:
            return self.memdb._MemDB__GetValue(tablename,key)
        else:
            return self.memdb.GetValue(tablename,key)
    def AddValue(self,tablename,key,value):
        if self.bTransaction:
            self.trans.LogPoint(tablename,key,value)
            #self.memdb._MemDB__AddValue(tablename,key,value)
        else:
            self.memdb.AddValue(tablename,key,value)
    def DelValue(self,tablename,key):
        if self.bTransaction:
            self.trans.LogPoint(tablename,key,None)
            #self.memdb._MemDB__DelValue(tablename,key)
        else:
            self.memdb.DelValue(tablename,key)
       
if __name__ == '__main__': # test db

    print time.localtime()
   
    db = MemDB()
    tname = "table1"
    db.CreateTable(tname)
    #for i in range(100000):
      db.AddValue(tname,i,"sdfsd")
    db.AddValue(tname,11,"sdfsd")
    #print db.GetValue(tname,11)
    db.AddValue(tname,11,"dddddd")
    #print db.GetValue(tname,11)
    db.AddValue(tname,12,"dsfdsfd")
    #print db.GetValue(tname,12)
   
    conn = MemDBConnect(db)
    t = conn.BeginTransaction()
    for i in range(100000):
        conn.AddValue(tname,i,"sdfsd")
    conn.AddValue(tname,12,"sfdas")
    conn.AddValue(tname,12,"ddddd")
    #print db.GetValue(tname,12)
    t.Commit()
    #print db.GetValue(tname,12)
   
    print time.localtime()

0

阅读 评论 收藏 转载 喜欢 打印举报/Report
  • 评论加载中,请稍候...
发评论

    发评论

    以上网友发言只代表其个人观点,不代表新浪网的观点或立场。

      

    新浪BLOG意见反馈留言板 电话:4000520066 提示音后按1键(按当地市话标准计费) 欢迎批评指正

    新浪简介 | About Sina | 广告服务 | 联系我们 | 招聘信息 | 网站律师 | SINA English | 会员注册 | 产品答疑

    新浪公司 版权所有