// Reboot and rollback logic // /data/update does not exist? => MONITOR // CHECK: // Connection to tcp!ipv4.google.com!80 unsuccessful? => FAIL // Connection to tcp!ipv6.google.com!80 unsuccessful? => FAIL // No connection from updater? => FAIL // => MONITOR // FAIL: // If 5m passed since first CHECK? => TIMEOUT // Wait 30s => CHECK // TIMEOUT: // Trigger rollback (can use /data/admind.passwd) => REBOOT // MONITOR: // Connection to tcp!ipv4.google.com!80 unsuccessful? => DROPOUT // Connection to tcp!ipv6.google.com!80 unsuccessful? => DROPOUT // Wait 5m => MONITOR // DROPOUT: // If DROPOUT for 1h? => REBOOT // Wait 5m => MONITOR use std::fs; use std::io::{self, Read, Write}; use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream}; use std::time::{Duration, Instant}; use nix::sys::signal::Signal; use nix::unistd::Pid; const UPDATE_FILE: &str = "/data/update"; const LISTEN_SOCKET: &str = "[::]:12808"; const MAGIC: [u8; 4] = [0x32, 0x7f, 0xfe, 0x4c]; const RESP_OK: [u8; 5] = [0x32, 0x7f, 0xfe, 0x4c, 0x00]; const RESP_NORMAL: [u8; 5] = [0x32, 0x7f, 0xfe, 0x4c, 0x01]; const CHECK_INTERVAL: Duration = Duration::from_secs(30); const CHECK_TIMEOUT: Duration = Duration::from_secs(600); const MONITOR_INTERVAL: Duration = Duration::from_secs(600); const MONITOR_TIMEOUT: Duration = Duration::from_secs(3600); const TCP_TIMEOUT: Duration = Duration::from_secs(8); const POLL_INTERVAL: Duration = Duration::from_millis(500); const PING_V4: &str = "1.1.1.1:80"; const PING_V6: &str = "[2606:4700:4700::1111]:80"; fn main() { println!("[info] init"); match run() { Ok(_) => eprintln!("[warn] logic terminated unexpectedly"), Err(e) => eprintln!("[warn] {}", e), } } fn run() -> io::Result<()> { let ln = TcpListener::bind(LISTEN_SOCKET)?; ln.set_nonblocking(true)?; if fs::exists(UPDATE_FILE)? { eprintln!("[info] update detected"); check_rollback(&ln)?; eprintln!("[info] no rollback needed"); } else { println!("[info] no update, skipping rollback check"); } monitor(&ln) } fn check_rollback(ln: &TcpListener) -> io::Result<()> { let mut outbound_healthy_v4 = false; let mut outbound_healthy_v6 = false; let mut inbound_healthy = false; let t_start = Instant::now(); let mut t = t_start; loop { match ln.accept() { Ok((conn, raddr)) => match handle_conn_check(conn, raddr, &mut inbound_healthy) { Ok(_) => {} Err(e) => eprintln!("[warn] handle (rollback) {}: {}", raddr, e), }, Err(e) if e.kind() == io::ErrorKind::WouldBlock => {} Err(e) => return Err(e), } let now = Instant::now(); if now.duration_since(t) >= CHECK_INTERVAL { check_connectivity(&mut outbound_healthy_v4, &mut outbound_healthy_v6)?; t = now; } if outbound_healthy_v4 && outbound_healthy_v6 && inbound_healthy { break; } if now.duration_since(t_start) >= CHECK_TIMEOUT { eprintln!( "rollback, IPv4: {}, IPv6: {}, inbound: {}", if outbound_healthy_v4 { "OK" } else { "ERR" }, if outbound_healthy_v6 { "OK" } else { "ERR" }, if inbound_healthy { "OK" } else { "ERR" } ); return rollback(); } std::thread::sleep(POLL_INTERVAL); } Ok(()) } fn monitor(ln: &TcpListener) -> io::Result<()> { let mut t_healthy = Instant::now(); let mut t = t_healthy; loop { match ln.accept() { Ok((conn, raddr)) => match handle_conn_monitor(conn, raddr) { Ok(_) => {} Err(e) => eprintln!("[warn] handle (monitor) {}: {}", raddr, e), }, Err(e) if e.kind() == io::ErrorKind::WouldBlock => {} Err(e) => return Err(e), } let mut outbound_healthy_v4 = false; let mut outbound_healthy_v6 = false; let now = Instant::now(); if now.duration_since(t) >= MONITOR_INTERVAL { check_connectivity(&mut outbound_healthy_v4, &mut outbound_healthy_v6)?; t = now; } if outbound_healthy_v4 && outbound_healthy_v6 { t_healthy = now; continue; } if now.duration_since(t_healthy) >= MONITOR_TIMEOUT { return reboot(); } std::thread::sleep(POLL_INTERVAL); } } fn handle_conn_check( mut conn: TcpStream, raddr: SocketAddr, inbound_healthy: &mut bool, ) -> io::Result<()> { conn.set_read_timeout(Some(TCP_TIMEOUT))?; conn.set_write_timeout(Some(TCP_TIMEOUT))?; let mut buf = [0; 4]; loop { match conn.read_exact(&mut buf) { Ok(_) => break, Err(e) if e.kind() == io::ErrorKind::WouldBlock => {} Err(e) => return Err(e), } } if buf != MAGIC { eprintln!("[warn] handle {raddr}: bad magic {buf:?}"); return Ok(()); } *inbound_healthy = true; conn.write_all(&RESP_OK)?; conn.shutdown(Shutdown::Both)?; println!("[info] inbound: {raddr}"); Ok(()) } fn handle_conn_monitor(mut conn: TcpStream, raddr: SocketAddr) -> io::Result<()> { conn.set_read_timeout(Some(TCP_TIMEOUT))?; conn.set_write_timeout(Some(TCP_TIMEOUT))?; let mut buf = [0; 4]; loop { match conn.read_exact(&mut buf) { Ok(_) => break, Err(e) if e.kind() == io::ErrorKind::WouldBlock => {} Err(e) => return Err(e), } } if buf != MAGIC { eprintln!("[warn] handle {raddr}: bad magic {buf:?}"); return Ok(()); } conn.write_all(&RESP_NORMAL)?; conn.shutdown(Shutdown::Both)?; println!("[info] redundant inbound: {raddr}"); Ok(()) } fn check_connectivity( outbound_healthy_v4: &mut bool, outbound_healthy_v6: &mut bool, ) -> io::Result<()> { let mut buf = [0; 1024]; let conn4 = match TcpStream::connect_timeout(&PING_V4.parse().expect("PING_V4 invalid"), TCP_TIMEOUT) { Ok(conn) => Some(conn), Err(e) => { eprintln!("[warn] IPv4: connect: {}", e); *outbound_healthy_v4 = false; None } }; if let Some(mut conn4) = conn4 { conn4.set_read_timeout(Some(TCP_TIMEOUT))?; conn4.set_write_timeout(Some(TCP_TIMEOUT))?; conn4.write_all(b"GET / HTTP/1.1\n\n").ok(); loop { match conn4.read(&mut buf) { Ok(0) => { eprintln!("[warn] IPv4: connection closed"); *outbound_healthy_v4 = false; } Ok(_) => { *outbound_healthy_v4 = true; } Err(e) if e.kind() == io::ErrorKind::WouldBlock => continue, Err(e) => { eprintln!("[warn] IPv4: read: {}", e); *outbound_healthy_v4 = false; } } break; } conn4.shutdown(Shutdown::Both).ok(); } let conn6 = match TcpStream::connect_timeout(&PING_V6.parse().expect("PING_V6 invalid"), TCP_TIMEOUT) { Ok(conn) => Some(conn), Err(e) => { eprintln!("[warn] IPv6: connect: {}", e); *outbound_healthy_v6 = false; None } }; if let Some(mut conn6) = conn6 { conn6.set_read_timeout(Some(TCP_TIMEOUT))?; conn6.set_write_timeout(Some(TCP_TIMEOUT))?; conn6.write_all(b"GET / HTTP/1.1\n\n").ok(); loop { match conn6.read(&mut buf) { Ok(0) => { eprintln!("[warn] IPv6: connection closed"); *outbound_healthy_v6 = false; } Ok(_) => { *outbound_healthy_v6 = true; } Err(e) if e.kind() == io::ErrorKind::WouldBlock => continue, Err(e) => { eprintln!("[warn] IPv6: read: {}", e); *outbound_healthy_v6 = false; } } break; } conn6.shutdown(Shutdown::Both).ok(); } Ok(()) } fn rollback() -> io::Result<()> { switch_to_inactive_root()?; reboot() } fn reboot() -> io::Result<()> { eprintln!("[info] connection unhealthy, rebooting"); match nix::sys::signal::kill(Pid::from_raw(1), Signal::SIGUSR1) { Ok(_) => Ok(()), Err(e) => Err(io::Error::from(e)), } } fn modify_cmdline(new: &str) -> io::Result<()> { let boot = boot_dev().ok_or(io::Error::from(io::ErrorKind::NotFound))?; let boot_partition = std::fs::OpenOptions::new() .read(true) .write(true) .open(boot)?; let buf_stream = fscommon::BufStream::new(boot_partition); let bootfs = fatfs::FileSystem::new(buf_stream, fatfs::FsOptions::new())?; let mut file = bootfs.root_dir().open_file("cmdline.txt")?; file.write_all(new.as_bytes())?; Ok(()) } fn dev() -> Option<&'static str> { let devs = ["/dev/mmcblk0", "/dev/sda", "/dev/vda"]; devs.into_iter().find(|&dev| std::fs::metadata(dev).is_ok()) } fn boot_dev() -> Option<&'static str> { Some(match dev()? { "/dev/mmcblk0" => "/dev/mmcblk0p1", "/dev/sda" => "/dev/sda1", "/dev/vda" => "/dev/vda1", _ => unreachable!(), }) } fn inactive_root() -> io::Result { let cmdline = std::fs::read_to_string("/proc/cmdline")?; for seg in cmdline.split(' ') { if seg.starts_with("root=PARTUUID=00000000-") { let root_id = match seg .split("root=PARTUUID=00000000-0") .collect::>() .into_iter() .next_back() .expect("no root device") { "2" => "3", "3" => "2", _ => unreachable!(), }; return Ok( match dev().ok_or(io::Error::from(io::ErrorKind::NotFound))? { "/dev/mmcblk0" => format!("/dev/mmcblk0p{}", root_id), "/dev/sda" => format!("/dev/sda{}", root_id), "/dev/vda" => format!("/dev/vda{}", root_id), _ => unreachable!(), }, ); } } Err(io::Error::from(io::ErrorKind::NotFound)) } fn switch_to_inactive_root() -> io::Result<()> { let new = inactive_root()?; let new = String::from("root=PARTUUID=00000000-0") + &new.chars().last().unwrap().to_string(); let cmdline = format!("{} init=/bin/init rootwait console=tty1", new); modify_cmdline(&cmdline)?; nix::unistd::sync(); Ok(()) }