Upload 6 files
Browse files- agent.py +368 -0
- config_generator.py +272 -0
- crypto.py +105 -0
- environment.py +214 -0
- reward.py +54 -0
- schemas.py +182 -0
agent.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import json
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch .nn as nn
|
| 10 |
+
import torch .nn .functional as F
|
| 11 |
+
from torch .distributions import Categorical ,Normal
|
| 12 |
+
from typing import Optional ,Tuple ,List
|
| 13 |
+
|
| 14 |
+
from server .rl .environment import (
|
| 15 |
+
DISCRETE_NVEC ,
|
| 16 |
+
N_CONTINUOUS ,
|
| 17 |
+
TOTAL_OBS_DIM ,
|
| 18 |
+
AlphaBypassEnv ,
|
| 19 |
+
)
|
| 20 |
+
from server .rl .reward import reward_to_label
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class PolicyNetwork (nn .Module ):
|
| 28 |
+
|
| 29 |
+
def __init__ (
|
| 30 |
+
self ,
|
| 31 |
+
obs_dim :int =TOTAL_OBS_DIM ,
|
| 32 |
+
hidden :int =512 ,
|
| 33 |
+
discrete_nvec :List [int ]=DISCRETE_NVEC ,
|
| 34 |
+
n_continuous :int =N_CONTINUOUS ,
|
| 35 |
+
):
|
| 36 |
+
super ().__init__ ()
|
| 37 |
+
self .discrete_nvec =discrete_nvec
|
| 38 |
+
self .n_continuous =n_continuous
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
self .trunk =nn .Sequential (
|
| 42 |
+
nn .Linear (obs_dim ,hidden ),
|
| 43 |
+
nn .LayerNorm (hidden ),
|
| 44 |
+
nn .ReLU (),
|
| 45 |
+
nn .Linear (hidden ,hidden ),
|
| 46 |
+
nn .LayerNorm (hidden ),
|
| 47 |
+
nn .ReLU (),
|
| 48 |
+
nn .Linear (hidden ,hidden ),
|
| 49 |
+
nn .LayerNorm (hidden ),
|
| 50 |
+
nn .ReLU (),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
self .discrete_heads =nn .ModuleList ([
|
| 55 |
+
nn .Linear (hidden ,n )for n in discrete_nvec
|
| 56 |
+
])
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
self .cont_mu =nn .Linear (hidden ,n_continuous )
|
| 60 |
+
self .cont_log_std =nn .Parameter (torch .zeros (n_continuous ))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
self .value_head =nn .Sequential (
|
| 64 |
+
nn .Linear (hidden ,256 ),
|
| 65 |
+
nn .ReLU (),
|
| 66 |
+
nn .Linear (256 ,1 ),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def forward (self ,obs :torch .Tensor ):
|
| 70 |
+
h =self .trunk (obs )
|
| 71 |
+
logits =[head (h )for head in self .discrete_heads ]
|
| 72 |
+
mu =torch .sigmoid (self .cont_mu (h ))
|
| 73 |
+
log_std =self .cont_log_std .clamp (-4 ,0 )
|
| 74 |
+
value =self .value_head (h ).squeeze (-1 )
|
| 75 |
+
return logits ,mu ,log_std ,value
|
| 76 |
+
|
| 77 |
+
def get_action_and_log_prob (
|
| 78 |
+
self ,
|
| 79 |
+
obs :torch .Tensor ,
|
| 80 |
+
action_masks :Optional [List [Optional [torch .Tensor ]]]=None ,
|
| 81 |
+
)->Tuple [np .ndarray ,np .ndarray ,torch .Tensor ,torch .Tensor ]:
|
| 82 |
+
|
| 83 |
+
logits ,mu ,log_std ,value =self .forward (obs )
|
| 84 |
+
|
| 85 |
+
discrete_actions =[]
|
| 86 |
+
log_probs_discrete =[]
|
| 87 |
+
|
| 88 |
+
for i ,(lg ,n )in enumerate (zip (logits ,self .discrete_nvec )):
|
| 89 |
+
if action_masks and action_masks [i ]is not None :
|
| 90 |
+
|
| 91 |
+
mask =action_masks [i ].to (lg .device )
|
| 92 |
+
lg =lg .masked_fill (~mask ,float ("-inf"))
|
| 93 |
+
dist =Categorical (logits =lg )
|
| 94 |
+
a =dist .sample ()
|
| 95 |
+
discrete_actions .append (a .item ())
|
| 96 |
+
log_probs_discrete .append (dist .log_prob (a ))
|
| 97 |
+
|
| 98 |
+
log_prob_discrete =torch .stack (log_probs_discrete ).sum ()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
std =log_std .exp ()
|
| 102 |
+
dist_cont =Normal (mu ,std )
|
| 103 |
+
cont_sample =dist_cont .sample ()
|
| 104 |
+
cont_action =cont_sample .clamp (0.0 ,1.0 )
|
| 105 |
+
log_prob_cont =dist_cont .log_prob (cont_sample ).sum ()
|
| 106 |
+
|
| 107 |
+
total_log_prob =log_prob_discrete +log_prob_cont
|
| 108 |
+
|
| 109 |
+
return (
|
| 110 |
+
np .array (discrete_actions ,dtype =np .int32 ),
|
| 111 |
+
cont_action .detach ().cpu ().numpy (),
|
| 112 |
+
total_log_prob ,
|
| 113 |
+
value ,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
def evaluate_actions (
|
| 117 |
+
self ,
|
| 118 |
+
obs :torch .Tensor ,
|
| 119 |
+
discrete_actions :torch .Tensor ,
|
| 120 |
+
cont_actions :torch .Tensor ,
|
| 121 |
+
)->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]:
|
| 122 |
+
logits ,mu ,log_std ,value =self .forward (obs )
|
| 123 |
+
|
| 124 |
+
log_prob_d =torch .zeros (obs .shape [0 ],device =obs .device )
|
| 125 |
+
entropy_d =torch .zeros (obs .shape [0 ],device =obs .device )
|
| 126 |
+
for i ,lg in enumerate (logits ):
|
| 127 |
+
dist =Categorical (logits =lg )
|
| 128 |
+
log_prob_d +=dist .log_prob (discrete_actions [:,i ])
|
| 129 |
+
entropy_d +=dist .entropy ()
|
| 130 |
+
|
| 131 |
+
std =log_std .exp ()
|
| 132 |
+
dist_c =Normal (mu ,std )
|
| 133 |
+
log_prob_c =dist_c .log_prob (cont_actions ).sum (-1 )
|
| 134 |
+
entropy_c =dist_c .entropy ().sum (-1 )
|
| 135 |
+
|
| 136 |
+
return log_prob_d +log_prob_c ,(entropy_d +entropy_c )/2 ,value
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class RolloutBuffer :
|
| 144 |
+
def __init__ (self ):
|
| 145 |
+
self .clear ()
|
| 146 |
+
|
| 147 |
+
def clear (self ):
|
| 148 |
+
self .obs :List [np .ndarray ]=[]
|
| 149 |
+
self .discrete_actions :List [np .ndarray ]=[]
|
| 150 |
+
self .cont_actions :List [np .ndarray ]=[]
|
| 151 |
+
self .rewards :List [float ]=[]
|
| 152 |
+
self .log_probs :List [torch .Tensor ]=[]
|
| 153 |
+
self .values :List [torch .Tensor ]=[]
|
| 154 |
+
self .dones :List [bool ]=[]
|
| 155 |
+
|
| 156 |
+
def add (self ,obs ,d_action ,c_action ,reward ,log_prob ,value ,done ):
|
| 157 |
+
self .obs .append (obs )
|
| 158 |
+
self .discrete_actions .append (d_action )
|
| 159 |
+
self .cont_actions .append (c_action )
|
| 160 |
+
self .rewards .append (reward )
|
| 161 |
+
self .log_probs .append (log_prob )
|
| 162 |
+
self .values .append (value )
|
| 163 |
+
self .dones .append (done )
|
| 164 |
+
|
| 165 |
+
def compute_returns (self ,last_value :float ,gamma :float =0.99 ,gae_lambda :float =0.95 ):
|
| 166 |
+
advantages =[]
|
| 167 |
+
gae =0.0
|
| 168 |
+
values =[v .item ()for v in self .values ]+[last_value ]
|
| 169 |
+
|
| 170 |
+
for t in reversed (range (len (self .rewards ))):
|
| 171 |
+
delta =self .rewards [t ]+gamma *values [t +1 ]*(1 -self .dones [t ])-values [t ]
|
| 172 |
+
gae =delta +gamma *gae_lambda *(1 -self .dones [t ])*gae
|
| 173 |
+
advantages .insert (0 ,gae )
|
| 174 |
+
|
| 175 |
+
returns =[a +v .item ()for a ,v in zip (advantages ,self .values )]
|
| 176 |
+
return advantages ,returns
|
| 177 |
+
|
| 178 |
+
def to_tensors (self ,device :torch .device ):
|
| 179 |
+
obs =torch .FloatTensor (np .stack (self .obs )).to (device )
|
| 180 |
+
d_act =torch .LongTensor (np .stack (self .discrete_actions )).to (device )
|
| 181 |
+
c_act =torch .FloatTensor (np .stack (self .cont_actions )).to (device )
|
| 182 |
+
return obs ,d_act ,c_act
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class PPOTrainer :
|
| 190 |
+
def __init__ (
|
| 191 |
+
self ,
|
| 192 |
+
env :AlphaBypassEnv ,
|
| 193 |
+
device_str :str ="cuda",
|
| 194 |
+
lr :float =3e-4 ,
|
| 195 |
+
gamma :float =0.99 ,
|
| 196 |
+
gae_lambda :float =0.95 ,
|
| 197 |
+
clip_eps :float =0.2 ,
|
| 198 |
+
entropy_coef :float =0.01 ,
|
| 199 |
+
vf_coef :float =0.5 ,
|
| 200 |
+
max_grad_norm :float =0.5 ,
|
| 201 |
+
update_epochs :int =4 ,
|
| 202 |
+
steps_per_update :int =8 ,
|
| 203 |
+
checkpoint_dir :str ="checkpoints",
|
| 204 |
+
checkpoint_every :int =100 ,
|
| 205 |
+
):
|
| 206 |
+
self .env =env
|
| 207 |
+
self .device =torch .device (device_str if torch .cuda .is_available ()else "cpu")
|
| 208 |
+
print (f"[PPO] device: {self .device }")
|
| 209 |
+
|
| 210 |
+
self .policy =PolicyNetwork ().to (self .device )
|
| 211 |
+
self .optimizer =torch .optim .Adam (self .policy .parameters (),lr =lr )
|
| 212 |
+
self .scheduler =torch .optim .lr_scheduler .ExponentialLR (self .optimizer ,gamma =0.999 )
|
| 213 |
+
|
| 214 |
+
self .gamma =gamma
|
| 215 |
+
self .gae_lambda =gae_lambda
|
| 216 |
+
self .clip_eps =clip_eps
|
| 217 |
+
self .entropy_coef =entropy_coef
|
| 218 |
+
self .vf_coef =vf_coef
|
| 219 |
+
self .max_grad_norm =max_grad_norm
|
| 220 |
+
self .update_epochs =update_epochs
|
| 221 |
+
self .steps_per_update =steps_per_update
|
| 222 |
+
self .checkpoint_dir =checkpoint_dir
|
| 223 |
+
self .checkpoint_every =checkpoint_every
|
| 224 |
+
|
| 225 |
+
os .makedirs (checkpoint_dir ,exist_ok =True )
|
| 226 |
+
|
| 227 |
+
self .total_episodes =0
|
| 228 |
+
self .best_reward =-float ("inf")
|
| 229 |
+
self .reward_history :List [float ]=[]
|
| 230 |
+
|
| 231 |
+
def _build_action_masks (self ,obs_tensor :torch .Tensor )->List [Optional [torch .Tensor ]]:
|
| 232 |
+
|
| 233 |
+
return [None ]*len (DISCRETE_NVEC )
|
| 234 |
+
|
| 235 |
+
def collect_rollout (self )->RolloutBuffer :
|
| 236 |
+
buffer =RolloutBuffer ()
|
| 237 |
+
obs =self .env ._build_obs ()
|
| 238 |
+
|
| 239 |
+
for _ in range (self .steps_per_update ):
|
| 240 |
+
obs_t =torch .FloatTensor (obs ).unsqueeze (0 ).to (self .device )
|
| 241 |
+
|
| 242 |
+
with torch .no_grad ():
|
| 243 |
+
masks =self ._build_action_masks (obs_t )
|
| 244 |
+
d_action ,c_action ,log_prob ,value =self .policy .get_action_and_log_prob (
|
| 245 |
+
obs_t .squeeze (0 ),masks
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
next_obs ,reward ,done ,info =self .env .step (d_action ,c_action )
|
| 249 |
+
|
| 250 |
+
self .total_episodes +=1
|
| 251 |
+
self .reward_history .append (reward )
|
| 252 |
+
|
| 253 |
+
print (
|
| 254 |
+
f"[Ep {self .total_episodes :04d}] "
|
| 255 |
+
f"reward={reward :+.4f} {reward_to_label (reward )} | "
|
| 256 |
+
f"transport={info ['transport']:5s} dest={info ['dest']:30s} | "
|
| 257 |
+
f"stable={info ['stability']:.2f} "
|
| 258 |
+
f"speed={info ['throughput_mbps']:.2f}Mbps"
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
buffer .add (obs ,d_action ,c_action ,reward ,log_prob ,value ,done )
|
| 262 |
+
obs =next_obs
|
| 263 |
+
|
| 264 |
+
if done :
|
| 265 |
+
obs =self .env .reset ()
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
if self .total_episodes %self .checkpoint_every ==0 :
|
| 269 |
+
self .save_checkpoint ()
|
| 270 |
+
|
| 271 |
+
return buffer
|
| 272 |
+
|
| 273 |
+
def update (self ,buffer :RolloutBuffer ):
|
| 274 |
+
print (f"\n[PPO] ── Update #{self .total_episodes //self .steps_per_update } ──────────────────────────")
|
| 275 |
+
print (f"[PPO] Buffer: {len (buffer .rewards )} episodes | "
|
| 276 |
+
f"mean_reward={sum (buffer .rewards )/len (buffer .rewards ):+.4f} | "
|
| 277 |
+
f"positive={sum (1 for r in buffer .rewards if r >0 )}/{len (buffer .rewards )}")
|
| 278 |
+
obs_t ,d_act_t ,c_act_t =buffer .to_tensors (self .device )
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
with torch .no_grad ():
|
| 282 |
+
last_obs =torch .FloatTensor (self .env ._build_obs ()).to (self .device )
|
| 283 |
+
_ ,_ ,_ ,last_val =self .policy .forward (last_obs .unsqueeze (0 ))
|
| 284 |
+
last_value =last_val .item ()
|
| 285 |
+
|
| 286 |
+
advantages ,returns =buffer .compute_returns (last_value ,self .gamma ,self .gae_lambda )
|
| 287 |
+
adv_t =torch .FloatTensor (advantages ).to (self .device )
|
| 288 |
+
ret_t =torch .FloatTensor (returns ).to (self .device )
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
adv_t =(adv_t -adv_t .mean ())/(adv_t .std ()+1e-8 )
|
| 292 |
+
|
| 293 |
+
old_log_probs =torch .stack (buffer .log_probs ).to (self .device ).detach ()
|
| 294 |
+
|
| 295 |
+
for _ in range (self .update_epochs ):
|
| 296 |
+
log_probs ,entropy ,values =self .policy .evaluate_actions (obs_t ,d_act_t ,c_act_t )
|
| 297 |
+
|
| 298 |
+
ratio =(log_probs -old_log_probs ).exp ()
|
| 299 |
+
surr1 =ratio *adv_t
|
| 300 |
+
surr2 =ratio .clamp (1 -self .clip_eps ,1 +self .clip_eps )*adv_t
|
| 301 |
+
|
| 302 |
+
policy_loss =-torch .min (surr1 ,surr2 ).mean ()
|
| 303 |
+
value_loss =F .mse_loss (values ,ret_t )
|
| 304 |
+
entropy_loss =-entropy .mean ()
|
| 305 |
+
|
| 306 |
+
loss =policy_loss +self .vf_coef *value_loss +self .entropy_coef *entropy_loss
|
| 307 |
+
|
| 308 |
+
self .optimizer .zero_grad ()
|
| 309 |
+
loss .backward ()
|
| 310 |
+
nn .utils .clip_grad_norm_ (self .policy .parameters (),self .max_grad_norm )
|
| 311 |
+
self .optimizer .step ()
|
| 312 |
+
|
| 313 |
+
self .scheduler .step ()
|
| 314 |
+
|
| 315 |
+
def train (self ,total_episodes :int =10000 ):
|
| 316 |
+
|
| 317 |
+
print (f"\n{'='*60 }")
|
| 318 |
+
print (f" AlphaBypass — PPO Training")
|
| 319 |
+
print (f" Target: {total_episodes } episodes")
|
| 320 |
+
print (f" Device: {self .device }")
|
| 321 |
+
print (f"{'='*60 }\n")
|
| 322 |
+
|
| 323 |
+
obs =self .env .reset ()
|
| 324 |
+
|
| 325 |
+
while self .total_episodes <total_episodes :
|
| 326 |
+
buffer =self .collect_rollout ()
|
| 327 |
+
self .update (buffer )
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
if len (self .reward_history )>=20 :
|
| 331 |
+
recent =self .reward_history [-20 :]
|
| 332 |
+
print (
|
| 333 |
+
f"\n[Stats] last 20 episodes: "
|
| 334 |
+
f"mean={np .mean (recent ):+.4f} "
|
| 335 |
+
f"max={np .max (recent ):+.4f} "
|
| 336 |
+
f"min={np .min (recent ):+.4f}\n"
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
def save_checkpoint (self ,tag :str =""):
|
| 340 |
+
path =os .path .join (
|
| 341 |
+
self .checkpoint_dir ,
|
| 342 |
+
f"checkpoint_ep{self .total_episodes :05d}{tag }.pt"
|
| 343 |
+
)
|
| 344 |
+
torch .save ({
|
| 345 |
+
"episode":self .total_episodes ,
|
| 346 |
+
"policy_state":self .policy .state_dict (),
|
| 347 |
+
"optimizer_state":self .optimizer .state_dict (),
|
| 348 |
+
"reward_history":self .reward_history ,
|
| 349 |
+
"best_reward":self .best_reward ,
|
| 350 |
+
},path )
|
| 351 |
+
print (f"[Checkpoint] saved → {path }")
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
r =np .mean (self .reward_history [-10 :])if len (self .reward_history )>=10 else -999
|
| 355 |
+
if r >self .best_reward :
|
| 356 |
+
self .best_reward =r
|
| 357 |
+
best_path =os .path .join (self .checkpoint_dir ,"best.pt")
|
| 358 |
+
torch .save (torch .load (path ),best_path )
|
| 359 |
+
print (f"[Checkpoint] 🏆 new best ({r :+.4f}) → {best_path }")
|
| 360 |
+
|
| 361 |
+
def load_checkpoint (self ,path :str ):
|
| 362 |
+
ck =torch .load (path ,map_location =self .device )
|
| 363 |
+
self .policy .load_state_dict (ck ["policy_state"])
|
| 364 |
+
self .optimizer .load_state_dict (ck ["optimizer_state"])
|
| 365 |
+
self .total_episodes =ck ["episode"]
|
| 366 |
+
self .reward_history =ck .get ("reward_history",[])
|
| 367 |
+
self .best_reward =ck .get ("best_reward",-float ("inf"))
|
| 368 |
+
print (f"[Checkpoint] loaded from ep {self .total_episodes }")
|
config_generator.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import json
|
| 3 |
+
import uuid
|
| 4 |
+
import secrets
|
| 5 |
+
import string
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from shared .schemas import VlessConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def generate_uuid ()->str :
|
| 15 |
+
return str (uuid .uuid4 ())
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def generate_short_id (length :int =8 )->str :
|
| 19 |
+
|
| 20 |
+
return secrets .token_hex (length //2 )
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def random_service_name (seed :int )->str :
|
| 24 |
+
|
| 25 |
+
rng_chars =string .ascii_lowercase +string .digits
|
| 26 |
+
r =__import__ ("random").Random (seed )
|
| 27 |
+
length =r .randint (6 ,20 )
|
| 28 |
+
return "".join (r .choice (rng_chars )for _ in range (length ))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def build_server_config (
|
| 36 |
+
cfg :VlessConfig ,
|
| 37 |
+
vless_uuid :str ,
|
| 38 |
+
private_key :str ,
|
| 39 |
+
public_key :str ,
|
| 40 |
+
listen_ip :str ="0.0.0.0",
|
| 41 |
+
)->dict :
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
stream =_build_stream_settings_server (cfg )
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
inbound ={
|
| 48 |
+
"tag":"vless-in",
|
| 49 |
+
"listen":listen_ip ,
|
| 50 |
+
"port":cfg .proxy_port ,
|
| 51 |
+
"protocol":"vless",
|
| 52 |
+
"settings":{
|
| 53 |
+
"clients":[
|
| 54 |
+
{
|
| 55 |
+
"id":vless_uuid ,
|
| 56 |
+
"flow":_pick_flow (cfg ),
|
| 57 |
+
}
|
| 58 |
+
],
|
| 59 |
+
"decryption":"none",
|
| 60 |
+
},
|
| 61 |
+
"streamSettings":stream ,
|
| 62 |
+
"sniffing":{
|
| 63 |
+
"enabled":True ,
|
| 64 |
+
"destOverride":["http","tls","quic"],
|
| 65 |
+
},
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
inbound ["streamSettings"]["realitySettings"]={
|
| 70 |
+
"show":False ,
|
| 71 |
+
"dest":f"{cfg .dest_domain }:443",
|
| 72 |
+
"xver":0 ,
|
| 73 |
+
"serverNames":[cfg .dest_domain ],
|
| 74 |
+
"privateKey":private_key ,
|
| 75 |
+
"shortIds":[cfg .short_id ],
|
| 76 |
+
"spiderX":cfg .spider_x ,
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
config ={
|
| 80 |
+
"log":{"loglevel":"warning"},
|
| 81 |
+
"inbounds":[inbound ],
|
| 82 |
+
"outbounds":[
|
| 83 |
+
{"tag":"direct","protocol":"freedom"},
|
| 84 |
+
{"tag":"block","protocol":"blackhole"},
|
| 85 |
+
],
|
| 86 |
+
"routing":{
|
| 87 |
+
"rules":[
|
| 88 |
+
{"type":"field","ip":["geoip:private"],"outboundTag":"block"},
|
| 89 |
+
]
|
| 90 |
+
},
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
return config
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def _pick_flow (cfg :VlessConfig )->str :
|
| 97 |
+
if cfg .transport_type =="tcp":
|
| 98 |
+
return "xtls-rprx-vision"
|
| 99 |
+
return ""
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _build_stream_settings_server (cfg :VlessConfig )->dict :
|
| 103 |
+
base ={
|
| 104 |
+
"network":cfg .transport_type ,
|
| 105 |
+
"security":"reality",
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
if cfg .transport_type =="grpc":
|
| 109 |
+
base ["grpcSettings"]={
|
| 110 |
+
"serviceName":cfg .grpc_service_name ,
|
| 111 |
+
"multiMode":False ,
|
| 112 |
+
}
|
| 113 |
+
elif cfg .transport_type =="xhttp":
|
| 114 |
+
base ["xhttpSettings"]={
|
| 115 |
+
"mode":cfg .xhttp_mode ,
|
| 116 |
+
"path":cfg .spider_x or "/",
|
| 117 |
+
"host":cfg .dest_domain ,
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
return base
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def build_client_config (
|
| 128 |
+
cfg :VlessConfig ,
|
| 129 |
+
vless_uuid :str ,
|
| 130 |
+
server_ip :str ,
|
| 131 |
+
public_key :str ,
|
| 132 |
+
socks_port :int =10808 ,
|
| 133 |
+
http_port :int =10809 ,
|
| 134 |
+
)->dict :
|
| 135 |
+
|
| 136 |
+
stream =_build_stream_settings_client (cfg ,public_key )
|
| 137 |
+
|
| 138 |
+
outbound ={
|
| 139 |
+
"tag":"proxy",
|
| 140 |
+
"protocol":"vless",
|
| 141 |
+
"settings":{
|
| 142 |
+
"vnext":[
|
| 143 |
+
{
|
| 144 |
+
"address":server_ip ,
|
| 145 |
+
"port":cfg .proxy_port ,
|
| 146 |
+
"users":[
|
| 147 |
+
{
|
| 148 |
+
"id":vless_uuid ,
|
| 149 |
+
"encryption":"none",
|
| 150 |
+
"flow":_pick_flow (cfg ),
|
| 151 |
+
}
|
| 152 |
+
],
|
| 153 |
+
}
|
| 154 |
+
]
|
| 155 |
+
},
|
| 156 |
+
"streamSettings":stream ,
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
if cfg .mux_concurrency >0 :
|
| 161 |
+
outbound ["mux"]={
|
| 162 |
+
"enabled":True ,
|
| 163 |
+
"concurrency":cfg .mux_concurrency ,
|
| 164 |
+
"xudpConcurrency":cfg .mux_concurrency ,
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if cfg .transport_type =="tcp"and cfg .fragment_strategy !="none":
|
| 169 |
+
outbound ["streamSettings"]["sockopt"]={
|
| 170 |
+
"dialerProxy":"fragment",
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
config ={
|
| 174 |
+
"log":{"loglevel":"warning"},
|
| 175 |
+
"inbounds":[
|
| 176 |
+
{
|
| 177 |
+
"tag":"socks",
|
| 178 |
+
"listen":"127.0.0.1",
|
| 179 |
+
"port":socks_port ,
|
| 180 |
+
"protocol":"socks",
|
| 181 |
+
"settings":{"auth":"noauth","udp":True },
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"tag":"http",
|
| 185 |
+
"listen":"127.0.0.1",
|
| 186 |
+
"port":http_port ,
|
| 187 |
+
"protocol":"http",
|
| 188 |
+
},
|
| 189 |
+
],
|
| 190 |
+
"outbounds":[outbound ,{"tag":"direct","protocol":"freedom"}],
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if cfg .transport_type =="tcp"and cfg .fragment_strategy !="none":
|
| 195 |
+
config ["outbounds"].append (_build_fragment_outbound (cfg ))
|
| 196 |
+
|
| 197 |
+
return config
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _build_stream_settings_client (cfg :VlessConfig ,public_key :str )->dict :
|
| 201 |
+
base :dict ={
|
| 202 |
+
"network":cfg .transport_type ,
|
| 203 |
+
"security":"reality",
|
| 204 |
+
"realitySettings":{
|
| 205 |
+
"fingerprint":cfg .fingerprint ,
|
| 206 |
+
"serverName":cfg .dest_domain ,
|
| 207 |
+
"publicKey":public_key ,
|
| 208 |
+
"shortId":cfg .short_id ,
|
| 209 |
+
"spiderX":cfg .spider_x ,
|
| 210 |
+
},
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
if cfg .alpn :
|
| 214 |
+
base ["realitySettings"]["alpn"]=cfg .alpn
|
| 215 |
+
|
| 216 |
+
if cfg .transport_type =="grpc":
|
| 217 |
+
base ["grpcSettings"]={
|
| 218 |
+
"serviceName":cfg .grpc_service_name ,
|
| 219 |
+
}
|
| 220 |
+
elif cfg .transport_type =="xhttp":
|
| 221 |
+
headers ={"Host":cfg .dest_domain }
|
| 222 |
+
headers .update (cfg .extra_headers )
|
| 223 |
+
base ["xhttpSettings"]={
|
| 224 |
+
"mode":cfg .xhttp_mode ,
|
| 225 |
+
"path":cfg .spider_x or "/",
|
| 226 |
+
"headers":headers ,
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
return base
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _build_fragment_outbound (cfg :VlessConfig )->dict :
|
| 233 |
+
return {
|
| 234 |
+
"tag":"fragment",
|
| 235 |
+
"protocol":"freedom",
|
| 236 |
+
"settings":{
|
| 237 |
+
"domainStrategy":"AsIs",
|
| 238 |
+
},
|
| 239 |
+
"streamSettings":{
|
| 240 |
+
"sockopt":{
|
| 241 |
+
"dialerProxy":"",
|
| 242 |
+
"tcpKeepAliveInterval":0 ,
|
| 243 |
+
}
|
| 244 |
+
},
|
| 245 |
+
"fragment":{
|
| 246 |
+
"packets":cfg .fragment_strategy ,
|
| 247 |
+
"length":f"{cfg .fragment_length_min }-{cfg .fragment_length_max }",
|
| 248 |
+
"interval":f"{cfg .fragment_interval_min }-{cfg .fragment_interval_max }",
|
| 249 |
+
},
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def generate_reality_keys ()->tuple [str ,str ]:
|
| 258 |
+
|
| 259 |
+
import subprocess
|
| 260 |
+
result =subprocess .run (
|
| 261 |
+
["xray","x25519"],
|
| 262 |
+
capture_output =True ,text =True ,timeout =10
|
| 263 |
+
)
|
| 264 |
+
lines =result .stdout .strip ().splitlines ()
|
| 265 |
+
priv =lines [0 ].split (": ")[1 ].strip ()
|
| 266 |
+
pub =lines [1 ].split (": ")[1 ].strip ()
|
| 267 |
+
return priv ,pub
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def save_config (config :dict ,path :str ):
|
| 271 |
+
with open (path ,"w")as f :
|
| 272 |
+
json .dump (config ,f ,indent =2 )
|
crypto.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import hmac
|
| 3 |
+
import hashlib
|
| 4 |
+
import time
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import base64
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def sign_payload (payload :dict ,secret :str )->Tuple [str ,str ]:
|
| 15 |
+
|
| 16 |
+
payload ["_ts"]=int (time .time ())
|
| 17 |
+
body =json .dumps (payload ,separators =(",",":"),sort_keys =True )
|
| 18 |
+
sig =hmac .new (
|
| 19 |
+
secret .encode (),
|
| 20 |
+
body .encode (),
|
| 21 |
+
hashlib .sha256
|
| 22 |
+
).hexdigest ()
|
| 23 |
+
return body ,sig
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def verify_payload (body :str ,sig :str ,secret :str ,max_age_seconds :int =300 )->dict :
|
| 27 |
+
|
| 28 |
+
expected =hmac .new (
|
| 29 |
+
secret .encode (),
|
| 30 |
+
body .encode (),
|
| 31 |
+
hashlib .sha256
|
| 32 |
+
).hexdigest ()
|
| 33 |
+
|
| 34 |
+
if not hmac .compare_digest (expected ,sig ):
|
| 35 |
+
raise ValueError ("Invalid HMAC signature")
|
| 36 |
+
|
| 37 |
+
data =json .loads (body )
|
| 38 |
+
ts =data .get ("_ts",0 )
|
| 39 |
+
if abs (time .time ()-ts )>max_age_seconds :
|
| 40 |
+
raise ValueError (f"Stale request: {abs (time .time ()-ts ):.0f}s old")
|
| 41 |
+
|
| 42 |
+
return data
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def generate_self_signed_cert (cert_path :str ,key_path :str ,cn :str ="localhost"):
|
| 50 |
+
|
| 51 |
+
from cryptography import x509
|
| 52 |
+
from cryptography .x509 .oid import NameOID
|
| 53 |
+
from cryptography .hazmat .primitives import hashes ,serialization
|
| 54 |
+
from cryptography .hazmat .primitives .asymmetric import rsa
|
| 55 |
+
from cryptography .hazmat .backends import default_backend
|
| 56 |
+
import datetime
|
| 57 |
+
|
| 58 |
+
key =rsa .generate_private_key (
|
| 59 |
+
public_exponent =65537 ,
|
| 60 |
+
key_size =2048 ,
|
| 61 |
+
backend =default_backend ()
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
subject =issuer =x509 .Name ([
|
| 65 |
+
x509 .NameAttribute (NameOID .COMMON_NAME ,cn ),
|
| 66 |
+
])
|
| 67 |
+
|
| 68 |
+
cert =(
|
| 69 |
+
x509 .CertificateBuilder ()
|
| 70 |
+
.subject_name (subject )
|
| 71 |
+
.issuer_name (issuer )
|
| 72 |
+
.public_key (key .public_key ())
|
| 73 |
+
.serial_number (x509 .random_serial_number ())
|
| 74 |
+
.not_valid_before (datetime .datetime .utcnow ())
|
| 75 |
+
.not_valid_after (datetime .datetime .utcnow ()+datetime .timedelta (days =3650 ))
|
| 76 |
+
.add_extension (
|
| 77 |
+
x509 .SubjectAlternativeName ([x509 .DNSName (cn )]),
|
| 78 |
+
critical =False ,
|
| 79 |
+
)
|
| 80 |
+
.sign (key ,hashes .SHA256 (),default_backend ())
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
with open (cert_path ,"wb")as f :
|
| 84 |
+
f .write (cert .public_bytes (serialization .Encoding .PEM ))
|
| 85 |
+
|
| 86 |
+
with open (key_path ,"wb")as f :
|
| 87 |
+
f .write (key .private_bytes (
|
| 88 |
+
serialization .Encoding .PEM ,
|
| 89 |
+
serialization .PrivateFormat .TraditionalOpenSSL ,
|
| 90 |
+
serialization .NoEncryption ()
|
| 91 |
+
))
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def load_or_create_secret (path :str ="shared_secret.key")->str :
|
| 99 |
+
if os .path .exists (path ):
|
| 100 |
+
with open (path ,"r")as f :
|
| 101 |
+
return f .read ().strip ()
|
| 102 |
+
secret =base64 .urlsafe_b64encode (os .urandom (32 )).decode ()
|
| 103 |
+
with open (path ,"w")as f :
|
| 104 |
+
f .write (secret )
|
| 105 |
+
return secret
|
environment.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
import time
|
| 6 |
+
import uuid
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import Optional ,Tuple
|
| 9 |
+
|
| 10 |
+
from shared .schemas import (
|
| 11 |
+
VlessConfig ,
|
| 12 |
+
EpisodeMetrics ,
|
| 13 |
+
TRANSPORT_TYPES ,
|
| 14 |
+
CANDIDATE_PORTS ,
|
| 15 |
+
SNI_DOMAINS ,
|
| 16 |
+
FINGERPRINTS ,
|
| 17 |
+
ALPN_OPTIONS ,
|
| 18 |
+
FRAGMENT_STRATEGIES ,
|
| 19 |
+
MUX_CONCURRENCY_VALUES ,
|
| 20 |
+
SHORT_ID_LENGTHS ,
|
| 21 |
+
XHTTP_MODES ,
|
| 22 |
+
)
|
| 23 |
+
from server .rl .reward import compute_reward
|
| 24 |
+
|
| 25 |
+
HISTORY_LEN =10
|
| 26 |
+
OBS_PER_EPISODE =7
|
| 27 |
+
|
| 28 |
+
TOTAL_OBS_DIM =HISTORY_LEN *OBS_PER_EPISODE +5
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
DISCRETE_NVEC =[
|
| 35 |
+
len (TRANSPORT_TYPES ),
|
| 36 |
+
len (CANDIDATE_PORTS ),
|
| 37 |
+
len (SNI_DOMAINS ),
|
| 38 |
+
len (FINGERPRINTS ),
|
| 39 |
+
len (ALPN_OPTIONS ),
|
| 40 |
+
len (FRAGMENT_STRATEGIES ),
|
| 41 |
+
len (MUX_CONCURRENCY_VALUES ),
|
| 42 |
+
len (SHORT_ID_LENGTHS ),
|
| 43 |
+
len (XHTTP_MODES ),
|
| 44 |
+
100 ,
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
N_CONTINUOUS =5
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def decode_action (discrete :np .ndarray ,continuous :np .ndarray )->VlessConfig :
|
| 51 |
+
import secrets as _sec
|
| 52 |
+
|
| 53 |
+
transport =TRANSPORT_TYPES [int (discrete [0 ])]
|
| 54 |
+
port =CANDIDATE_PORTS [int (discrete [1 ])]
|
| 55 |
+
dest =SNI_DOMAINS [int (discrete [2 ])]
|
| 56 |
+
fingerprint =FINGERPRINTS [int (discrete [3 ])]
|
| 57 |
+
alpn =ALPN_OPTIONS [int (discrete [4 ])]
|
| 58 |
+
frag_strat =FRAGMENT_STRATEGIES [int (discrete [5 ])]
|
| 59 |
+
mux_conc =MUX_CONCURRENCY_VALUES [int (discrete [6 ])]
|
| 60 |
+
sid_len =SHORT_ID_LENGTHS [int (discrete [7 ])]
|
| 61 |
+
xhttp_mode =XHTTP_MODES [int (discrete [8 ])]
|
| 62 |
+
grpc_seed =int (discrete [9 ])
|
| 63 |
+
|
| 64 |
+
c =continuous
|
| 65 |
+
frag_len_min =int (10 +c [0 ]*190 )
|
| 66 |
+
frag_len_max =frag_len_min +int (c [1 ]*100 )
|
| 67 |
+
frag_interval_min =int (c [2 ]*50 )
|
| 68 |
+
frag_interval_max =frag_interval_min +5
|
| 69 |
+
padding_min =int (c [3 ]*500 )
|
| 70 |
+
padding_max =padding_min +int (c [4 ]*500 )
|
| 71 |
+
|
| 72 |
+
from server .config_generator import random_service_name ,generate_short_id
|
| 73 |
+
grpc_name =random_service_name (grpc_seed )
|
| 74 |
+
short_id =generate_short_id (sid_len )
|
| 75 |
+
|
| 76 |
+
if transport !="tcp":
|
| 77 |
+
frag_strat ="none"
|
| 78 |
+
|
| 79 |
+
padding_enabled =(transport =="tcp")and (padding_min >0 )
|
| 80 |
+
|
| 81 |
+
return VlessConfig (
|
| 82 |
+
transport_type =transport ,
|
| 83 |
+
proxy_port =port ,
|
| 84 |
+
dest_domain =dest ,
|
| 85 |
+
short_id =short_id ,
|
| 86 |
+
spider_x ="/",
|
| 87 |
+
fingerprint =fingerprint ,
|
| 88 |
+
alpn =alpn ,
|
| 89 |
+
grpc_service_name =grpc_name ,
|
| 90 |
+
xhttp_mode =xhttp_mode ,
|
| 91 |
+
fragment_strategy =frag_strat ,
|
| 92 |
+
fragment_length_min =frag_len_min ,
|
| 93 |
+
fragment_length_max =frag_len_max ,
|
| 94 |
+
fragment_interval_min =frag_interval_min ,
|
| 95 |
+
fragment_interval_max =frag_interval_max ,
|
| 96 |
+
padding_enabled =padding_enabled ,
|
| 97 |
+
padding_min =padding_min ,
|
| 98 |
+
padding_max =padding_max ,
|
| 99 |
+
mux_concurrency =mux_conc ,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def metrics_to_obs_vector (m :EpisodeMetrics )->np .ndarray :
|
| 104 |
+
return np .array ([
|
| 105 |
+
float (m .connected ),
|
| 106 |
+
min (m .stability_ratio ,1.0 ),
|
| 107 |
+
min (m .throughput_ratio ,1.0 ),
|
| 108 |
+
max (0.0 ,1.0 -m .avg_ping_ms /1000.0 ),
|
| 109 |
+
1.0 -min (m .packet_loss_ratio ,1.0 ),
|
| 110 |
+
max (0.0 ,1.0 -m .connect_time_ms /5000.0 ),
|
| 111 |
+
max (0.0 ,1.0 -m .reconnect_count /5.0 ),
|
| 112 |
+
],dtype =np .float32 )
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class AlphaBypassEnv :
|
| 116 |
+
def __init__ (
|
| 117 |
+
self ,
|
| 118 |
+
bridge ,
|
| 119 |
+
episode_duration :int =90 ,
|
| 120 |
+
baseline_mbps :float =1.0 ,
|
| 121 |
+
max_steps :int =0 ,
|
| 122 |
+
fail_streak_warn :int =10 ,
|
| 123 |
+
):
|
| 124 |
+
self .bridge =bridge
|
| 125 |
+
self .episode_duration =episode_duration
|
| 126 |
+
self .baseline_mbps =baseline_mbps
|
| 127 |
+
self .max_steps =max_steps
|
| 128 |
+
self .fail_streak_warn =fail_streak_warn
|
| 129 |
+
|
| 130 |
+
self .history :list [EpisodeMetrics ]=[]
|
| 131 |
+
self .step_count :int =0
|
| 132 |
+
self ._fail_streak :int =0
|
| 133 |
+
|
| 134 |
+
self .obs_dim =TOTAL_OBS_DIM
|
| 135 |
+
self .discrete_nvec =DISCRETE_NVEC
|
| 136 |
+
self .n_continuous =N_CONTINUOUS
|
| 137 |
+
|
| 138 |
+
def _build_obs (self )->np .ndarray :
|
| 139 |
+
obs =np .zeros (self .obs_dim ,dtype =np .float32 )
|
| 140 |
+
|
| 141 |
+
relevant =self .history [-HISTORY_LEN :]
|
| 142 |
+
for i ,m in enumerate (reversed (relevant )):
|
| 143 |
+
start =i *OBS_PER_EPISODE
|
| 144 |
+
obs [start :start +OBS_PER_EPISODE ]=metrics_to_obs_vector (m )
|
| 145 |
+
|
| 146 |
+
base =HISTORY_LEN *OBS_PER_EPISODE
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
t =time .localtime ()
|
| 150 |
+
hour =t .tm_hour +t .tm_min /60.0
|
| 151 |
+
obs [base ]=min (self .step_count /1000.0 ,1.0 )
|
| 152 |
+
obs [base +1 ]=math .sin (2 *math .pi *hour /24 )
|
| 153 |
+
obs [base +2 ]=math .cos (2 *math .pi *hour /24 )
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
dow =t .tm_wday
|
| 157 |
+
obs [base +3 ]=math .sin (2 *math .pi *dow /7 )
|
| 158 |
+
obs [base +4 ]=math .cos (2 *math .pi *dow /7 )
|
| 159 |
+
|
| 160 |
+
return obs
|
| 161 |
+
|
| 162 |
+
def reset (self )->np .ndarray :
|
| 163 |
+
self .history =[]
|
| 164 |
+
self .step_count =0
|
| 165 |
+
self ._fail_streak =0
|
| 166 |
+
return self ._build_obs ()
|
| 167 |
+
|
| 168 |
+
def step (
|
| 169 |
+
self ,
|
| 170 |
+
discrete_action :np .ndarray ,
|
| 171 |
+
continuous_action :np .ndarray ,
|
| 172 |
+
)->Tuple [np .ndarray ,float ,bool ,dict ]:
|
| 173 |
+
cfg =decode_action (discrete_action ,continuous_action )
|
| 174 |
+
episode_id =str (uuid .uuid4 ())[:8 ]
|
| 175 |
+
|
| 176 |
+
metrics =self .bridge .run_episode (
|
| 177 |
+
cfg =cfg ,
|
| 178 |
+
episode_id =episode_id ,
|
| 179 |
+
duration =self .episode_duration ,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
reward =compute_reward (metrics ,self .baseline_mbps )
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
self .bridge .report_reward (episode_id ,reward )
|
| 186 |
+
self .history .append (metrics )
|
| 187 |
+
self .step_count +=1
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
if not metrics .connected :
|
| 191 |
+
self ._fail_streak +=1
|
| 192 |
+
if self ._fail_streak ==self .fail_streak_warn :
|
| 193 |
+
print (
|
| 194 |
+
f"\n⚠️ [Degradation] {self ._fail_streak } FAIL подряд! "
|
| 195 |
+
f"Возможно РКН изменил политику или проблема с сетью."
|
| 196 |
+
)
|
| 197 |
+
else :
|
| 198 |
+
self ._fail_streak =0
|
| 199 |
+
|
| 200 |
+
done =(self .max_steps >0 and self .step_count >=self .max_steps )
|
| 201 |
+
obs =self ._build_obs ()
|
| 202 |
+
|
| 203 |
+
info ={
|
| 204 |
+
"episode_id":episode_id ,
|
| 205 |
+
"reward":reward ,
|
| 206 |
+
"connected":metrics .connected ,
|
| 207 |
+
"stability":metrics .stability_ratio ,
|
| 208 |
+
"throughput_mbps":metrics .throughput_mbps ,
|
| 209 |
+
"transport":cfg .transport_type ,
|
| 210 |
+
"dest":cfg .dest_domain ,
|
| 211 |
+
"fail_streak":self ._fail_streak ,
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
return obs ,reward ,done ,info
|
reward.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from shared .schemas import EpisodeMetrics
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def compute_reward (metrics :EpisodeMetrics ,baseline_mbps :float =1.0 )->float :
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
if not metrics .connected :
|
| 9 |
+
return -1.0
|
| 10 |
+
|
| 11 |
+
r =0.0
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
connect_score =max (0.0 ,1.0 -metrics .connect_time_ms /5000.0 )
|
| 16 |
+
|
| 17 |
+
loss_score =1.0 -metrics .packet_loss_ratio
|
| 18 |
+
|
| 19 |
+
ping_score =max (0.0 ,1.0 -metrics .avg_ping_ms /1000.0 )
|
| 20 |
+
|
| 21 |
+
connection_component =(connect_score *0.3 +loss_score *0.4 +ping_score *0.3 )
|
| 22 |
+
r +=0.50 *connection_component
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
stability_score =metrics .stability_ratio
|
| 27 |
+
|
| 28 |
+
reconnect_penalty =max (0.0 ,1.0 -metrics .reconnect_count /5.0 )
|
| 29 |
+
stability_component =stability_score *0.7 +reconnect_penalty *0.3
|
| 30 |
+
r +=0.35 *stability_component
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
safe_baseline =max (baseline_mbps ,0.1 )
|
| 34 |
+
speed_ratio =min (metrics .throughput_mbps /safe_baseline ,1.0 )
|
| 35 |
+
|
| 36 |
+
import math
|
| 37 |
+
speed_score =math .log1p (speed_ratio *9 )/math .log1p (9 )
|
| 38 |
+
r +=0.15 *speed_score
|
| 39 |
+
|
| 40 |
+
return round (r ,4 )
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def reward_to_label (r :float )->str :
|
| 44 |
+
if r <-0.5 :
|
| 45 |
+
return "💀 FAIL"
|
| 46 |
+
if r <0.0 :
|
| 47 |
+
return "❌ bad"
|
| 48 |
+
if r <0.3 :
|
| 49 |
+
return "⚠️ weak"
|
| 50 |
+
if r <0.6 :
|
| 51 |
+
return "🟡 ok"
|
| 52 |
+
if r <0.8 :
|
| 53 |
+
return "🟢 good"
|
| 54 |
+
return "🏆 great"
|
schemas.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
from dataclasses import dataclass ,field ,asdict
|
| 4 |
+
from typing import Optional ,List ,Dict ,Any
|
| 5 |
+
import json
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
TRANSPORT_TYPES =["tcp","grpc"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
CANDIDATE_PORTS =[443 ,80 ,8443 ,2053 ,2083 ,2087 ,9443 ]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
SNI_DOMAINS =[
|
| 20 |
+
|
| 21 |
+
"download.nvidia.com",
|
| 22 |
+
"swscan.apple.com",
|
| 23 |
+
"updates.cdn-apple.com",
|
| 24 |
+
"steamcdn-a.akamaihd.net",
|
| 25 |
+
"dl.delivery.mp.microsoft.com",
|
| 26 |
+
"download.windowsupdate.com",
|
| 27 |
+
"cdn.cloudflare.steamstatic.com",
|
| 28 |
+
"origin-a.akamaihd.net",
|
| 29 |
+
"pkg-containers.githubusercontent.com",
|
| 30 |
+
"download.jetbrains.com",
|
| 31 |
+
"packages.ubuntu.com",
|
| 32 |
+
|
| 33 |
+
"ajax.aspnetcdn.com",
|
| 34 |
+
"github-releases.githubusercontent.com",
|
| 35 |
+
"objects.githubusercontent.com",
|
| 36 |
+
"software.download.prss.microsoft.com",
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
FINGERPRINTS =["chrome","firefox","edge","safari","ios","random","randomized"]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
ALPN_OPTIONS =[
|
| 46 |
+
["h2","http/1.1"],
|
| 47 |
+
["h2"],
|
| 48 |
+
["http/1.1"],
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
FRAGMENT_STRATEGIES =["none","tlshello","all"]
|
| 52 |
+
|
| 53 |
+
MUX_CONCURRENCY_VALUES =[0 ,1 ,2 ,4 ,8 ,16 ,32 ]
|
| 54 |
+
|
| 55 |
+
SHORT_ID_LENGTHS =[4 ,8 ,16 ]
|
| 56 |
+
|
| 57 |
+
XHTTP_MODES =["packet-up","streaming"]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dataclass
|
| 65 |
+
class VlessConfig :
|
| 66 |
+
|
| 67 |
+
transport_type :str ="tcp"
|
| 68 |
+
proxy_port :int =443
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
dest_domain :str ="download.nvidia.com"
|
| 72 |
+
short_id :str ="abcdef01"
|
| 73 |
+
spider_x :str ="/"
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
fingerprint :str ="chrome"
|
| 77 |
+
alpn :List [str ]=field (default_factory =lambda :["h2","http/1.1"])
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
grpc_service_name :str ="grpc"
|
| 81 |
+
xhttp_mode :str ="packet-up"
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
fragment_strategy :str ="none"
|
| 85 |
+
fragment_length_min :int =50
|
| 86 |
+
fragment_length_max :int =100
|
| 87 |
+
fragment_interval_min :int =1
|
| 88 |
+
fragment_interval_max :int =5
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
padding_enabled :bool =False
|
| 92 |
+
padding_min :int =0
|
| 93 |
+
padding_max :int =0
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
mux_concurrency :int =0
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
extra_headers :Dict [str ,str ]=field (default_factory =dict )
|
| 100 |
+
|
| 101 |
+
def to_dict (self )->dict :
|
| 102 |
+
return asdict (self )
|
| 103 |
+
|
| 104 |
+
@classmethod
|
| 105 |
+
def from_dict (cls ,d :dict )->"VlessConfig":
|
| 106 |
+
return cls (**d )
|
| 107 |
+
|
| 108 |
+
def to_json (self )->str :
|
| 109 |
+
return json .dumps (self .to_dict ())
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def from_json (cls ,s :str )->"VlessConfig":
|
| 113 |
+
return cls .from_dict (json .loads (s ))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@dataclass
|
| 121 |
+
class EpisodeMetrics :
|
| 122 |
+
episode_id :str =""
|
| 123 |
+
timestamp :float =field (default_factory =time .time )
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
connected :bool =False
|
| 127 |
+
connect_time_ms :float =0.0
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
stability_ratio :float =0.0
|
| 131 |
+
reconnect_count :int =0
|
| 132 |
+
drop_count :int =0
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
throughput_mbps :float =0.0
|
| 136 |
+
throughput_ratio :float =0.0
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
avg_ping_ms :float =0.0
|
| 140 |
+
max_ping_ms :float =0.0
|
| 141 |
+
packet_loss_ratio :float =0.0
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
error_message :Optional [str ]=None
|
| 145 |
+
samples :int =0
|
| 146 |
+
|
| 147 |
+
def to_dict (self )->dict :
|
| 148 |
+
return asdict (self )
|
| 149 |
+
|
| 150 |
+
@classmethod
|
| 151 |
+
def from_dict (cls ,d :dict )->"EpisodeMetrics":
|
| 152 |
+
return cls (**d )
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@dataclass
|
| 160 |
+
class EpisodeCommand :
|
| 161 |
+
episode_id :str =""
|
| 162 |
+
config :Optional [dict ]=None
|
| 163 |
+
duration_seconds :int =90
|
| 164 |
+
server_ip :str =""
|
| 165 |
+
server_port :int =443
|
| 166 |
+
uuid :str =""
|
| 167 |
+
|
| 168 |
+
def to_dict (self )->dict :
|
| 169 |
+
return asdict (self )
|
| 170 |
+
|
| 171 |
+
@classmethod
|
| 172 |
+
def from_dict (cls ,d :dict )->"EpisodeCommand":
|
| 173 |
+
obj =cls (**{k :v for k ,v in d .items ()if k !="config"})
|
| 174 |
+
obj .config =d .get ("config")
|
| 175 |
+
return obj
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@dataclass
|
| 179 |
+
class ClientStatus :
|
| 180 |
+
episode_id :str =""
|
| 181 |
+
partial_metrics :Optional [dict ]=None
|
| 182 |
+
phase :str ="idle"
|