from datetime import date, datetime, timedelta from time import sleep from sqlalchemy import Column, Double, Integer, String, create_engine from sqlalchemy.orm import DeclarativeBase, Session from tushare import pro_api TUSHARE_API_KEY = '64ebff4fa679167600b905ee45dd88e76f3963c0ff39157f3f085f0e' class Base(DeclarativeBase): pass class Stock(Base): __tablename__ = 'stock' code = Column(String, primary_key=True, comment="代码") name = Column(String, comment="名称") fullname = Column(String, comment="全名") market = Column(String, comment="市场") exchange = Column(String, comment="交易所") industry = Column(String, comment="行业") list_date = Column(String, comment="上市日期") class Daily(Base): __tablename__ = 'daily' code = Column(String, primary_key=True) trade_date = Column(String, primary_key=True) open = Column(Double) close = Column(Double) high = Column(Double) low = Column(Double) previous_close = Column(Double) turnover = Column(Double) volume = Column(Integer) price_change_amount = Column(Double) factor = Column(Double) def main(): print("开始更新数据") engine = create_engine(f"sqlite:////Users/lanyuanxiaoyao/Documents/leopard_data/leopard.sqlite") try: Stock.metadata.create_all(engine, checkfirst=True) Daily.metadata.create_all(engine, checkfirst=True) pro = pro_api(TUSHARE_API_KEY) # with engine.connect() as connection: # stocks = pro.stock_basic(list_status="L", market="主板", fields="ts_code,name,fullname,market,exchange,industry,list_date") # for row in stocks.itertuples(): # stmt = insert(Stock).values( # code=row.ts_code, # name=row.name, # fullname=row.fullname, # market=row.market, # exchange=row.exchange, # industry=row.industry, # list_date=row.list_date, # ) # stmt = stmt.on_conflict_do_update( # index_elements=["code"], # set_={ # "name": stmt.excluded.name, # "fullname": stmt.excluded.fullname, # "market": stmt.excluded.market, # "exchange": stmt.excluded.exchange, # "industry": stmt.excluded.industry, # "list_date": stmt.excluded.list_date, # }, # ) # print(stmt) # connection.execute(stmt) # connection.commit() # # print("清理行情数据") # connection.execute(text("delete from daily where code not in (select distinct code from stock)")) # connection.commit() # # print("清理财务数据") # connection.execute(text("delete from finance_indicator where code not in (select distinct code from stock)")) # connection.commit() with Session(engine) as session: stock_codes = [row[0] for row in session.query(Stock.code).all()] latest_date = session.query(Daily.trade_date).order_by(Daily.trade_date.desc()).first() if latest_date is None: latest_date = '1990-12-19' else: latest_date = latest_date.trade_date latest_date = datetime.strptime(latest_date, '%Y-%m-%d').date() current_date = date.today() - timedelta(days=1) delta = (current_date - latest_date).days print(f"最新数据日期:{latest_date},当前日期:{current_date},待更新天数:{delta}") if delta > 0: update_dates = [] for i in range(delta): latest_date = latest_date + timedelta(days=1) update_dates.append(latest_date.strftime('%Y%m%d')) for target_date in update_dates: print(f"正在采集:{target_date}") dailies = pro.daily(trade_date=target_date) dailies.set_index("ts_code", inplace=True) factors = pro.adj_factor(trade_date=target_date) factors.set_index("ts_code", inplace=True) results = dailies.join(factors, lsuffix="_daily", rsuffix="_factor", how="left") rows = [] for row in results.itertuples(): if row.Index in stock_codes: rows.append( Daily( code=row.Index, trade_date=datetime.strptime(target_date, '%Y%m%d').strftime("%Y-%m-%d"), open=row.open, close=row.close, high=row.high, low=row.low, previous_close=row.pre_close, turnover=row.amount, volume=row.vol, price_change_amount=row.pct_chg, factor=row.adj_factor, ) ) session.add_all(rows) session.commit() sleep(1) finally: engine.dispose() if __name__ == '__main__': main()