diff --git a/sys/kern/uipc_socket2.c b/sys/kern/uipc_socket2.c index 6da17be1dc4c..6b6b7dff38c9 100644 --- a/sys/kern/uipc_socket2.c +++ b/sys/kern/uipc_socket2.c @@ -1506,22 +1506,6 @@ solocked2(const struct socket *so1, const struct socket *so2) return mutex_owned(lock); } -/* - * sosetlock: assign a default lock to a new socket. - */ -void -sosetlock(struct socket *so) -{ - if (so->so_lock == NULL) { - kmutex_t *lock = softnet_lock; - - so->so_lock = lock; - mutex_obj_hold(lock); - mutex_enter(lock); - } - KASSERT(solocked(so)); -} - /* * Set lock on sockbuf sb; sleep if lock is already held. * Unless SB_NOINTR is set on sockbuf, sleep is interruptible. diff --git a/sys/net/if_llatbl.c b/sys/net/if_llatbl.c index 54498d21ac32..aa6c3718b933 100644 --- a/sys/net/if_llatbl.c +++ b/sys/net/if_llatbl.c @@ -47,9 +47,13 @@ #include #include #include +#include #include #include #include +#include +#include +#include #ifdef DDB #include @@ -72,6 +76,29 @@ static struct pool llentry_pool; static void lltable_unlink(struct lltable *llt); static void llentries_unlink(struct lltable *llt, struct llentries *head); +static struct workqueue *llatbl_wq; + +/* + * Deferred llentry free. Entries are collected on a per-lltable + * pending list and freed in batches behind a single pserialize_perform. + */ +static void +llentry_destroy_cb(struct work *wk, void *arg) +{ + struct lltable *llt = container_of(wk, struct lltable, llt_destroy_work); + struct llentry *lle; + + pserialize_perform(llt->llt_psz); + + mutex_enter(&llt->llt_destroy_lock); + while ((lle = LIST_FIRST(&llt->llt_destroy_list)) != NULL) { + LIST_REMOVE(lle, lle_chain); + PSLIST_ENTRY_DESTROY(lle, lle_next); + pool_put(&llentry_pool, lle); + } + mutex_exit(&llt->llt_destroy_lock); +} + static void htable_unlink_entry(struct llentry *lle); static void htable_link_entry(struct lltable *llt, struct llentry *lle); static int htable_foreach_lle(struct lltable *llt, llt_foreach_cb_t *f, @@ -189,13 +216,14 @@ done: static int htable_foreach_lle(struct lltable *llt, llt_foreach_cb_t *f, void *farg) { - struct llentry *lle, *next; + struct llentry *lle; int i, error; error = 0; for (i = 0; i < llt->llt_hsize; i++) { - LIST_FOREACH_SAFE(lle, &llt->lle_head[i], lle_next, next) { + PSLIST_WRITER_FOREACH(lle, &llt->lle_head[i], struct llentry, + lle_next) { error = f(llt, lle, farg); if (error != 0) break; @@ -208,7 +236,7 @@ htable_foreach_lle(struct lltable *llt, llt_foreach_cb_t *f, void *farg) static void htable_link_entry(struct lltable *llt, struct llentry *lle) { - struct llentries *lleh; + struct pslist_head *lleh; uint32_t hashidx; if ((lle->la_flags & LLE_LINKED) != 0) @@ -222,7 +250,8 @@ htable_link_entry(struct lltable *llt, struct llentry *lle) lle->lle_tbl = llt; lle->lle_head = lleh; lle->la_flags |= LLE_LINKED; - LIST_INSERT_HEAD(lleh, lle, lle_next); + PSLIST_ENTRY_INIT(lle, lle_next); + PSLIST_WRITER_INSERT_HEAD(lleh, lle, lle_next); llt->llt_lle_count++; } @@ -233,7 +262,7 @@ htable_unlink_entry(struct llentry *lle) if ((lle->la_flags & LLE_LINKED) != 0) { IF_AFDATA_WLOCK_ASSERT(lle->lle_tbl->llt_ifp); - LIST_REMOVE(lle, lle_next); + PSLIST_WRITER_REMOVE(lle, lle_next); lle->la_flags &= ~(LLE_VALID | LLE_LINKED); #if 0 lle->lle_tbl = NULL; @@ -295,6 +324,10 @@ static void htable_free_tbl(struct lltable *llt) { + /* Wait for any pending deferred-free work to complete */ + workqueue_wait(llatbl_wq, &llt->llt_destroy_work); + pserialize_destroy(llt->llt_psz); + mutex_destroy(&llt->llt_destroy_lock); free(llt->lle_head, M_LLTABLE); free(llt, M_LLTABLE); } @@ -352,6 +385,23 @@ void llentry_pool_put(struct llentry *lle) { + /* + * If this entry was linked into a PSLIST hash chain, readers + * may still be traversing it. Queue on the per-lltable + * destroy list for batched free after pserialize_perform. + */ + if (lle->lle_tbl != NULL) { + struct lltable *llt = lle->lle_tbl; + bool first; + mutex_enter(&llt->llt_destroy_lock); + first = LIST_EMPTY(&llt->llt_destroy_list); + LIST_INSERT_HEAD(&llt->llt_destroy_list, lle, lle_chain); + mutex_exit(&llt->llt_destroy_lock); + if (first) + workqueue_enqueue(llatbl_wq, &llt->llt_destroy_work, + NULL); + return; + } pool_put(&llentry_pool, lle); } @@ -492,13 +542,18 @@ lltable_drain(int af) if (llt->llt_af != af) continue; - for (i=0; i < llt->llt_hsize; i++) { - LIST_FOREACH(lle, &llt->lle_head[i], lle_next) { + IF_AFDATA_WLOCK(llt->llt_ifp); + for (i = 0; i < llt->llt_hsize; i++) { + PSLIST_WRITER_FOREACH(lle, &llt->lle_head[i], + struct llentry, lle_next) { + if (lle->la_flags & LLE_DELETED) + continue; LLE_WLOCK(lle); lltable_drop_entry_queue(lle); LLE_WUNLOCK(lle); } } + IF_AFDATA_WUNLOCK(llt->llt_ifp); } LLTABLE_RUNLOCK(); } @@ -527,11 +582,15 @@ lltable_allocate_htbl(uint32_t hsize) llt = malloc(sizeof(struct lltable), M_LLTABLE, M_WAITOK | M_ZERO); llt->llt_hsize = hsize; - llt->lle_head = malloc(sizeof(struct llentries) * hsize, + llt->lle_head = malloc(sizeof(struct pslist_head) * hsize, M_LLTABLE, M_WAITOK | M_ZERO); for (i = 0; i < llt->llt_hsize; i++) - LIST_INIT(&llt->lle_head[i]); + PSLIST_INIT(&llt->lle_head[i]); + + llt->llt_psz = pserialize_create(); + mutex_init(&llt->llt_destroy_lock, MUTEX_DEFAULT, IPL_SOFTNET); + LIST_INIT(&llt->llt_destroy_list); /* Set some default callbacks */ llt->llt_link_entry = htable_link_entry; @@ -787,12 +846,17 @@ out: void lltableinit(void) { + int error; SLIST_INIT(&lltables); rw_init(&lltable_rwlock); pool_init(&llentry_pool, sizeof(struct llentry), 0, 0, 0, "llentrypl", NULL, IPL_SOFTNET); + + error = workqueue_create(&llatbl_wq, "llatbl_free", + llentry_destroy_cb, NULL, PRI_SOFTNET, IPL_SOFTNET, WQ_MPSAFE); + KASSERT(error == 0); } #ifdef __FreeBSD__ @@ -810,7 +874,7 @@ llatbl_lle_show(struct llentry_sa *la) lle = &la->base; db_printf("lle=%p\n", lle); - db_printf(" lle_next=%p\n", lle->lle_next.le_next); + db_printf(" lle_next=%p\n", lle->lle_next.ple_next); db_printf(" lle_lock=%p\n", &lle->lle_lock); db_printf(" lle_tbl=%p\n", lle->lle_tbl); db_printf(" lle_head=%p\n", lle->lle_head); @@ -882,7 +946,8 @@ llatbl_llt_show(struct lltable *llt) llt, llt->llt_af, llt->llt_ifp); for (i = 0; i < llt->llt_hsize; i++) { - LIST_FOREACH(lle, &llt->lle_head[i], lle_next) { + PSLIST_WRITER_FOREACH(lle, &llt->lle_head[i], + struct llentry, lle_next) { llatbl_lle_show((struct llentry_sa *)lle); if (db_pager_quit) diff --git a/sys/net/if_llatbl.h b/sys/net/if_llatbl.h index 640ae6a4d94b..4f758863ae5a 100644 --- a/sys/net/if_llatbl.h +++ b/sys/net/if_llatbl.h @@ -39,6 +39,9 @@ #include #include +#include +#include +#include #include @@ -49,7 +52,7 @@ struct rt_addrinfo; struct rt_walkarg; struct llentry; -LIST_HEAD(llentries, llentry); +LIST_HEAD(llentries, llentry); /* legacy - retained for lle_chain */ extern krwlock_t lltable_rwlock; #define LLTABLE_RLOCK() rw_enter(&lltable_rwlock, RW_READER) @@ -63,7 +66,7 @@ extern krwlock_t lltable_rwlock; * a shared lock */ struct llentry { - LIST_ENTRY(llentry) lle_next; + struct pslist_entry lle_next; /* hash chain (pserialize-safe) */ union l3addr { struct in_addr addr4; struct in6_addr addr6; @@ -77,7 +80,7 @@ struct llentry { uint64_t spare1; struct lltable *lle_tbl; - struct llentries *lle_head; + struct pslist_head *lle_head; /* back-pointer to hash bucket */ void (*lle_free)(struct llentry *); void (*lle_ll_free)(struct llentry *); struct mbuf *la_hold; @@ -216,8 +219,12 @@ struct lltable { SLIST_ENTRY(lltable) llt_link; int llt_af; int llt_hsize; - struct llentries *lle_head; + struct pslist_head *lle_head; /* hash buckets (pserialize-safe) */ unsigned int llt_lle_count; + pserialize_t llt_psz; /* pserialize for lock-free reads */ + kmutex_t llt_destroy_lock; + LIST_HEAD(, llentry) llt_destroy_list; /* batched deferred frees */ + struct work llt_destroy_work; struct ifnet *llt_ifp; llt_lookup_t *llt_lookup; diff --git a/sys/net/link_proto.c b/sys/net/link_proto.c index 6455acc4e629..a5daa6becb50 100644 --- a/sys/net/link_proto.c +++ b/sys/net/link_proto.c @@ -262,7 +262,10 @@ link_control(struct socket *so, unsigned long cmd, void *data, static int link_attach(struct socket *so, int proto) { - sosetlock(so); + if (so->so_lock == NULL) { + so->so_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + mutex_enter(so->so_lock); + } KASSERT(solocked(so)); return 0; } diff --git a/sys/net/npf/npf_socket.c b/sys/net/npf/npf_socket.c index 853e3c778adf..201e00d8480c 100644 --- a/sys/net/npf/npf_socket.c +++ b/sys/net/npf/npf_socket.c @@ -167,6 +167,7 @@ npf_ip_socket(npf_cache_t *npc, int dir) } so = inp->inp_socket; + inp_lookup_unlock(inp); return so; } @@ -229,6 +230,7 @@ npf_ip6_socket(npf_cache_t *npc, int dir) } so = in6p->inp_socket; + inp_lookup_unlock(in6p); return so; } #endif diff --git a/sys/net/raw_cb.c b/sys/net/raw_cb.c index f9543d4e1229..f55d8c95259a 100644 --- a/sys/net/raw_cb.c +++ b/sys/net/raw_cb.c @@ -77,7 +77,10 @@ raw_attach(struct socket *so, int proto, struct rawcbhead *rawcbhead) */ rp = sotorawcb(so); KASSERT(rp != NULL); - sosetlock(so); + if (so->so_lock == NULL) { + so->so_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + mutex_enter(so->so_lock); + } if ((error = soreserve(so, raw_sendspace, raw_recvspace)) != 0) { return error; diff --git a/sys/net/raw_usrreq.c b/sys/net/raw_usrreq.c index 716e2d5ce437..ef62d64828ae 100644 --- a/sys/net/raw_usrreq.c +++ b/sys/net/raw_usrreq.c @@ -95,7 +95,10 @@ raw_input(struct mbuf *m0, struct sockproto *proto, struct sockaddr *src, continue; if (last != NULL) { struct mbuf *n; + bool locked = solocked(last); + if (!locked) + solock(last); if ((n = m_copypacket(m, M_DONTWAIT)) == NULL || sbappendaddr(&last->so_rcv, src, n, NULL) == 0) { @@ -103,15 +106,23 @@ raw_input(struct mbuf *m0, struct sockproto *proto, struct sockaddr *src, soroverflow(last); } else sorwakeup(last); + if (!locked) + sounlock(last); } last = rp->rcb_socket; } if (last != NULL) { + bool locked = solocked(last); + + if (!locked) + solock(last); if (sbappendaddr(&last->so_rcv, src, m, NULL) == 0) { m_freem(m); soroverflow(last); } else sorwakeup(last); + if (!locked) + sounlock(last); } else { m_freem(m); } diff --git a/sys/netatalk/ddp_input.c b/sys/netatalk/ddp_input.c index 0a1297f9dd5f..ccb3a86a6b42 100644 --- a/sys/netatalk/ddp_input.c +++ b/sys/netatalk/ddp_input.c @@ -308,10 +308,12 @@ ddp_input(struct mbuf *m, struct ifnet *ifp, struct elaphdr *elh, int phase) m_freem(m); return; } + solock(ddp->ddp_socket); if (sbappendaddr(&ddp->ddp_socket->so_rcv, (struct sockaddr *) & from, m, (struct mbuf *) 0) == 0) { DDP_STATINC(DDP_STAT_NOSOCKSPACE); soroverflow(ddp->ddp_socket); + sounlock(ddp->ddp_socket); m_freem(m); return; } @@ -320,6 +322,7 @@ ddp_input(struct mbuf *m, struct ifnet *ifp, struct elaphdr *elh, int phase) aa->aa_ifa.ifa_data.ifad_inbytes += dlen; #endif sorwakeup(ddp->ddp_socket); + sounlock(ddp->ddp_socket); } #if 0 diff --git a/sys/netatalk/ddp_usrreq.c b/sys/netatalk/ddp_usrreq.c index 27dc052afa90..3662f9b271ee 100644 --- a/sys/netatalk/ddp_usrreq.c +++ b/sys/netatalk/ddp_usrreq.c @@ -287,7 +287,10 @@ ddp_attach(struct socket *so, int proto) int error; KASSERT(sotoddpcb(so) == NULL); - sosetlock(so); + if (so->so_lock == NULL) { + so->so_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + mutex_enter(so->so_lock); + } #ifdef MBUFTRACE so->so_rcv.sb_mowner = &atalk_rx_mowner; so->so_snd.sb_mowner = &atalk_tx_mowner; @@ -617,6 +620,9 @@ ddp_init(void) MOWNER_ATTACH(&aarp_mowner); } +#ifdef NET_MPSAFE +/* No PR_WRAP_USRREQS: per-socket lock, no KERNEL_LOCK needed */ +#else PR_WRAP_USRREQS(ddp) #define ddp_attach ddp_attach_wrapper #define ddp_detach ddp_detach_wrapper @@ -637,6 +643,7 @@ PR_WRAP_USRREQS(ddp) #define ddp_send ddp_send_wrapper #define ddp_sendoob ddp_sendoob_wrapper #define ddp_purgeif ddp_purgeif_wrapper +#endif const struct pr_usrreqs ddp_usrreqs = { .pr_attach = ddp_attach, diff --git a/sys/netcan/can.c b/sys/netcan/can.c index 1eae6d5e69f8..31b09da76300 100644 --- a/sys/netcan/can.c +++ b/sys/netcan/can.c @@ -386,6 +386,7 @@ canintr(void *arg __unused) mc = m; m = NULL; } + solock(canp->canp_socket); if (sbappendaddr(&canp->canp_socket->so_rcv, (struct sockaddr *) &from, mc, (struct mbuf *) 0) == 0) { @@ -393,6 +394,7 @@ canintr(void *arg __unused) m_freem(mc); } else sorwakeup(canp->canp_socket); + sounlock(canp->canp_socket); mutex_exit(&canp->canp_mtx); if (m == NULL) break; @@ -431,8 +433,11 @@ can_attach(struct socket *so, int proto) KASSERT(sotocanpcb(so) == NULL); - /* Assign the lock (must happen even if we will error out). */ - sosetlock(so); + /* Assign a per-socket lock. */ + if (so->so_lock == NULL) { + so->so_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + mutex_enter(so->so_lock); + } #ifdef MBUFTRACE so->so_mowner = &can_mowner; @@ -814,6 +819,7 @@ release: #endif #if 0 +/* XXXKB this needs to be refactored to use MPSAFE approach if desired */ static void can_notify(struct canpcb *canp, int errno) { @@ -959,6 +965,9 @@ can_ctloutput(int op, struct socket *so, struct sockopt *sopt) return error; } +#ifdef NET_MPSAFE +/* No PR_WRAP_USRREQS: per-socket lock, no KERNEL_LOCK needed */ +#else PR_WRAP_USRREQS(can) #define can_attach can_attach_wrapper #define can_detach can_detach_wrapper @@ -979,6 +988,7 @@ PR_WRAP_USRREQS(can) #define can_send can_send_wrapper #define can_sendoob can_sendoob_wrapper #define can_purgeif can_purgeif_wrapper +#endif const struct pr_usrreqs can_usrreqs = { .pr_attach = can_attach, diff --git a/sys/netinet/dccp_usrreq.c b/sys/netinet/dccp_usrreq.c index eee6e35f37f2..d555d612ce06 100644 --- a/sys/netinet/dccp_usrreq.c +++ b/sys/netinet/dccp_usrreq.c @@ -121,10 +121,10 @@ __KERNEL_RCSID(0, "$NetBSD: dccp_usrreq.c,v 1.27 2024/07/05 04:31:54 rin Exp $") #define INP_INFO_WUNLOCK(x) #define INP_INFO_RLOCK(x) #define INP_INFO_RUNLOCK(x) -#define INP_LOCK(x) -#define IN6P_LOCK(x) -#define INP_UNLOCK(x) -#define IN6P_UNLOCK(x) +#define INP_LOCK(x) solock((x)->inp_socket) +#define IN6P_LOCK(x) solock((x)->inp_socket) +#define INP_UNLOCK(x) sounlock((x)->inp_socket) +#define IN6P_UNLOCK(x) sounlock((x)->inp_socket) /* Congestion control switch table */ extern struct dccp_cc_sw cc_sw[]; @@ -406,7 +406,7 @@ dccp_input(struct mbuf *m, int off, int proto) INP_INFO_WUNLOCK(&dccpbinfo); goto badunlocked; } - INP_LOCK(inp); + /* solock held from inpcb_lookup */ dp = intodccpcb(inp); if (dp == NULL) { @@ -447,8 +447,10 @@ dccp_input(struct mbuf *m, int off, int proto) in6p_faddr(inp) = ip6->ip6_src; inp->inp_lport = dh->dh_dport; inp->inp_fport = dh->dh_sport; + INP_HASH_LOCK(inp->inp_table); inpcb_set_state(inp, INP_CONNECTED); - } else + INP_HASH_UNLOCK(inp->inp_table); + } else #endif { inp = sotoinpcb(so); @@ -458,8 +460,11 @@ dccp_input(struct mbuf *m, int off, int proto) inp->inp_fport = dh->dh_sport; } - if (!isipv6) + if (!isipv6) { + INP_HASH_LOCK(inp->inp_table); inpcb_set_state(inp, INP_BOUND); + INP_HASH_UNLOCK(inp->inp_table); + } dp = inp->inp_ppcb; @@ -637,7 +642,7 @@ dccp_input(struct mbuf *m, int off, int proto) DCCP_DEBUG((LOG_INFO, "Got DCCP RESET\n")); dp->state = DCCPS_TIME_WAIT; dp = dccp_close(dp); - return; + goto badunlocked; default: DCCP_DEBUG((LOG_INFO, "Got a %u packet while in listen stage!\n", dh->dh_type)); @@ -672,7 +677,7 @@ dccp_input(struct mbuf *m, int off, int proto) DCCP_DEBUG((LOG_INFO, "Got DCCP RESET\n")); dp->state = DCCPS_TIME_WAIT; dp = dccp_close(dp); - return; + goto badunlocked; default: DCCP_DEBUG((LOG_INFO, "Got a %u packet while in REQUEST stage!\n", dh->dh_type)); @@ -680,7 +685,7 @@ dccp_input(struct mbuf *m, int off, int proto) dccp_output(dp, DCCP_TYPE_RESET + 2); if (dh->dh_type == DCCP_TYPE_CLOSE) { dp = dccp_close(dp); - return; + goto badunlocked; } else { callout_stop(&dp->retrans_timer); dp->state = DCCPS_TIME_WAIT; @@ -895,10 +900,8 @@ dccp_input(struct mbuf *m, int off, int proto) m_freem(m); m_freem(opts); } -#if defined(__FreeBSD__) && __FreeBSD_version >= 500000 if (dp) INP_UNLOCK(inp); -#endif return; @@ -915,10 +918,13 @@ badunlocked: void dccp_notify(struct inpcb *inp, int errno) { - inp->inp_socket->so_error = errno; - sorwakeup(inp->inp_socket); - sowwakeup(inp->inp_socket); - return; + struct socket *so = inp->inp_socket; + + solock(so); + so->so_error = errno; + sorwakeup(so); + sowwakeup(so); + sounlock(so); } /* @@ -946,20 +952,20 @@ dccp_ctlinput(int cmd, const struct sockaddr *sa, void *vip) else if ((unsigned)cmd >= PRC_NCMDS || inetctlerrmap[cmd] == 0) return NULL; if (ip) { - /*s = splsoftnet();*/ dh = (struct dccphdr *)((vaddr_t)ip + (ip->ip_hl << 2)); - INP_INFO_RLOCK(&dccpbinfo); - inpcb_notify(&dccpbtable, faddr, dh->dh_dport, - ip->ip_src, dh->dh_sport, inetctlerrmap[cmd], notify); + inp = inpcb_lookup(&dccpbtable, faddr, dh->dh_dport, + ip->ip_src, dh->dh_sport, NULL); if (inp != NULL) { - INP_LOCK(inp); - if (inp->inp_socket != NULL) { - (*notify)(inp, inetctlerrmap[cmd]); - } - INP_UNLOCK(inp); + kmutex_t *lock = inp->inp_socket->so_lock; + mutex_obj_hold(lock); + bool acquired __diagused = inpcb_ref_acquire(inp); + KASSERT(acquired); + mutex_exit(lock); + (*notify)(inp, inetctlerrmap[cmd]); + if (inpcb_ref_release(inp)) + inpcb_pool_put(inp); + mutex_obj_free(lock); } - INP_INFO_RUNLOCK(&dccpbinfo); - /*splx(s);*/ } else inpcb_notifyall(&dccpbtable, faddr, inetctlerrmap[cmd], notify); @@ -1575,11 +1581,11 @@ dccp_close(struct dccpcb *dp) DCCP_DEBUG((LOG_INFO, "Entering dccp_close!\n")); - /* Stop all timers */ - callout_stop(&dp->connect_timer); - callout_stop(&dp->retrans_timer); - callout_stop(&dp->close_timer); - callout_stop(&dp->timewait_timer); + /* Halt all timers, must wait for running callbacks on other CPUs */ + callout_halt(&dp->connect_timer, so->so_lock); + callout_halt(&dp->retrans_timer, so->so_lock); + callout_halt(&dp->close_timer, so->so_lock); + callout_halt(&dp->timewait_timer, so->so_lock); if (dp->cc_in_use[0] > 0) (*cc_sw[dp->cc_in_use[0]].cc_send_free)(dp->cc_state[0]); @@ -1607,7 +1613,14 @@ dccp_attach(struct socket *so, int proto) DCCP_DEBUG((LOG_INFO, "Entering dccp_attach(proto=%d)!\n", proto)); INP_INFO_WLOCK(&dccpbinfo); s = splsoftnet(); - sosetlock(so); + /* + * Assign a per-socket lock. Each DCCP socket gets its own mutex + * so that connections on different CPUs can be processed in parallel. + */ + if (so->so_lock == NULL) { + so->so_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + mutex_enter(so->so_lock); + } inp = sotoinpcb(so); if (inp != 0) { @@ -1808,7 +1821,7 @@ dccp_doconnect(struct socket *so, struct sockaddr *nam, * Detaches the DCCP protocol from the socket. * */ -int +void dccp_detach(struct socket *so) { struct inpcb *inp; @@ -1817,14 +1830,13 @@ dccp_detach(struct socket *so) DCCP_DEBUG((LOG_INFO, "Entering dccp_detach!\n")); inp = sotoinpcb(so); if (inp == NULL) { - return EINVAL; + return; } dp = inp->inp_ppcb; if (! dccp_disconnect2(dp)) { INP_UNLOCK(inp); } INP_INFO_WUNLOCK(&dccpbinfo); - return 0; } /* @@ -2630,22 +2642,18 @@ void dccp_retrans_t(void *dcb) { struct dccpcb *dp = dcb; - /*struct inpcb *inp;*/ + struct inpcb *inp; DCCP_DEBUG((LOG_INFO, "Entering dccp_retrans_t!\n")); - mutex_enter(softnet_lock); INP_INFO_RLOCK(&dccpbinfo); - /*inp = dp->d_inpcb;*/ + inp = dp->d_inpcb; INP_LOCK(inp); INP_INFO_RUNLOCK(&dccpbinfo); callout_stop(&dp->retrans_timer); - KERNEL_LOCK(1, NULL); dccp_output(dp, 0); - KERNEL_UNLOCK_ONE(NULL); dp->retrans = dp->retrans * 2; callout_reset(&dp->retrans_timer, dp->retrans, dccp_retrans_t, dp); INP_UNLOCK(inp); - mutex_exit(softnet_lock); } static int @@ -3095,6 +3103,15 @@ SYSCTL_SETUP(sysctl_net_inet_dccp_setup, "sysctl net.inet.dccp subtree setup") CTL_EOL); } +#ifdef NET_MPSAFE +/* + * DCCP is fully MP-safe with per-socket locking. All usrreqs run + * under solock (held by the socket layer), so KERNEL_LOCK wrappers + * must NOT be used. Acquiring KERNEL_LOCK inside solock inverts + * the lock order vs ctlinput (which takes KERNEL_LOCK then solock). + */ +/* No PR_WRAP_USRREQS, all functions used directly */ +#else PR_WRAP_USRREQS(dccp) #define dccp_attach dccp_attach_wrapper #define dccp_detach dccp_detach_wrapper @@ -3115,6 +3132,7 @@ PR_WRAP_USRREQS(dccp) #define dccp_send dccp_send_wrapper #define dccp_sendoob dccp_sendoob_wrapper #define dccp_purgeif dccp_purgeif_wrapper +#endif const struct pr_usrreqs dccp_usrreqs = { .pr_attach = dccp_attach, diff --git a/sys/netinet/dccp_var.h b/sys/netinet/dccp_var.h index 61017f8f9394..813c4ffe6ea0 100644 --- a/sys/netinet/dccp_var.h +++ b/sys/netinet/dccp_var.h @@ -295,7 +295,7 @@ int dccp_doconnect(struct socket *, struct sockaddr *, struct lwp *, int); int dccp_add_option(struct dccpcb *, u_int8_t, char *, u_int8_t); int dccp_add_feature(struct dccpcb *, u_int8_t, u_int8_t, char *, u_int8_t); -int dccp_detach(struct socket *); +void dccp_detach(struct socket *); int dccp_attach(struct socket *, int); int dccp_abort(struct socket *); int dccp_disconnect(struct socket *); diff --git a/sys/netinet/if_arp.c b/sys/netinet/if_arp.c index 4dca12f34a8c..2d3d91fcb27e 100644 --- a/sys/netinet/if_arp.c +++ b/sys/netinet/if_arp.c @@ -155,8 +155,9 @@ static struct nd_domain arp_nd_domain = { .nd_umaxtries = 3, /* maximum unicast query */ .nd_retransmultiple = BACKOFF_MULTIPLE, .nd_maxretrans = MAX_RETRANS_TIMER, - .nd_maxnudhint = 0, /* max # of subsequent upper layer hints */ - .nd_maxqueuelen = 1, /* max # of packets in unresolved ND entries */ + .nd_gctimer = 20*60, /* stale ARP entry GC: 20 minutes */ + .nd_maxnudhint = 5, /* max # of subsequent upper layer hints */ + .nd_maxqueuelen = 16, /* max # of packets in unresolved ARP entries */ .nd_nud_enabled = arp_nud_enabled, .nd_reachable = arp_llinfo_reachable, .nd_retrans = arp_llinfo_retrans, @@ -598,6 +599,29 @@ arpannounce1(struct ifaddr *ifa) * packet has been held pending resolution. Any other value indicates an * error. */ + +/* + * Lock-free ARP lookup for the fast path. Must be called inside a + * pserialize read section. Returns the llentry if found, NULL otherwise. + */ +static struct llentry * +arplookup_psz(struct ifnet *ifp, const struct sockaddr *sa) +{ + struct lltable *llt = LLTABLE(ifp); + struct in_addr dst = satocsin(sa)->sin_addr; + uint32_t hashidx = in_lltable_hash_dst(dst, llt->llt_hsize); + struct pslist_head *head = &llt->lle_head[hashidx]; + struct llentry *lle; + + PSLIST_READER_FOREACH(lle, head, struct llentry, lle_next) { + if (lle->la_flags & LLE_DELETED) + continue; + if (lle->r_l3addr.addr4.s_addr == dst.s_addr) + return lle; + } + return NULL; +} + int arpresolve(struct ifnet *ifp, const struct rtentry *rt, struct mbuf *m, const struct sockaddr *dst, void *desten, size_t destlen) @@ -605,6 +629,7 @@ arpresolve(struct ifnet *ifp, const struct rtentry *rt, struct mbuf *m, struct llentry *la; const char *create_lookup; int error; + int s; #if NCARP > 0 if (rt != NULL && rt->rt_ifp->if_type == IFT_CARP) @@ -613,25 +638,38 @@ arpresolve(struct ifnet *ifp, const struct rtentry *rt, struct mbuf *m, KASSERT(m != NULL); - la = arplookup(ifp, NULL, dst, 0); - if (la == NULL) - goto notfound; - - if (la->la_flags & LLE_VALID && la->ln_state == ND_LLINFO_REACHABLE) { + /* + * Fast path: pserialize read section, no locks. + * If the entry exists and is REACHABLE, just copy the MAC. + * pserialize + deferred free guarantees the entry won't be + * freed under us. + */ + s = pserialize_read_enter(); + la = arplookup_psz(ifp, dst); + if (la != NULL && + (la->la_flags & LLE_VALID) != 0 && + la->ln_state == ND_LLINFO_REACHABLE) { KASSERT(destlen >= ifp->if_addrlen); + membar_acquire(); memcpy(desten, &la->ll_addr, ifp->if_addrlen); - LLE_RUNLOCK(la); + pserialize_read_exit(s); return 0; } + pserialize_read_exit(s); -notfound: + /* + * Slow path: take locks, create or update entry. + */ if (ifp->if_flags & IFF_NOARP) { - if (la != NULL) - LLE_RUNLOCK(la); error = ENOTSUP; goto bad; } + /* Look up with write lock for the slow path */ + IF_AFDATA_RLOCK(ifp); + la = lla_lookup(LLTABLE(ifp), LLE_EXCLUSIVE, dst); + IF_AFDATA_RUNLOCK(ifp); + if (la == NULL) { struct rtentry *_rt; @@ -646,12 +684,6 @@ notfound: ARP_STATINC(ARP_STAT_ALLOCFAIL); else la->ln_state = ND_LLINFO_NOSTATE; - } else if (LLE_TRY_UPGRADE(la) == 0) { - create_lookup = "lookup"; - LLE_RUNLOCK(la); - IF_AFDATA_RLOCK(ifp); - la = lla_lookup(LLTABLE(ifp), LLE_EXCLUSIVE, dst); - IF_AFDATA_RUNLOCK(ifp); } error = EINVAL; @@ -1055,6 +1087,7 @@ again: KASSERT(sizeof(la->ll_addr) >= ifp->if_addrlen); memcpy(&la->ll_addr, ar_sha(ah), ifp->if_addrlen); + membar_release(); la->la_flags |= LLE_VALID; la->la_flags &= ~LLE_UNRESOLVED; la->ln_asked = 0; diff --git a/sys/netinet/in.c b/sys/netinet/in.c index 02f0ff4f4582..50b310583230 100644 --- a/sys/netinet/in.c +++ b/sys/netinet/in.c @@ -2142,13 +2142,6 @@ error: return error; } -static inline uint32_t -in_lltable_hash_dst(const struct in_addr dst, uint32_t hsize) -{ - - return (IN_LLTBL_HASH(dst.s_addr, hsize)); -} - static uint32_t in_lltable_hash(const struct llentry *lle, uint32_t hsize) { @@ -2172,12 +2165,12 @@ static inline struct llentry * in_lltable_find_dst(struct lltable *llt, struct in_addr dst) { struct llentry *lle; - struct llentries *lleh; + struct pslist_head *lleh; u_int hashidx; hashidx = in_lltable_hash_dst(dst, llt->llt_hsize); lleh = &llt->lle_head[hashidx]; - LIST_FOREACH(lle, lleh, lle_next) { + PSLIST_WRITER_FOREACH(lle, lleh, struct llentry, lle_next) { if (lle->la_flags & LLE_DELETED) continue; if (lle->r_l3addr.addr4.s_addr == dst.s_addr) diff --git a/sys/netinet/in_pcb.c b/sys/netinet/in_pcb.c index 6cf3426fee53..f21bb029613c 100644 --- a/sys/netinet/in_pcb.c +++ b/sys/netinet/in_pcb.c @@ -102,6 +102,8 @@ __KERNEL_RCSID(0, "$NetBSD: in_pcb.c,v 1.202 2022/11/04 09:05:41 ozaki-r Exp $") #include #include +#include +#include #include #include #include @@ -160,6 +162,9 @@ static pool_cache_t in4pcb_pool_cache; #ifdef INET6 static pool_cache_t in6pcb_pool_cache; #endif +static struct workqueue *inpcb_destroy_wq; + +static void inpcb_destroy_cb(struct work *, void *); static int inpcb_poolinit(void) @@ -171,6 +176,9 @@ inpcb_poolinit(void) in6pcb_pool_cache = pool_cache_init(sizeof(struct in6pcb), coherency_unit, 0, 0, "in6pcbpl", NULL, IPL_NET, NULL, NULL, NULL); #endif + if (workqueue_create(&inpcb_destroy_wq, "inpcbfree", + inpcb_destroy_cb, NULL, PRI_SOFTNET, IPL_SOFTNET, WQ_MPSAFE)) + panic("inpcb_poolinit: workqueue_create failed"); return 0; } @@ -180,12 +188,15 @@ inpcb_init(struct inpcbtable *table, int bindhashsize, int connecthashsize) static ONCE_DECL(control); TAILQ_INIT(&table->inpt_queue); - table->inpt_porthashtbl = hashinit(bindhashsize, HASH_LIST, true, + PSLIST_INIT(&table->inpt_queue_pslist); + table->inpt_porthashtbl = hashinit(bindhashsize, HASH_PSLIST, true, &table->inpt_porthash); - table->inpt_bindhashtbl = hashinit(bindhashsize, HASH_LIST, true, + table->inpt_bindhashtbl = hashinit(bindhashsize, HASH_PSLIST, true, &table->inpt_bindhash); - table->inpt_connecthashtbl = hashinit(connecthashsize, HASH_LIST, true, + table->inpt_connecthashtbl = hashinit(connecthashsize, HASH_PSLIST, true, &table->inpt_connecthash); + mutex_init(&table->inpt_hash_lock, MUTEX_DEFAULT, IPL_SOFTNET); + table->inpt_psz = pserialize_create(); table->inpt_lastlow = IPPORT_RESERVEDMAX; table->inpt_lastport = (in_port_t)anonportmax; @@ -227,6 +238,12 @@ inpcb_create(struct socket *so, void *v) inp->inp_socket = so; inp->inp_portalgo = PORTALGO_DEFAULT; inp->inp_bindportonsend = false; + INP_LOCK_INIT(inp); + inp->inp_refcount = 1; + PSLIST_ENTRY_INIT(inp, inp_bind_hash); + PSLIST_ENTRY_INIT(inp, inp_connect_hash); + PSLIST_ENTRY_INIT(inp, inp_port_hash); + PSLIST_ENTRY_INIT(inp, inp_queue_hash); if (inp->inp_af == AF_INET) { in4p_errormtu(inp) = -1; @@ -259,10 +276,14 @@ inpcb_create(struct socket *so, void *v) #endif so->so_pcb = inp; s = splsoftnet(); + INP_HASH_LOCK(table); TAILQ_INSERT_HEAD(&table->inpt_queue, inp, inp_queue); - LIST_INSERT_HEAD(INPCBHASH_PORT(table, inp->inp_lport), inp, - inp_lhash); + PSLIST_WRITER_INSERT_HEAD(&table->inpt_queue_pslist, + inp, inp_queue_hash); + PSLIST_WRITER_INSERT_HEAD(INPCBHASH_PORT(table, inp->inp_lport), + inp, inp_port_hash); inpcb_set_state(inp, INP_ATTACHED); + INP_HASH_UNLOCK(table); splx(s); return 0; } @@ -308,7 +329,9 @@ inpcb_set_port(struct sockaddr_in *sin, struct inpcb *inp, kauth_cred_t cred) *lastport = lport; lport = htons(lport); inp->inp_lport = lport; + INP_HASH_LOCK(inp->inp_table); inpcb_set_state(inp, INP_BOUND); + INP_HASH_UNLOCK(inp->inp_table); return 0; } @@ -414,8 +437,12 @@ inpcb_bind_port(struct inpcb *inp, struct sockaddr_in *sin, kauth_cred_t cred) #ifdef INET6 in6_in_2_v4mapin6(&sin->sin_addr, &mapped); t6 = in6pcb_lookup_local(table, &mapped, sin->sin_port, wild, &vestige); - if (t6 && (reuseport & t6->inp_socket->so_options) == 0) + if (t6 && (reuseport & t6->inp_socket->so_options) == 0) { + inp_lookup_unlock(t6); return EADDRINUSE; + } + if (t6) + inp_lookup_unlock(t6); if (!t6 && vestige.valid) { if (!!reuseport != !!vestige.reuse_port) { return EADDRINUSE; @@ -436,8 +463,11 @@ inpcb_bind_port(struct inpcb *inp, struct sockaddr_in *sin, kauth_cred_t cred) !in_nullhost(in4p_laddr(t)) || (t->inp_socket->so_options & SO_REUSEPORT) == 0) && (so->so_uidinfo->ui_uid != t->inp_socket->so_uidinfo->ui_uid)) { + inp_lookup_unlock(t); return EADDRINUSE; } + if (t) + inp_lookup_unlock(t); if (!t && vestige.valid) { if ((!in_nullhost(sin->sin_addr) || !in_nullhost(vestige.laddr.v4) @@ -448,20 +478,29 @@ inpcb_bind_port(struct inpcb *inp, struct sockaddr_in *sin, kauth_cred_t cred) } } t = inpcb_lookup_local(table, sin->sin_addr, sin->sin_port, wild, &vestige); - if (t && (reuseport & t->inp_socket->so_options) == 0) + if (t && (reuseport & t->inp_socket->so_options) == 0) { + inp_lookup_unlock(t); return EADDRINUSE; + } + if (t) + inp_lookup_unlock(t); if (!t && vestige.valid && !(reuseport && vestige.reuse_port)) return EADDRINUSE; inp->inp_lport = sin->sin_port; + INP_HASH_LOCK(table); inpcb_set_state(inp, INP_BOUND); + INP_HASH_UNLOCK(table); } - LIST_REMOVE(inp, inp_lhash); - LIST_INSERT_HEAD(INPCBHASH_PORT(table, inp->inp_lport), inp, - inp_lhash); + INP_HASH_LOCK(table); + PSLIST_WRITER_REMOVE(inp, inp_port_hash); + PSLIST_ENTRY_INIT(inp, inp_port_hash); + PSLIST_WRITER_INSERT_HEAD(INPCBHASH_PORT(table, inp->inp_lport), + inp, inp_port_hash); + INP_HASH_UNLOCK(table); return 0; } @@ -523,6 +562,7 @@ int inpcb_connect(void *v, struct sockaddr_in *sin, struct lwp *l) { struct inpcb *inp = v; + struct inpcb *t; vestigial_inpcb_t vestige; int error; struct in_addr laddr; @@ -610,11 +650,15 @@ inpcb_connect(void *v, struct sockaddr_in *sin, struct lwp *l) curlwp_bindx(bound); } else laddr = in4p_laddr(inp); - if (inpcb_lookup(inp->inp_table, sin->sin_addr, sin->sin_port, - laddr, inp->inp_lport, &vestige) != NULL || - vestige.valid) { + + t = inpcb_lookup(inp->inp_table, sin->sin_addr, sin->sin_port, + laddr, inp->inp_lport, &vestige); + if (t != NULL) { + inp_lookup_unlock(t); return EADDRINUSE; } + if (vestige.valid) + return EADDRINUSE; if (in_nullhost(in4p_laddr(inp))) { if (inp->inp_lport == 0) { error = inpcb_bind(inp, NULL, l); @@ -643,7 +687,9 @@ inpcb_connect(void *v, struct sockaddr_in *sin, struct lwp *l) return error; } + INP_HASH_LOCK(inp->inp_table); inpcb_set_state(inp, INP_CONNECTED); + INP_HASH_UNLOCK(inp->inp_table); #if defined(IPSEC) if (ipsec_enabled && inp->inp_socket->so_type == SOCK_STREAM) ipsec_pcbconn(inp->inp_sp); @@ -666,7 +712,9 @@ inpcb_disconnect(void *v) in4p_faddr(inp) = zeroin_addr; inp->inp_fport = 0; + INP_HASH_LOCK(inp->inp_table); inpcb_set_state(inp, INP_BOUND); + INP_HASH_UNLOCK(inp->inp_table); #if defined(IPSEC) if (ipsec_enabled) ipsec_pcbdisconn(inp->inp_sp); @@ -677,6 +725,11 @@ inpcb_disconnect(void *v) /* * inpcb_destroy: destroy PCB as well as the associated socket. + * + * Immediately frees all resources (socket, options, routes, etc). + * The inpcb memory itself is deferred via workqueue until after a + * pserialize grace period, so that lock-free hash table readers + * can safely complete traversal of stale entries. */ void inpcb_destroy(void *v) @@ -686,19 +739,41 @@ inpcb_destroy(void *v) int s; KASSERT(inp->inp_af == AF_INET || inp->inp_af == AF_INET6); + KASSERT(so == NULL || solocked(so)); #if defined(IPSEC) if (ipsec_enabled) ipsec_delete_pcbpolicy(inp); #endif - so->so_pcb = NULL; + /* + * Sever socket->PCB link and handle socket lifecycle. + * inp_socket may be NULL if the caller (tcp_close) already + * detached it to handle concurrent soclose safely. + */ + if (so != NULL) { + so->so_pcb = NULL; + } + + /* Remove from all hash tables and mark dead */ s = splsoftnet(); + INP_HASH_LOCK(inp->inp_table); inpcb_set_state(inp, INP_ATTACHED); - LIST_REMOVE(inp, inp_lhash); + PSLIST_WRITER_REMOVE(inp, inp_port_hash); + PSLIST_WRITER_REMOVE(inp, inp_queue_hash); TAILQ_REMOVE(&inp->inp_table->inpt_queue, inp, inp_queue); + atomic_store_release(&inp->inp_state, INP_FREED); + INP_HASH_UNLOCK(inp->inp_table); splx(s); + /* + * Free resources immediately. INP_FREED is set with release + * semantics above. Pserialize readers that see INP_FREED skip + * the entry. Readers that don't see it yet will find tryenter + * fails (caller holds solock) or use the refcount fallback. + * The deferred free via workqueue ensures the inpcb memory + * survives until all pserialize readers are done. + */ if (inp->inp_options) { m_free(inp->inp_options); } @@ -713,7 +788,53 @@ inpcb_destroy(void *v) ip6_freemoptions(in6p_moptions(inp)); } #endif - sofree(so); /* drops the socket's lock */ + + /* + * Handle socket lock lifecycle. If the socket is still attached + * (inp_socket != NULL), call sofree to drop the lock reference. + * sofree reacquires the lock when the socket survives. + * If inp_socket is NULL (tcp_close detached it), skipm the + * caller handles the socket lifecycle directly. + */ + if (so != NULL) { + bool survives = (so->so_state & SS_NOFDREF) == 0; + kmutex_t *lock = so->so_lock; + if (survives) + mutex_obj_hold(lock); + sofree(so); /* drops the socket's lock */ + if (survives) { + mutex_enter(lock); + mutex_obj_free(lock); + } + } + + /* + * Defer only the inpcb struct memory free. Pserialize readers + * may still be traversing stale PSLIST entries that point into + * this inpcb. The memory must remain valid until the grace + * period completes. + */ + workqueue_enqueue(inpcb_destroy_wq, &inp->inp_destroy_work, NULL); +} + +/* + * inpcb_destroy_cb: workqueue callback for deferred inpcb memory free. + * Called after the inpcb has been removed from all hash tables and all + * resources freed. Waits for pserialize grace period, then frees memory. + */ +/* + * inpcb_pool_put: final free of inpcb memory. Called when the last + * reference is released and the pserialize grace period has completed. + */ +void +inpcb_pool_put(struct inpcb *inp) +{ + + PSLIST_ENTRY_DESTROY(inp, inp_bind_hash); + PSLIST_ENTRY_DESTROY(inp, inp_connect_hash); + PSLIST_ENTRY_DESTROY(inp, inp_port_hash); + PSLIST_ENTRY_DESTROY(inp, inp_queue_hash); + INP_LOCK_DESTROY(inp); #ifdef INET6 if (inp->inp_af == AF_INET) @@ -724,7 +845,25 @@ inpcb_destroy(void *v) KASSERT(inp->inp_af == AF_INET); pool_cache_put(in4pcb_pool_cache, inp); #endif - mutex_enter(softnet_lock); /* reacquire the softnet_lock */ +} + +static void +inpcb_destroy_cb(struct work *wk, void *arg) +{ + struct inpcb *inp = container_of(wk, struct inpcb, inp_destroy_work); + + /* Wait for all pserialize readers to drain */ + pserialize_perform(inp->inp_table->inpt_psz); + + /* + * Release the "existence" reference. If a lookup grabbed a + * reference before pserialize completed, it will call + * inpcb_pool_put when it releases the last ref. + */ + if (!inpcb_ref_release(inp)) + return; + + inpcb_pool_put(inp); } /* @@ -764,13 +903,17 @@ inpcb_fetch_peeraddr(struct inpcb *inp, struct sockaddr_in *sin) * report any errors for each matching socket. * * Must be called at splsoftnet. + * + * XXXKB No live callers remain all converted to direct inpcb_lookup + * Only dead reference in sys/netcan/can.c (#if 0). + * Candidate for removal. */ int inpcb_notify(struct inpcbtable *table, struct in_addr faddr, u_int fport_arg, struct in_addr laddr, u_int lport_arg, int errno, void (*notify)(struct inpcb *, int)) { - struct inpcbhead *head; + struct pslist_head *head; struct inpcb *inp; in_port_t fport = fport_arg, lport = lport_arg; int nmatch; @@ -780,7 +923,7 @@ inpcb_notify(struct inpcbtable *table, struct in_addr faddr, u_int fport_arg, nmatch = 0; head = INPCBHASH_CONNECT(table, faddr, fport, laddr, lport); - LIST_FOREACH(inp, head, inp_hash) { + PSLIST_WRITER_FOREACH(inp, head, struct inpcb, inp_connect_hash) { if (inp->inp_af != AF_INET) continue; @@ -795,21 +938,56 @@ inpcb_notify(struct inpcbtable *table, struct in_addr faddr, u_int fport_arg, return nmatch; } +/* + * Notify all connections matching faddr. Uses hand-over-hand + * refcount pattern (like FreeBSD's inp_next): ref both current + * and next inside pserialize, exit pserialize to call the blocking + * notify callback, re-enter pserialize and advance to next. + */ void inpcb_notifyall(struct inpcbtable *table, struct in_addr faddr, int errno, void (*notify)(struct inpcb *, int)) { - struct inpcb *inp; + struct inpcb *inp, *next; + int s; if (in_nullhost(faddr) || notify == NULL) return; - TAILQ_FOREACH(inp, &table->inpt_queue, inp_queue) { - if (inp->inp_af != AF_INET) + s = pserialize_read_enter(); + for (inp = PSLIST_READER_FIRST(&table->inpt_queue_pslist, + struct inpcb, inp_queue_hash); inp != NULL; inp = next) { + if (inp->inp_af != AF_INET || !in_hosteq(in4p_faddr(inp), + faddr)) { + next = PSLIST_READER_NEXT(inp, struct inpcb, + inp_queue_hash); continue; - if (in_hosteq(in4p_faddr(inp), faddr)) - (*notify)(inp, errno); + } + /* Pin next element before exiting pserialize */ + next = PSLIST_READER_NEXT(inp, struct inpcb, inp_queue_hash); + while (next != NULL && !inpcb_ref_acquire(next)) + next = PSLIST_READER_NEXT(next, struct inpcb, + inp_queue_hash); + if (!inpcb_ref_acquire(inp)) { + /* Current dying, skip to next anchor */ + if (next != NULL && inpcb_ref_release(next)) + inpcb_pool_put(next); + break; + } + pserialize_read_exit(s); + + (*notify)(inp, errno); + if (inpcb_ref_release(inp)) + inpcb_pool_put(inp); + + s = pserialize_read_enter(); + /* Release next's ref now that we're back in pserialize */ + if (next != NULL) { + bool last __diagused = inpcb_ref_release(next); + KASSERT(!last); + } } + pserialize_read_exit(s); } void @@ -849,7 +1027,9 @@ inpcb_purgeif0(struct inpcbtable *table, struct ifnet *ifp) { struct inpcb *inp; - TAILQ_FOREACH(inp, &table->inpt_queue, inp_queue) { + INP_HASH_LOCK(table); + PSLIST_WRITER_FOREACH(inp, &table->inpt_queue_pslist, + struct inpcb, inp_queue_hash) { bool need_unlock = false; if (inp->inp_af != AF_INET) @@ -866,6 +1046,7 @@ inpcb_purgeif0(struct inpcbtable *table, struct ifnet *ifp) if (need_unlock) inp_unlock(inp); } + INP_HASH_UNLOCK(table); } void @@ -873,8 +1054,9 @@ inpcb_purgeif(struct inpcbtable *table, struct ifnet *ifp) { struct rtentry *rt; struct inpcb *inp; - - TAILQ_FOREACH(inp, &table->inpt_queue, inp_queue) { + INP_HASH_LOCK(table); + PSLIST_WRITER_FOREACH(inp, &table->inpt_queue_pslist, + struct inpcb, inp_queue_hash) { if (inp->inp_af != AF_INET) continue; if ((rt = rtcache_validate(&inp->inp_route)) != NULL && @@ -884,6 +1066,7 @@ inpcb_purgeif(struct inpcbtable *table, struct ifnet *ifp) } else rtcache_unref(rt, &inp->inp_route); } + INP_HASH_UNLOCK(table); } /* @@ -941,8 +1124,6 @@ inpcb_rtchange(struct inpcb *inp, int errno) return; rtcache_free(&inp->inp_route); - - /* XXX SHOULD NOTIFY HIGHER-LEVEL PROTOCOLS */ } /* @@ -954,18 +1135,22 @@ struct inpcb * inpcb_lookup_local(struct inpcbtable *table, struct in_addr laddr, u_int lport_arg, int lookup_wildcard, vestigial_inpcb_t *vp) { - struct inpcbhead *head; + struct pslist_head *head; struct inpcb *inp; struct inpcb *match = NULL; int matchwild = 3; int wildcard; in_port_t lport = lport_arg; + int s; if (vp) vp->valid = 0; head = INPCBHASH_PORT(table, lport); - LIST_FOREACH(inp, head, inp_lhash) { + s = pserialize_read_enter(); + PSLIST_READER_FOREACH(inp, head, struct inpcb, inp_port_hash) { + if (atomic_load_acquire(&inp->inp_state) == INP_FREED) + continue; if (inp->inp_af != AF_INET) continue; if (inp->inp_lport != lport) @@ -1007,11 +1192,26 @@ inpcb_lookup_local(struct inpcbtable *table, struct in_addr laddr, break; } } + if (match != NULL) { + struct socket *so = match->inp_socket; + if (__predict_true(so != NULL && + mutex_tryenter(so->so_lock))) { + if (__predict_false(atomic_load_acquire( + &match->inp_state) == INP_FREED)) { + mutex_exit(so->so_lock); + match = NULL; + } + } else { + match = NULL; + } + } + pserialize_read_exit(s); if (match && matchwild == 0) return match; if (vp && table->vestige) { - void *state = (*table->vestige->init_ports4)(laddr, lport_arg, lookup_wildcard); + struct tcp_ports_iterator ports_it; + void *state = (*table->vestige->init_ports4)(laddr, lport_arg, lookup_wildcard, &ports_it); vestigial_inpcb_t better; bool has_better = false; @@ -1047,6 +1247,9 @@ inpcb_lookup_local(struct inpcbtable *table, struct in_addr laddr, } if (has_better) { + /* Release solock on live match before returning vestige */ + if (match != NULL) + sounlock(match->inp_socket); *vp = better; return 0; } @@ -1068,24 +1271,81 @@ inpcb_lookup(struct inpcbtable *table, struct in_addr laddr, u_int lport_arg, vestigial_inpcb_t *vp) { - struct inpcbhead *head; + struct pslist_head *head; struct inpcb *inp; in_port_t fport = fport_arg, lport = lport_arg; + int s; if (vp) vp->valid = 0; head = INPCBHASH_CONNECT(table, faddr, fport, laddr, lport); - LIST_FOREACH(inp, head, inp_hash) { + s = pserialize_read_enter(); + PSLIST_READER_FOREACH(inp, head, struct inpcb, inp_connect_hash) { if (inp->inp_af != AF_INET) continue; + if (atomic_load_acquire(&inp->inp_state) == INP_FREED) + continue; if (in_hosteq(in4p_faddr(inp), faddr) && inp->inp_fport == fport && inp->inp_lport == lport && - in_hosteq(in4p_laddr(inp), laddr)) - goto out; + in_hosteq(in4p_laddr(inp), laddr)) { + struct socket *so = inp->inp_socket; + if (__predict_true(so != NULL && + mutex_tryenter(so->so_lock))) { + if (__predict_false(atomic_load_acquire( + &inp->inp_state) == INP_FREED)) { + mutex_exit(so->so_lock); + inp = NULL; + } + break; + } + /* + * Exact 4-tuple match but tryenter failed + * (lock contended). Grab a refcount on the + * inpcb to keep it alive across the blocking + * lock acquire outside pserialize. + */ + if (!inpcb_ref_acquire(inp)) + continue; + { + kmutex_t *lock = so->so_lock; + bool last; + mutex_obj_hold(lock); + pserialize_read_exit(s); + mutex_enter(lock); + /* + * Verify we hold the correct lock, solockreset + * in tcp_accept could have changed so->so_lock. + */ + if (__predict_false(lock != so->so_lock)) { + mutex_exit(lock); + mutex_obj_free(lock); + if (inpcb_ref_release(inp)) + inpcb_pool_put(inp); + return 0; + } + if (__predict_false(atomic_load_acquire( + &inp->inp_state) == INP_FREED)) { + mutex_exit(lock); + mutex_obj_free(lock); + if (inpcb_ref_release(inp)) + inpcb_pool_put(inp); + return 0; + } + mutex_obj_free(lock); + last = inpcb_ref_release(inp); + KASSERT(!last); + return inp; + } + } } + pserialize_read_exit(s); + + if (inp != NULL) + return inp; + if (vp && table->vestige) { if ((*table->vestige->lookup4)(faddr, fport_arg, laddr, lport_arg, vp)) @@ -1100,14 +1360,6 @@ inpcb_lookup(struct inpcbtable *table, } #endif return 0; - -out: - /* Move this PCB to the head of hash chain. */ - if (inp != LIST_FIRST(head)) { - LIST_REMOVE(inp, inp_hash); - LIST_INSERT_HEAD(head, inp, inp_hash); - } - return inp; } /* @@ -1118,28 +1370,62 @@ struct inpcb * inpcb_lookup_bound(struct inpcbtable *table, struct in_addr laddr, u_int lport_arg) { - struct inpcbhead *head; + struct pslist_head *head; struct inpcb *inp; in_port_t lport = lport_arg; + int s; + s = pserialize_read_enter(); head = INPCBHASH_BIND(table, laddr, lport); - LIST_FOREACH(inp, head, inp_hash) { + PSLIST_READER_FOREACH(inp, head, struct inpcb, inp_bind_hash) { if (inp->inp_af != AF_INET) continue; + if (atomic_load_acquire(&inp->inp_state) == INP_FREED) + continue; if (inp->inp_lport == lport && - in_hosteq(in4p_laddr(inp), laddr)) - goto out; + in_hosteq(in4p_laddr(inp), laddr)) { + struct socket *so = inp->inp_socket; + if (__predict_true(so != NULL && + mutex_tryenter(so->so_lock))) { + if (__predict_false(atomic_load_acquire( + &inp->inp_state) == INP_FREED)) { + mutex_exit(so->so_lock); + continue; + } + goto out; + } + continue; + } } head = INPCBHASH_BIND(table, zeroin_addr, lport); - LIST_FOREACH(inp, head, inp_hash) { + PSLIST_READER_FOREACH(inp, head, struct inpcb, inp_bind_hash) { if (inp->inp_af != AF_INET) continue; + if (atomic_load_acquire(&inp->inp_state) == INP_FREED) + continue; if (inp->inp_lport == lport && - in_hosteq(in4p_laddr(inp), zeroin_addr)) - goto out; + in_hosteq(in4p_laddr(inp), zeroin_addr)) { + struct socket *so = inp->inp_socket; + if (__predict_true(so != NULL && + mutex_tryenter(so->so_lock))) { + if (__predict_false(atomic_load_acquire( + &inp->inp_state) == INP_FREED)) { + mutex_exit(so->so_lock); + inp = NULL; + goto out; + } + goto out; + } + continue; + } } + inp = NULL; +out: + pserialize_read_exit(s); + if (inp != NULL) + return inp; #ifdef DIAGNOSTIC if (inpcb_notifymiss) { printf("inpcb_lookup_bound: laddr=%08x lport=%d\n", @@ -1147,14 +1433,6 @@ inpcb_lookup_bound(struct inpcbtable *table, } #endif return 0; - -out: - /* Move this PCB to the head of hash chain. */ - if (inp != LIST_FIRST(head)) { - LIST_REMOVE(inp, inp_hash); - LIST_INSERT_HEAD(head, inp, inp_hash); - } - return inp; } void @@ -1171,20 +1449,37 @@ inpcb_set_state(struct inpcb *inp, int state) return; #endif - if (inp->inp_state > INP_ATTACHED) - LIST_REMOVE(inp, inp_hash); + KASSERT(INP_HASH_LOCKED(inp->inp_table)); + + /* Remove from current hash (uses separate entry per hash table) */ + switch (inp->inp_state) { + case INP_BOUND: + PSLIST_WRITER_REMOVE(inp, inp_bind_hash); + /* + * Re-init entry for future reuse. Safe under hash lock; + * pserialize_perform() in deferred free ensures readers + * are done before re-init. + */ + PSLIST_ENTRY_INIT(inp, inp_bind_hash); + break; + case INP_CONNECTED: + PSLIST_WRITER_REMOVE(inp, inp_connect_hash); + PSLIST_ENTRY_INIT(inp, inp_connect_hash); + break; + } + /* Insert into new hash */ switch (state) { case INP_BOUND: - LIST_INSERT_HEAD(INPCBHASH_BIND(inp->inp_table, + PSLIST_WRITER_INSERT_HEAD(INPCBHASH_BIND(inp->inp_table, in4p_laddr(inp), inp->inp_lport), inp, - inp_hash); + inp_bind_hash); break; case INP_CONNECTED: - LIST_INSERT_HEAD(INPCBHASH_CONNECT(inp->inp_table, + PSLIST_WRITER_INSERT_HEAD(INPCBHASH_CONNECT(inp->inp_table, in4p_faddr(inp), inp->inp_fport, in4p_laddr(inp), inp->inp_lport), inp, - inp_hash); + inp_connect_hash); break; } diff --git a/sys/netinet/in_pcb.h b/sys/netinet/in_pcb.h index 8ac387dea410..ad67b25b4eff 100644 --- a/sys/netinet/in_pcb.h +++ b/sys/netinet/in_pcb.h @@ -64,6 +64,12 @@ #define _NETINET_IN_PCB_H_ #include +#include +#include +#include +#include +#include +#include #include @@ -87,15 +93,20 @@ struct icmp6_filter; */ struct inpcb { - LIST_ENTRY(inpcb) inp_hash; - LIST_ENTRY(inpcb) inp_lhash; - TAILQ_ENTRY(inpcb) inp_queue; + struct pslist_entry inp_bind_hash; /* bind hash (pserialize-safe) */ + struct pslist_entry inp_connect_hash; /* connect hash (pserialize-safe) */ + struct pslist_entry inp_port_hash; /* port hash (pserialize-safe) */ + struct pslist_entry inp_queue_hash; /* all-PCB list (pserialize-safe) */ + TAILQ_ENTRY(inpcb) inp_queue; /* all-PCB list (kvm compat) */ + krwlock_t inp_lock; /* per-inpcb rwlock (future use) */ + u_int inp_refcount; /* reference count (future use) */ int inp_af; /* address family - AF_INET or AF_INET6 */ void * inp_ppcb; /* pointer to per-protocol pcb */ int inp_state; /* bind/connect state */ #define INP_ATTACHED 0 #define INP_BOUND 1 #define INP_CONNECTED 2 +#define INP_FREED -1 /* marked for deferred free */ int inp_portalgo; struct socket *inp_socket; /* back pointer to socket */ struct inpcbtable *inp_table; @@ -112,6 +123,8 @@ struct inpcb { pcb_overudp_cb_t inp_overudp_cb; void *inp_overudp_arg; + + struct work inp_destroy_work; /* deferred free after pserialize */ }; struct in4pcb { @@ -154,7 +167,10 @@ struct in6pcb { #define in6p_outputopts(inpcb) (((struct in6pcb *)(inpcb))->in6p_outputopts) #define in6p_moptions(inpcb) (((struct in6pcb *)(inpcb))->in6p_moptions) -LIST_HEAD(inpcbhead, inpcb); +/* Legacy type retained for DCCP compat, not used by hash tables */ +struct inpcbhead { + struct pslist_head head; +}; /* flags in inp_flags: */ #define INP_RECVOPTS 0x0001 /* receive incoming IP options */ @@ -219,19 +235,43 @@ LIST_HEAD(inpcbhead, inpcb); #define inp_unlock(inp) sounlock((inp)->inp_socket) #define inp_locked(inp) solocked((inp)->inp_socket) +/* + * Release the solock acquired by inpcb lookup functions. + * All inpcb-using protocols (TCP, UDP, raw IP, DCCP) use per-socket + * locks. SCTP has its own PCB infrastructure and does not use these. + */ +#define inp_lookup_unlock(inp) sounlock((inp)->inp_socket) + +/* Per-inpcb rwlock macros (used for hash-table write-side operations) */ +#define INP_WLOCK(inp) rw_enter(&(inp)->inp_lock, RW_WRITER) +#define INP_RLOCK(inp) rw_enter(&(inp)->inp_lock, RW_READER) +#define INP_WUNLOCK(inp) rw_exit(&(inp)->inp_lock) +#define INP_RUNLOCK(inp) rw_exit(&(inp)->inp_lock) +#define INP_WLOCKED(inp) rw_write_held(&(inp)->inp_lock) +#define INP_LOCK_INIT(inp) rw_init(&(inp)->inp_lock) +#define INP_LOCK_DESTROY(inp) rw_destroy(&(inp)->inp_lock) + +/* Hash table lock macros */ +#define INP_HASH_LOCK(table) mutex_enter(&(table)->inpt_hash_lock) +#define INP_HASH_UNLOCK(table) mutex_exit(&(table)->inpt_hash_lock) +#define INP_HASH_LOCKED(table) mutex_owned(&(table)->inpt_hash_lock) + TAILQ_HEAD(inpcbqueue, inpcb); struct vestigial_hooks; /* It's still referenced by kvm users */ struct inpcbtable { - struct inpcbqueue inpt_queue; - struct inpcbhead *inpt_porthashtbl; - struct inpcbhead *inpt_bindhashtbl; - struct inpcbhead *inpt_connecthashtbl; + struct inpcbqueue inpt_queue; /* kvm compat (TAILQ) */ + struct pslist_head inpt_queue_pslist; /* all-PCB list (pserialize-safe) */ + struct pslist_head *inpt_porthashtbl; + struct pslist_head *inpt_bindhashtbl; + struct pslist_head *inpt_connecthashtbl; u_long inpt_porthash; u_long inpt_bindhash; u_long inpt_connecthash; + kmutex_t inpt_hash_lock; /* protects hash insert/remove */ + pserialize_t inpt_psz; /* pserialize for lock-free lookups */ in_port_t inpt_lastport; in_port_t inpt_lastlow; @@ -250,6 +290,7 @@ struct sockaddr_in; struct socket; struct vestigial_inpcb; +void inpcb_pool_put(struct inpcb *); void inpcb_losing(struct inpcb *); int inpcb_create(struct socket *, void *); int inpcb_bindableaddr(const struct inpcb *, struct sockaddr_in *, @@ -316,6 +357,38 @@ extern struct inpcb *in6pcb_lookup(struct inpcbtable *, extern struct inpcb *in6pcb_lookup_bound(struct inpcbtable *, const struct in6_addr *, u_int, int); +/* + * inpcb_ref_acquire: atomically increment the reference count only if + * it is currently > 0. Returns true on success, false if the inpcb is + * already being freed (refcount == 0). This follows the + * refcount_acquire_if_not_zero pattern from FreeBSD. + */ +static inline bool +inpcb_ref_acquire(struct inpcb *inp) +{ + u_int old; + + do { + old = atomic_load_relaxed(&inp->inp_refcount); + if (old == 0) + return false; + } while (atomic_cas_uint(&inp->inp_refcount, old, old + 1) != old); + return true; +} + +/* + * inpcb_ref_release: atomically decrement the reference count. + * Returns true if the count reached zero (caller should free). + */ +static inline bool +inpcb_ref_release(struct inpcb *inp) +{ + + KASSERT(atomic_load_relaxed(&inp->inp_refcount) > 0); + membar_release(); + return atomic_dec_uint_nv(&inp->inp_refcount) == 0; +} + static inline void inpcb_register_overudp_cb(struct inpcb *inp, pcb_overudp_cb_t cb, void *arg) { @@ -340,15 +413,19 @@ struct in6_addr; * If vestigial entries exist for a table (TCP only) * the vestigial pointer is set. */ +struct tcp_ports_iterator; + typedef struct vestigial_hooks { /* IPv4 hooks */ - void *(*init_ports4)(struct in_addr, u_int, int); + void *(*init_ports4)(struct in_addr, u_int, int, + struct tcp_ports_iterator *); int (*next_port4)(void *, struct vestigial_inpcb *); int (*lookup4)(struct in_addr, uint16_t, struct in_addr, uint16_t, struct vestigial_inpcb *); /* IPv6 hooks */ - void *(*init_ports6)(const struct in6_addr*, u_int, int); + void *(*init_ports6)(const struct in6_addr*, u_int, int, + struct tcp_ports_iterator *); int (*next_port6)(void *, struct vestigial_inpcb *); int (*lookup6)(const struct in6_addr *, uint16_t, const struct in6_addr *, uint16_t, diff --git a/sys/netinet/in_proto.c b/sys/netinet/in_proto.c index 0c7be4b2e704..180a9d374e2a 100644 --- a/sys/netinet/in_proto.c +++ b/sys/netinet/in_proto.c @@ -146,50 +146,16 @@ __KERNEL_RCSID(0, "$NetBSD: in_proto.c,v 1.131 2022/09/03 02:53:18 thorpej Exp $ DOMAIN_DEFINE(inetdomain); /* forward declare and add to link set */ -/* Wrappers to acquire kernel_lock. */ - -PR_WRAP_CTLINPUT(rip_ctlinput) -PR_WRAP_CTLINPUT(udp_ctlinput) -PR_WRAP_CTLINPUT(tcp_ctlinput) - -#define rip_ctlinput rip_ctlinput_wrapper -#define udp_ctlinput udp_ctlinput_wrapper -#define tcp_ctlinput tcp_ctlinput_wrapper - -PR_WRAP_CTLOUTPUT(rip_ctloutput) -PR_WRAP_CTLOUTPUT(udp_ctloutput) -PR_WRAP_CTLOUTPUT(tcp_ctloutput) - -#define rip_ctloutput rip_ctloutput_wrapper -#define udp_ctloutput udp_ctloutput_wrapper -#define tcp_ctloutput tcp_ctloutput_wrapper - -#ifdef DCCP -PR_WRAP_CTLINPUT(dccp_ctlinput) -PR_WRAP_CTLOUTPUT(dccp_ctloutput) - -#define dccp_ctlinput dccp_ctlinput_wrapper -#define dccp_ctloutput dccp_ctloutput_wrapper -#endif - -#ifdef SCTP -PR_WRAP_CTLINPUT(sctp_ctlinput) -PR_WRAP_CTLOUTPUT(sctp_ctloutput) - -#define sctp_ctlinput sctp_ctlinput_wrapper -#define sctp_ctloutput sctp_ctloutput_wrapper -#endif +/* + * All per-socket-lock protocols use internal locking for ctlinput + * (mutex_obj_hold + refcount for single lookup, hand-over-hand + * pserialize + refcount for broadcast notify). No KERNEL_LOCK + * or softnet_lock wrappers needed. + */ #ifdef NET_MPSAFE -PR_WRAP_INPUT(udp_input) -PR_WRAP_INPUT(tcp_input) -#ifdef DCCP -PR_WRAP_INPUT(dccp_input) -#endif -#ifdef SCTP -PR_WRAP_INPUT(sctp_input) -#endif -PR_WRAP_INPUT(rip_input) +/* All per-socket-lock protocols acquire solock internally, no input wrappers. */ +/* rip_input uses pserialize for TAILQ iteration + tryenter for solock */ #if NPFSYNC > 0 PR_WRAP_INPUT(pfsync_input) #endif @@ -198,11 +164,6 @@ PR_WRAP_INPUT(igmp_input) PR_WRAP_INPUT(pim_input) #endif -#define udp_input udp_input_wrapper -#define tcp_input tcp_input_wrapper -#define dccp_input dccp_input_wrapper -#define sctp_input sctp_input_wrapper -#define rip_input rip_input_wrapper #define pfsync_input pfsync_input_wrapper #define igmp_input igmp_input_wrapper #define pim_input pim_input_wrapper diff --git a/sys/netinet/in_var.h b/sys/netinet/in_var.h index 98a7f53ad407..bff9292fcc34 100644 --- a/sys/netinet/in_var.h +++ b/sys/netinet/in_var.h @@ -495,6 +495,13 @@ int sysctl_inpcblist(SYSCTLFN_PROTO); #define LLTABLE(ifp) \ ((struct in_ifinfo *)(ifp)->if_afdata[AF_INET])->ii_llt +static inline uint32_t +in_lltable_hash_dst(const struct in_addr dst, uint32_t hsize) +{ + uint32_t k = dst.s_addr; + return (((((((k >> 8) ^ k) >> 8) ^ k) >> 8) ^ k) & (hsize - 1)); +} + #endif /* !_KERNEL */ /* INET6 stuff */ diff --git a/sys/netinet/ip_icmp.c b/sys/netinet/ip_icmp.c index e627ea892bff..d9cdd484a0b4 100644 --- a/sys/netinet/ip_icmp.c +++ b/sys/netinet/ip_icmp.c @@ -1143,11 +1143,16 @@ void icmp_mtudisc(struct icmp *icp, struct in_addr faddr) { struct icmp_mtudisc_callback *mc; - struct sockaddr *dst = sintosa(&icmpsrc); + struct sockaddr_in sin; + struct sockaddr *dst; struct rtentry *rt; u_long mtu = ntohs(icp->icmp_nextmtu); /* Why a long? IPv6 */ int error; + /* Use a local sockaddr instead of the global icmpsrc */ + sockaddr_in_init(&sin, &faddr, 0); + dst = sintosa(&sin); + rt = rtalloc1(dst, 1); if (rt == NULL) return; diff --git a/sys/netinet/portalgo.c b/sys/netinet/portalgo.c index 16c357540fc5..c67c18255454 100644 --- a/sys/netinet/portalgo.c +++ b/sys/netinet/portalgo.c @@ -47,6 +47,7 @@ __KERNEL_RCSID(0, "$NetBSD: portalgo.c,v 1.15 2022/11/04 09:01:53 ozaki-r Exp $" #include #include #include +#include #include @@ -263,6 +264,8 @@ check_suitable_port(uint16_t port, struct inpcb *inp, kauth_cred_t cred) sin.sin_addr = in4p_laddr(inp); pcb = inpcb_lookup_local(table, sin.sin_addr, htons(port), 1, &vestigial); + if (pcb != NULL) + inp_lookup_unlock(pcb); DPRINTF("%s inpcb_lookup_local returned %p and " "vestigial.valid %d\n", @@ -318,6 +321,8 @@ check_suitable_port(uint16_t port, struct inpcb *inp, kauth_cred_t cred) t = inpcb_lookup_local(table, *(struct in_addr *)&sin6.sin6_addr.s6_addr32[3], htons(port), wild, &vestigial); + if (t != NULL) + inp_lookup_unlock((struct inpcb *)t); if (!t && vestigial.valid) { DPRINTF("%s inpcb_lookup_local returned " "a result\n", __func__); @@ -328,6 +333,8 @@ check_suitable_port(uint16_t port, struct inpcb *inp, kauth_cred_t cred) { t = in6pcb_lookup_local(table, &sin6.sin6_addr, htons(port), wild, &vestigial); + if (t != NULL) + inp_lookup_unlock((struct inpcb *)t); if (!t && vestigial.valid) { DPRINTF("%s in6pcb_lookup_local returned " "a result\n", __func__); diff --git a/sys/netinet/raw_ip.c b/sys/netinet/raw_ip.c index 23f2647fe991..f5c7bea345f6 100644 --- a/sys/netinet/raw_ip.c +++ b/sys/netinet/raw_ip.c @@ -168,7 +168,7 @@ rip_input(struct mbuf *m, int off, int proto) struct inpcb *last = NULL; struct mbuf *n; struct sockaddr_in ripsrc; - int hlen; + int hlen, s; sockaddr_in_init(&ripsrc, &ip->ip_src, 0); @@ -181,7 +181,15 @@ rip_input(struct mbuf *m, int off, int proto) ip->ip_len = ntohs(ip->ip_len) - hlen; NTOHS(ip->ip_off); - TAILQ_FOREACH(inp, &rawcbtable.inpt_queue, inp_queue) { + /* + * pserialize protects the PSLIST iteration from concurrent + * insert/remove in inpcb_create/inpcb_destroy. Use tryenter + * for solock to avoid blocking inside the pserialize section. + * On tryenter failure, drop (rare contention). + */ + s = pserialize_read_enter(); + PSLIST_READER_FOREACH(inp, &rawcbtable.inpt_queue_pslist, + struct inpcb, inp_queue_hash) { if (inp->inp_af != AF_INET) continue; if (in4p_ip(inp).ip_p && in4p_ip(inp).ip_p != proto) @@ -202,7 +210,16 @@ rip_input(struct mbuf *m, int off, int proto) } #endif else if ((n = m_copypacket(m, M_DONTWAIT)) != NULL) { - rip_sbappendaddr(last, ip, sintosa(&ripsrc), hlen, n); + bool locked = solocked(last->inp_socket); + if (locked || mutex_tryenter(last->inp_socket->so_lock)) { + rip_sbappendaddr(last, ip, sintosa(&ripsrc), + hlen, n); + if (!locked) + mutex_exit(last->inp_socket->so_lock); + } else { + /* tryenter failed, drop (rare contention) */ + m_freem(n); + } } last = inp; @@ -210,24 +227,51 @@ rip_input(struct mbuf *m, int off, int proto) #if defined(IPSEC) if (ipsec_used && last != NULL && ipsec_in_reject(m, last)) { + pserialize_read_exit(s); m_freem(m); IP_STATDEC(IP_STAT_DELIVERED); /* do not inject data into pcb */ } else #endif if (last != NULL) { - rip_sbappendaddr(last, ip, sintosa(&ripsrc), hlen, m); - } else if (inetsw[ip_protox[ip->ip_p]].pr_input == rip_input) { - net_stat_ref_t ips; - - icmp_error(m, ICMP_UNREACH, ICMP_UNREACH_PROTOCOL, - 0, 0); - ips = IP_STAT_GETREF(); - _NET_STATINC_REF(ips, IP_STAT_NOPROTO); - _NET_STATDEC_REF(ips, IP_STAT_DELIVERED); - IP_STAT_PUTREF(); + bool locked = solocked(last->inp_socket); + if (locked || mutex_tryenter(last->inp_socket->so_lock)) { + rip_sbappendaddr(last, ip, sintosa(&ripsrc), hlen, m); + if (!locked) + mutex_exit(last->inp_socket->so_lock); + pserialize_read_exit(s); + } else if (inpcb_ref_acquire(last)) { + kmutex_t *lock = last->inp_socket->so_lock; + mutex_obj_hold(lock); + pserialize_read_exit(s); + mutex_enter(lock); + if (last->inp_state != INP_FREED) + rip_sbappendaddr(last, ip, sintosa(&ripsrc), + hlen, m); + else + m_freem(m); + mutex_exit(lock); + mutex_obj_free(lock); + if (inpcb_ref_release(last)) + inpcb_pool_put(last); + } else { + pserialize_read_exit(s); + m_freem(m); + } } else { - m_freem(m); + pserialize_read_exit(s); + if (inetsw[ip_protox[ip->ip_p]].pr_input == rip_input) { + net_stat_ref_t ips; + + icmp_error(m, ICMP_UNREACH, ICMP_UNREACH_PROTOCOL, + 0, 0); + ips = IP_STAT_GETREF(); + _NET_STATINC_REF(ips, IP_STAT_NOPROTO); + _NET_STATDEC_REF(ips, IP_STAT_DELIVERED); + IP_STAT_PUTREF(); + } else { + m_freem(m); + } } return; @@ -238,21 +282,44 @@ rip_pcbnotify(struct inpcbtable *table, struct in_addr faddr, struct in_addr laddr, int proto, int errno, void (*notify)(struct inpcb *, int)) { - struct inpcb *inp; - int nmatch; + struct inpcb *inp, *next; + int s, nmatch; nmatch = 0; - TAILQ_FOREACH(inp, &table->inpt_queue, inp_queue) { - if (inp->inp_af != AF_INET) + s = pserialize_read_enter(); + for (inp = PSLIST_READER_FIRST(&table->inpt_queue_pslist, + struct inpcb, inp_queue_hash); inp != NULL; inp = next) { + if (inp->inp_af != AF_INET || + (in4p_ip(inp).ip_p && in4p_ip(inp).ip_p != proto) || + !in_hosteq(in4p_faddr(inp), faddr) || + !in_hosteq(in4p_laddr(inp), laddr)) { + next = PSLIST_READER_NEXT(inp, struct inpcb, + inp_queue_hash); continue; - if (in4p_ip(inp).ip_p && in4p_ip(inp).ip_p != proto) - continue; - if (in_hosteq(in4p_faddr(inp), faddr) && - in_hosteq(in4p_laddr(inp), laddr)) { - (*notify)(inp, errno); - nmatch++; + } + next = PSLIST_READER_NEXT(inp, struct inpcb, inp_queue_hash); + while (next != NULL && !inpcb_ref_acquire(next)) + next = PSLIST_READER_NEXT(next, struct inpcb, + inp_queue_hash); + if (!inpcb_ref_acquire(inp)) { + if (next != NULL && inpcb_ref_release(next)) + inpcb_pool_put(next); + break; + } + pserialize_read_exit(s); + + (*notify)(inp, errno); + nmatch++; + if (inpcb_ref_release(inp)) + inpcb_pool_put(inp); + + s = pserialize_read_enter(); + if (next != NULL) { + bool last __diagused = inpcb_ref_release(next); + KASSERT(!last); } } + pserialize_read_exit(s); return nmatch; } @@ -513,7 +580,16 @@ rip_attach(struct socket *so, int proto) int error; KASSERT(sotoinpcb(so) == NULL); - sosetlock(so); + + /* + * Assign a per-socket lock. Each raw IP socket gets its own + * mutex so that input on different CPUs can be processed in + * parallel. + */ + if (so->so_lock == NULL) { + so->so_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + mutex_enter(so->so_lock); + } if (so->so_snd.sb_hiwat == 0 || so->so_rcv.sb_hiwat == 0) { error = soreserve(so, rip_sendspace, rip_recvspace); @@ -829,6 +905,15 @@ rip_purgeif(struct socket *so, struct ifnet *ifp) return 0; } +#ifdef NET_MPSAFE +/* + * Raw IP is fully MP-safe with per-socket locking. All usrreqs run + * under solock (held by the socket layer), so KERNEL_LOCK wrappers + * must NOT be used, acquiring KERNEL_LOCK inside solock inverts + * the lock order vs ctlinput (which takes KERNEL_LOCK then solock). + */ +/* No PR_WRAP_USRREQS all functions used directly */ +#else PR_WRAP_USRREQS(rip) #define rip_attach rip_attach_wrapper #define rip_detach rip_detach_wrapper @@ -849,6 +934,7 @@ PR_WRAP_USRREQS(rip) #define rip_send rip_send_wrapper #define rip_sendoob rip_sendoob_wrapper #define rip_purgeif rip_purgeif_wrapper +#endif const struct pr_usrreqs rip_usrreqs = { .pr_attach = rip_attach, diff --git a/sys/netinet/sctp_input.c b/sys/netinet/sctp_input.c index b01e834bb264..034cc8cecba1 100644 --- a/sys/netinet/sctp_input.c +++ b/sys/netinet/sctp_input.c @@ -4253,9 +4253,18 @@ sctp_input(struct mbuf *m, int off, int proto) offset -= sizeof(struct sctp_chunkhdr); ecn_bits = ip->ip_tos; + /* + * Acquire solock for the duration of input processing. + * sbappend/sorwakeup require solock on NetBSD. SCTP internal + * locks are no-ops on NetBSD, so solock is the sole serializer. + */ + if (inp->sctp_socket) + solock(inp->sctp_socket); sctp_common_input_processing(&m, iphlen, offset, length, sh, ch, inp, stcb, net, ecn_bits); /* inp's ref-count reduced && stcb unlocked */ + if (inp->sctp_socket) + sounlock(inp->sctp_socket); sctp_m_freem(m); sctp_m_freem(opts); diff --git a/sys/netinet/sctp_usrreq.c b/sys/netinet/sctp_usrreq.c index ad6ccf1c3c8f..8f2098ce7476 100644 --- a/sys/netinet/sctp_usrreq.c +++ b/sys/netinet/sctp_usrreq.c @@ -512,7 +512,10 @@ sctp_attach(struct socket *so, int proto) #endif int error; - sosetlock(so); + if (so->so_lock == NULL) { + so->so_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + mutex_enter(so->so_lock); + } inp = (struct sctp_inpcb *)so->so_pcb; if (inp != 0) { return EINVAL; @@ -571,14 +574,14 @@ sctp_bind(struct socket *so, struct sockaddr *nam, struct lwp *l) } -static int +static void sctp_detach(struct socket *so) { struct sctp_inpcb *inp; inp = (struct sctp_inpcb *)so->so_pcb; if (inp == 0) - return EINVAL; + return; if (((so->so_options & SO_LINGER) && (so->so_linger == 0)) || (so->so_rcv.sb_cc > 0)) { @@ -586,7 +589,6 @@ sctp_detach(struct socket *so) } else { sctp_inpcb_free(inp, 0); } - return 0; } static int @@ -4009,6 +4011,15 @@ sysctl_net_inet_sctp_setup(struct sysctllog **clog) #endif } +#ifdef NET_MPSAFE +/* + * SCTP uses per-socket locks. The input path acquires solock before + * calling sctp_common_input_processing (which calls sbappend/sorwakeup). + * SCTP internal locks are currently no-ops on NetBSD, so solock is the + * sole serializer. No KERNEL_LOCK wrappers needed. + */ +/* No PR_WRAP_USRREQS all functions used directly */ +#else PR_WRAP_USRREQS(sctp) #define sctp_attach sctp_attach_wrapper #define sctp_detach sctp_detach_wrapper @@ -4029,6 +4040,7 @@ PR_WRAP_USRREQS(sctp) #define sctp_send sctp_send_wrapper #define sctp_sendoob sctp_sendoob_wrapper #define sctp_purgeif sctp_purgeif_wrapper +#endif const struct pr_usrreqs sctp_usrreqs = { .pr_attach = sctp_attach, diff --git a/sys/netinet/tcp_input.c b/sys/netinet/tcp_input.c index 31bba8251b99..997d5cd49d00 100644 --- a/sys/netinet/tcp_input.c +++ b/sys/netinet/tcp_input.c @@ -1197,6 +1197,7 @@ tcp_input(struct mbuf *m, int off, int proto) struct tcpcb *tp = NULL; int tiflags; struct socket *so = NULL; + kmutex_t *inp_lock = NULL; /* saved per-socket lock */ int todrop, acked, ourfinisacked, needoutput = 0; bool dupseg; #ifdef TCP_DEBUG @@ -1412,8 +1413,10 @@ findpcb: } #if defined(IPSEC) if (ipsec_used) { - if (inp && ipsec_in_reject(m, inp)) + if (inp && ipsec_in_reject(m, inp)) { + sounlock(inp->inp_socket); goto drop; + } } #endif /*IPSEC*/ break; @@ -1444,8 +1447,10 @@ findpcb: goto dropwithreset_ratelim; } #if defined(IPSEC) - if (ipsec_used && inp && ipsec_in_reject(m, inp)) + if (ipsec_used && inp && ipsec_in_reject(m, inp)) { + sounlock(inp->inp_socket); goto drop; + } #endif break; } @@ -1463,12 +1468,24 @@ findpcb: tp = NULL; so = NULL; if (inp) { + /* + * inp is returned with solock held from lookup. + */ + so = inp->inp_socket; + /* Check the minimum TTL for socket. */ - if (inp->inp_af == AF_INET && ip->ip_ttl < in4p_ip_minttl(inp)) + if (inp->inp_af == AF_INET && ip->ip_ttl < in4p_ip_minttl(inp)) { + sounlock(so); goto drop; + } tp = intotcpcb(inp); - so = inp->inp_socket; + /* + * Save the per-socket lock pointer for the final + * unlock because tcp_close may free the socket. + */ + inp_lock = so->so_lock; + mutex_obj_hold(inp_lock); } else if (vestige.valid) { /* We do not support the resurrection of vtw tcpcps. */ tcp_vtw_input(th, &vestige, m, tlen); @@ -1481,7 +1498,6 @@ findpcb: if (tp->t_state == TCPS_CLOSED) goto drop; - KASSERT(so->so_lock == softnet_lock); KASSERT(solocked(so)); /* Unscale the window into a 32-bit value. */ @@ -1901,10 +1917,12 @@ after_listen: sowwakeup(so); if (so->so_snd.sb_cc) { - KERNEL_LOCK(1, NULL); + KERNEL_LOCK_UNLESS_NET_MPSAFE(); (void)tcp_output(tp); - KERNEL_UNLOCK_ONE(NULL); + KERNEL_UNLOCK_UNLESS_NET_MPSAFE(); } + mutex_exit(inp_lock); + mutex_obj_free(inp_lock); m_freem(tcp_saveti); return; } @@ -2011,10 +2029,12 @@ after_listen: sorwakeup(so); tcp_setup_ack(tp, th); if (tp->t_flags & TF_ACKNOW) { - KERNEL_LOCK(1, NULL); + KERNEL_LOCK_UNLESS_NET_MPSAFE(); (void)tcp_output(tp); - KERNEL_UNLOCK_ONE(NULL); + KERNEL_UNLOCK_UNLESS_NET_MPSAFE(); } + mutex_exit(inp_lock); + mutex_obj_free(inp_lock); m_freem(tcp_saveti); return; } @@ -2063,8 +2083,10 @@ after_listen: SEQ_GT(th->th_ack, tp->snd_max))) goto dropwithreset; if (tiflags & TH_RST) { - if (tiflags & TH_ACK) + if (tiflags & TH_ACK) { tp = tcp_drop(tp, ECONNREFUSED); + /* tcp_close released the lock */ + } goto drop; } if ((tiflags & TH_SYN) == 0) @@ -2302,6 +2324,11 @@ after_listen: tp->t_state == TCPS_TIME_WAIT && SEQ_GT(th->th_seq, tp->rcv_nxt)) { tp = tcp_close(tp); + /* tcp_close returned with lock held; release it */ + mutex_exit(inp_lock); + mutex_obj_free(inp_lock); + inp_lock = NULL; + so = NULL; tcp_fields_to_net(th); m_freem(tcp_saveti); tcp_saveti = NULL; @@ -2406,6 +2433,8 @@ after_listen: if (tp->rcv_nxt == th->th_seq) { tcp_respond(tp, m, m, th, (tcp_seq)0, th->th_ack - 1, TH_ACK); + mutex_exit(inp_lock); + mutex_obj_free(inp_lock); m_freem(tcp_saveti); return; } @@ -2497,9 +2526,9 @@ after_listen: goto drop; } else if (tp->t_dupacks > tcprexmtthresh) { tp->snd_cwnd += tp->t_segsz; - KERNEL_LOCK(1, NULL); + KERNEL_LOCK_UNLESS_NET_MPSAFE(); (void)tcp_output(tp); - KERNEL_UNLOCK_ONE(NULL); + KERNEL_UNLOCK_UNLESS_NET_MPSAFE(); goto drop; } } else { @@ -2855,9 +2884,9 @@ dodata: * Return any desired output. */ if (needoutput || (tp->t_flags & TF_ACKNOW)) { - KERNEL_LOCK(1, NULL); + KERNEL_LOCK_UNLESS_NET_MPSAFE(); (void)tcp_output(tp); - KERNEL_UNLOCK_ONE(NULL); + KERNEL_UNLOCK_UNLESS_NET_MPSAFE(); } m_freem(tcp_saveti); @@ -2867,8 +2896,13 @@ dodata: && ((af == AF_INET ? tcp4_vtw_enable : tcp6_vtw_enable) & 1) != 0 && TAILQ_EMPTY(&tp->segq) && vtw_add(af, tp)) { - ; + /* vtw_add -> tcp_close returned with lock held; release */ + mutex_exit(inp_lock); + mutex_obj_free(inp_lock); + return; } + mutex_exit(inp_lock); + mutex_obj_free(inp_lock); return; badsyn: @@ -2901,9 +2935,11 @@ dropafterack_ratelim: dropafterack2: m_freem(m); tp->t_flags |= TF_ACKNOW; - KERNEL_LOCK(1, NULL); + KERNEL_LOCK_UNLESS_NET_MPSAFE(); (void)tcp_output(tp); - KERNEL_UNLOCK_ONE(NULL); + KERNEL_UNLOCK_UNLESS_NET_MPSAFE(); + mutex_exit(inp_lock); + mutex_obj_free(inp_lock); m_freem(tcp_saveti); return; @@ -2935,6 +2971,10 @@ dropwithreset: (void)tcp_respond(tp, m, m, th, th->th_seq + tlen, (tcp_seq)0, TH_RST|TH_ACK); } + if (inp_lock != NULL) { + mutex_exit(inp_lock); + mutex_obj_free(inp_lock); + } m_freem(tcp_saveti); return; @@ -2944,12 +2984,15 @@ drop: * Drop space held by incoming segment and return. */ if (tp) { - so = tp->t_inpcb->inp_socket; #ifdef TCP_DEBUG if (so && (so->so_options & SO_DEBUG) != 0) tcp_trace(TA_DROP, ostate, tp, tcp_saveti, 0); #endif } + if (inp_lock != NULL) { + mutex_exit(inp_lock); + mutex_obj_free(inp_lock); + } m_freem(tcp_saveti); m_freem(m); return; diff --git a/sys/netinet/tcp_subr.c b/sys/netinet/tcp_subr.c index 38dc53a1d1b0..edd6d63eb135 100644 --- a/sys/netinet/tcp_subr.c +++ b/sys/netinet/tcp_subr.c @@ -959,6 +959,7 @@ tcp_newtcpcb(int family, struct inpcb *inp) LIST_INIT(&tp->t_sc); /* XXX can template this */ /* Don't sweat this loop; hopefully the compiler will unroll it. */ + tp->t_timer_armed = 0; for (i = 0; i < TCPT_NTIMERS; i++) { callout_init(&tp->t_timer[i], CALLOUT_MPSAFE); TCP_TIMER_INIT(tp, i); @@ -1151,20 +1152,63 @@ tcp_close(struct tcpcb *tp) tp->t_flags |= TF_DEAD; inp->inp_ppcb = NULL; soisdisconnected(so); - inpcb_destroy(inp); + + /* + * Sever the socket->PCB link before halting timers. + * callout_halt drops so->so_lock temporarily, during which + * a concurrent soclose on another CPU could run. With + * so_pcb == NULL, soclose's pr_detach (tcp_detach) will + * see no PCB and return immediately. + * + * Hold a reference on the lock to keep it alive if soclose + * frees the socket (and its lock reference) during the + * callout_halt window. + */ + so->so_pcb = NULL; + inp->inp_socket = NULL; + { + kmutex_t *lock = so->so_lock; + mutex_obj_hold(lock); + /* - * pcb is no longer visble elsewhere, so we can safely release - * the lock in callout_halt() if needed. + * Halt all timers BEFORE inpcb_destroy. Timer callbacks + * use the same lock (tp->t_lock), so callout_halt correctly + * waits for any in-progress callback to complete. + * + * Note: callout_halt drops the lock while waiting. A concurrent + * soclose may run and free the socket during this window. + * Our mutex_obj_hold keeps the lock object alive regardless. */ TCP_STATINC(TCP_STAT_CLOSED); for (j = 0; j < TCPT_NTIMERS; j++) { - callout_halt(&tp->t_timer[j], softnet_lock); + callout_halt(&tp->t_timer[j], lock); callout_destroy(&tp->t_timer[j]); } - callout_halt(&tp->t_delack_ch, softnet_lock); + callout_halt(&tp->t_delack_ch, lock); callout_destroy(&tp->t_delack_ch); + + /* + * Destroy the inpcb. inp->inp_socket is NULL (set above), + * so inpcb_destroy skips all socket interaction. + */ + inpcb_destroy(inp); pool_put(&tcpcb_pool, tp); + /* + * Return with lock held. Our mutex_obj_hold keeps the lock + * alive even if concurrent soclose freed the socket during + * callout_halt. Release the extra reference but keep the + * lock held. + * + * Callers handle unlock: + * - soclose (pr_detach): lock held as expected; continues + * to set SS_NOFDREF and call sofree. + * - tcp_input: checks tp == NULL, skips sounlock. + * - timers: checks tp == NULL, skips mutex_exit/free. + */ + mutex_obj_free(lock); + } + return NULL; } @@ -1212,6 +1256,7 @@ tcp_drain(void) { struct inpcb *inp; struct tcpcb *tp; + int s; mutex_enter(softnet_lock); KERNEL_LOCK(1, NULL); @@ -1219,7 +1264,9 @@ tcp_drain(void) /* * Free the sequence queue of all TCP connections. */ - TAILQ_FOREACH(inp, &tcbtable.inpt_queue, inp_queue) { + s = pserialize_read_enter(); + PSLIST_READER_FOREACH(inp, &tcbtable.inpt_queue_pslist, + struct inpcb, inp_queue_hash) { tp = intotcpcb(inp); if (tp != NULL) { /* @@ -1233,6 +1280,7 @@ tcp_drain(void) TCP_REASS_UNLOCK(tp); } } + pserialize_read_exit(s); KERNEL_UNLOCK_ONE(NULL); mutex_exit(softnet_lock); @@ -1249,6 +1297,14 @@ tcp_notify(struct inpcb *inp, int error) struct tcpcb *tp = (struct tcpcb *)inp->inp_ppcb; struct socket *so = inp->inp_socket; + /* + * Acquire the per-socket lock. TCP sockets use per-socket + * locks, so we must hold it for sorwakeup/sowwakeup which + * assert solocked(so). This is called from both direct + * lookup paths and inpcb_notifyall iterators. + */ + solock(so); + /* * Ignore some errors if we are hooked up. * If connection hasn't completed, has retransmitted several times, @@ -1259,6 +1315,7 @@ tcp_notify(struct inpcb *inp, int error) if (tp->t_state == TCPS_ESTABLISHED && (error == EHOSTUNREACH || error == ENETUNREACH || error == EHOSTDOWN)) { + sounlock(so); return; } else if (TCPS_HAVEESTABLISHED(tp->t_state) == 0 && tp->t_rxtshift > 3 && tp->t_softerror) @@ -1268,6 +1325,7 @@ tcp_notify(struct inpcb *inp, int error) cv_broadcast(&so->so_cv); sorwakeup(so); sowwakeup(so); + sounlock(so); } #ifdef INET6 @@ -1276,7 +1334,7 @@ tcp6_ctlinput(int cmd, const struct sockaddr *sa, void *d) { struct tcphdr th; void (*notify)(struct inpcb *, int) = tcp_notify; - int nmatch; + struct inpcb *inp; struct ip6_hdr *ip6; const struct sockaddr_in6 *sa6_src = NULL; const struct sockaddr_in6 *sa6 = (const struct sockaddr_in6 *)sa; @@ -1336,11 +1394,14 @@ tcp6_ctlinput(int cmd, const struct sockaddr *sa, void *d) * corresponding to the address in the ICMPv6 message * payload. */ - if (in6pcb_lookup(&tcbtable, &sa6->sin6_addr, + inp = in6pcb_lookup(&tcbtable, &sa6->sin6_addr, th.th_dport, (const struct in6_addr *)&sa6_src->sin6_addr, - th.th_sport, 0, 0)) + th.th_sport, 0, NULL); + if (inp != NULL) { + sounlock(inp->inp_socket); valid++; + } /* * Depending on the value of "valid" and routing table @@ -1358,14 +1419,26 @@ tcp6_ctlinput(int cmd, const struct sockaddr *sa, void *d) return NULL; } - nmatch = in6pcb_notify(&tcbtable, sa, th.th_dport, - (const struct sockaddr *)sa6_src, th.th_sport, cmd, NULL, notify); - if (nmatch == 0 && syn_cache_count && + inp = in6pcb_lookup(&tcbtable, &sa6->sin6_addr, + th.th_dport, + (const struct in6_addr *)&sa6_src->sin6_addr, + th.th_sport, 0, NULL); + if (inp != NULL) { + kmutex_t *lock = inp->inp_socket->so_lock; + mutex_obj_hold(lock); + (void)inpcb_ref_acquire(inp); + mutex_exit(lock); + (*notify)(inp, inet6ctlerrmap[cmd]); + if (inpcb_ref_release(inp)) + inpcb_pool_put(inp); + mutex_obj_free(lock); + } else if (syn_cache_count && (inet6ctlerrmap[cmd] == EHOSTUNREACH || inet6ctlerrmap[cmd] == ENETUNREACH || - inet6ctlerrmap[cmd] == EHOSTDOWN)) + inet6ctlerrmap[cmd] == EHOSTDOWN)) { syn_cache_unreach((const struct sockaddr *)sa6_src, sa, &th); + } } else { (void) in6pcb_notify(&tcbtable, sa, 0, (const struct sockaddr *)sa6_src, 0, cmd, NULL, notify); @@ -1385,7 +1458,6 @@ tcp_ctlinput(int cmd, const struct sockaddr *sa, void *v) extern const int inetctlerrmap[]; void (*notify)(struct inpcb *, int) = tcp_notify; int errno; - int nmatch; struct tcpcb *tp; u_int mtu; tcp_seq seq; @@ -1422,16 +1494,18 @@ tcp_ctlinput(int cmd, const struct sockaddr *sa, void *v) in6_in_2_v4mapin6(&ip->ip_dst, &dst6); #endif if ((inp = inpcb_lookup(&tcbtable, ip->ip_dst, - th->th_dport, ip->ip_src, th->th_sport, 0)) != NULL) + th->th_dport, ip->ip_src, th->th_sport, NULL)) != NULL) ; #ifdef INET6 else if ((inp = in6pcb_lookup(&tcbtable, &dst6, - th->th_dport, &src6, th->th_sport, 0, 0)) != NULL) + th->th_dport, &src6, th->th_sport, 0, NULL)) != NULL) ; #endif else return NULL; + /* solock already held from lookup */ + /* * Now that we've validated that we are actually communicating * with the host indicated in the ICMP message, locate the @@ -1441,26 +1515,39 @@ tcp_ctlinput(int cmd, const struct sockaddr *sa, void *v) icp = (struct icmp *)((char *)ip - offsetof(struct icmp, icmp_ip)); tp = intotcpcb(inp); - if (tp == NULL) + if (tp == NULL) { + sounlock(inp->inp_socket); return NULL; + } seq = ntohl(th->th_seq); - if (SEQ_LT(seq, tp->snd_una) || SEQ_GT(seq, tp->snd_max)) + if (SEQ_LT(seq, tp->snd_una) || SEQ_GT(seq, tp->snd_max)) { + sounlock(inp->inp_socket); return NULL; + } /* * If the ICMP message advertises a Next-Hop MTU * equal or larger than the maximum packet size we have * ever sent, drop the message. */ mtu = (u_int)ntohs(icp->icmp_nextmtu); - if (mtu >= tp->t_pmtud_mtu_sent) + if (mtu >= tp->t_pmtud_mtu_sent) { + sounlock(inp->inp_socket); return NULL; + } if (mtu >= tcp_hdrsz(tp) + tp->t_pmtud_mss_acked) { /* * Calculate new MTU, and create corresponding * route (traditional PMTUD). */ + struct in_addr dst = ip->ip_dst; tp->t_flags &= ~TF_PMTUD_PEND; - icmp_mtudisc(icp, ip->ip_dst); + sounlock(inp->inp_socket); + /* + * Call icmp_mtudisc outside solock to avoid + * nested lock: solock -> icmp_mtx -> solock + * (via tcp_mtudisc_callback -> inpcb_notifyall). + */ + icmp_mtudisc(icp, dst); } else { /* * Record the information got in the ICMP @@ -1470,14 +1557,17 @@ tcp_ctlinput(int cmd, const struct sockaddr *sa, void *v) * refers to an older TCP segment */ if (tp->t_flags & TF_PMTUD_PEND) { - if (SEQ_LT(tp->t_pmtud_th_seq, seq)) + if (SEQ_LT(tp->t_pmtud_th_seq, seq)) { + sounlock(inp->inp_socket); return NULL; + } } else tp->t_flags |= TF_PMTUD_PEND; tp->t_pmtud_th_seq = seq; tp->t_pmtud_nextmtu = icp->icmp_nextmtu; tp->t_pmtud_ip_len = icp->icmp_ip.ip_len; tp->t_pmtud_ip_hl = icp->icmp_ip.ip_hl; + sounlock(inp->inp_socket); } return NULL; } else if (cmd == PRC_HOSTDEAD) @@ -1486,9 +1576,19 @@ tcp_ctlinput(int cmd, const struct sockaddr *sa, void *v) return NULL; if (ip && ip->ip_v == 4 && sa->sa_family == AF_INET) { th = (struct tcphdr *)((char *)ip + (ip->ip_hl << 2)); - nmatch = inpcb_notify(&tcbtable, satocsin(sa)->sin_addr, - th->th_dport, ip->ip_src, th->th_sport, errno, notify); - if (nmatch == 0 && syn_cache_count && + inp = inpcb_lookup(&tcbtable, satocsin(sa)->sin_addr, + th->th_dport, ip->ip_src, th->th_sport, NULL); + if (inp != NULL) { + kmutex_t *lock = inp->inp_socket->so_lock; + mutex_obj_hold(lock); + bool acquired __diagused = inpcb_ref_acquire(inp); + KASSERT(acquired); + mutex_exit(lock); + (*notify)(inp, errno); + if (inpcb_ref_release(inp)) + inpcb_pool_put(inp); + mutex_obj_free(lock); + } else if (syn_cache_count && (inetctlerrmap[cmd] == EHOSTUNREACH || inetctlerrmap[cmd] == ENETUNREACH || inetctlerrmap[cmd] == EHOSTDOWN)) { @@ -1527,6 +1627,22 @@ tcp_quench(struct inpcb *inp) /* * Path MTU Discovery handlers. */ + +/* + * Wrapper for tcp_mtudisc used as inpcb_notifyall callback. + * Acquires the per-socket lock before calling tcp_mtudisc + * (which asserts solocked). + */ +static void +tcp_mtudisc_locked(struct inpcb *inp, int errno) +{ + struct socket *so = inp->inp_socket; + + solock(so); + tcp_mtudisc(inp, errno); + sounlock(so); +} + void tcp_mtudisc_callback(struct in_addr faddr) { @@ -1534,7 +1650,7 @@ tcp_mtudisc_callback(struct in_addr faddr) struct in6_addr in6; #endif - inpcb_notifyall(&tcbtable, faddr, EMSGSIZE, tcp_mtudisc); + inpcb_notifyall(&tcbtable, faddr, EMSGSIZE, tcp_mtudisc_locked); #ifdef INET6 in6_in_2_v4mapin6(&faddr, &in6); tcp6_mtudisc_callback(&in6); @@ -1550,11 +1666,13 @@ void tcp_mtudisc(struct inpcb *inp, int errno) { struct tcpcb *tp = intotcpcb(inp); + struct socket *so = inp->inp_socket; struct rtentry *rt; if (tp == NULL) return; + KASSERT(solocked(so)); rt = inpcb_rtentry(inp); if (rt != NULL) { /* @@ -1593,6 +1711,16 @@ tcp_mtudisc(struct inpcb *inp, int errno) /* * Path MTU Discovery handlers. */ +static void +tcp6_mtudisc_locked(struct inpcb *inp, int errno) +{ + struct socket *so = inp->inp_socket; + + solock(so); + tcp6_mtudisc(inp, errno); + sounlock(so); +} + void tcp6_mtudisc_callback(struct in6_addr *faddr) { @@ -1603,18 +1731,20 @@ tcp6_mtudisc_callback(struct in6_addr *faddr) sin6.sin6_len = sizeof(struct sockaddr_in6); sin6.sin6_addr = *faddr; (void) in6pcb_notify(&tcbtable, (struct sockaddr *)&sin6, 0, - (const struct sockaddr *)&sa6_any, 0, PRC_MSGSIZE, NULL, tcp6_mtudisc); + (const struct sockaddr *)&sa6_any, 0, PRC_MSGSIZE, NULL, tcp6_mtudisc_locked); } void tcp6_mtudisc(struct inpcb *inp, int errno) { struct tcpcb *tp = intotcpcb(inp); + struct socket *so = inp->inp_socket; struct rtentry *rt; if (tp == NULL) return; + KASSERT(solocked(so)); rt = in6pcb_rtentry(inp); if (rt != NULL) { /* diff --git a/sys/netinet/tcp_syncache.c b/sys/netinet/tcp_syncache.c index db5b0c39f106..d981f19d810f 100644 --- a/sys/netinet/tcp_syncache.c +++ b/sys/netinet/tcp_syncache.c @@ -164,9 +164,11 @@ __KERNEL_RCSID(0, "$NetBSD: tcp_syncache.c,v 1.7 2024/06/29 12:59:08 riastradh E #include #include #include +#include #include #include #include /* for lwp0 */ +#include #include #include @@ -277,16 +279,23 @@ syn_cache_timer_arm(struct syn_cache *sc) #define SYN_CACHE_TIMESTAMP(sc) (tcp_now - (sc)->sc_timebase) +/* + * Remove a syn cache entry from its bucket and tcpcb list. + * Caller must hold the bucket mutex (sch_mtx). + */ static inline void syn_cache_rm(struct syn_cache *sc) { + + KASSERT(mutex_owned(&tcp_syn_cache[sc->sc_bucketidx].sch_mtx)); + TAILQ_REMOVE(&tcp_syn_cache[sc->sc_bucketidx].sch_bucket, sc, sc_bucketq); sc->sc_tp = NULL; LIST_REMOVE(sc, sc_tpq); tcp_syn_cache[sc->sc_bucketidx].sch_length--; callout_stop(&sc->sc_timer); - syn_cache_count--; + atomic_dec_ulong(&syn_cache_count); } static inline void @@ -309,8 +318,10 @@ syn_cache_init(void) "synpl", NULL, IPL_SOFTNET); /* Initialize the hash buckets. */ - for (i = 0; i < tcp_syn_cache_size; i++) + for (i = 0; i < tcp_syn_cache_size; i++) { + mutex_init(&tcp_syn_cache[i].sch_mtx, MUTEX_DEFAULT, IPL_SOFTNET); TAILQ_INIT(&tcp_syn_cache[i].sch_bucket); + } } void @@ -318,13 +329,13 @@ syn_cache_insert(struct syn_cache *sc, struct tcpcb *tp) { struct syn_cache_head *scp; struct syn_cache *sc2; - int s; /* * If there are no entries in the hash table, reinitialize - * the hash secrets. + * the hash secrets. This is racy but harmless worst case + * we reinitialize twice at startup. */ - if (syn_cache_count == 0) { + if (atomic_load_relaxed(&syn_cache_count) == 0) { syn_hash1 = cprng_fast32(); syn_hash2 = cprng_fast32(); } @@ -337,7 +348,7 @@ syn_cache_insert(struct syn_cache *sc, struct tcpcb *tp) * Make sure that we don't overflow the per-bucket * limit or the total cache size limit. */ - s = splsoftnet(); + mutex_enter(&scp->sch_mtx); if (scp->sch_length >= tcp_syn_bucket_limit) { TCP_STATINC(TCP_STAT_SC_BUCKETOVERFLOW); /* @@ -354,8 +365,9 @@ syn_cache_insert(struct syn_cache *sc, struct tcpcb *tp) panic("syn_cache_insert: bucketoverflow: impossible"); #endif syn_cache_rm(sc2); - syn_cache_put(sc2); /* calls pool_put but see spl above */ - } else if (syn_cache_count >= tcp_syn_cache_limit) { + syn_cache_put(sc2); + } else if (atomic_load_relaxed(&syn_cache_count) >= + (u_long)tcp_syn_cache_limit) { struct syn_cache_head *scp2, *sce; TCP_STATINC(TCP_STAT_SC_OVERFLOWED); @@ -366,29 +378,38 @@ syn_cache_insert(struct syn_cache *sc, struct tcpcb *tp) * XXX We would really like to toss the oldest * entry in the cache, but we hope that this * condition doesn't happen very often. + * + * Use trylock when scanning other buckets to avoid + * lock-order issues between bucket mutexes. */ - scp2 = scp; - if (TAILQ_EMPTY(&scp2->sch_bucket)) { + sc2 = NULL; + if (!TAILQ_EMPTY(&scp->sch_bucket)) { + /* Our own bucket is non-empty and already locked. */ + sc2 = TAILQ_FIRST(&scp->sch_bucket); + syn_cache_rm(sc2); + syn_cache_put(sc2); + } else { sce = &tcp_syn_cache[tcp_syn_cache_size]; - for (++scp2; scp2 != scp; scp2++) { + for (scp2 = scp + 1; scp2 != scp; scp2++) { if (scp2 >= sce) scp2 = &tcp_syn_cache[0]; - if (! TAILQ_EMPTY(&scp2->sch_bucket)) + if (!mutex_tryenter(&scp2->sch_mtx)) + continue; + if (!TAILQ_EMPTY(&scp2->sch_bucket)) { + sc2 = TAILQ_FIRST(&scp2->sch_bucket); + syn_cache_rm(sc2); + mutex_exit(&scp2->sch_mtx); + syn_cache_put(sc2); break; + } + mutex_exit(&scp2->sch_mtx); } -#ifdef DIAGNOSTIC /* - * This should never happen; we should always find a - * non-empty bucket. + * If all other buckets were empty or locked, + * proceed anyway. The global count is approximate + * under per-bucket locking and will self-correct. */ - if (scp2 == scp) - panic("syn_cache_insert: cacheoverflow: " - "impossible"); -#endif } - sc2 = TAILQ_FIRST(&scp2->sch_bucket); - syn_cache_rm(sc2); - syn_cache_put(sc2); /* calls pool_put but see spl above */ } /* @@ -404,10 +425,10 @@ syn_cache_insert(struct syn_cache *sc, struct tcpcb *tp) /* Put it into the bucket. */ TAILQ_INSERT_TAIL(&scp->sch_bucket, sc, sc_bucketq); scp->sch_length++; - syn_cache_count++; + atomic_inc_ulong(&syn_cache_count); TCP_STATINC(TCP_STAT_SC_ADDED); - splx(s); + mutex_exit(&scp->sch_mtx); } /* @@ -419,9 +440,10 @@ static void syn_cache_timer(void *arg) { struct syn_cache *sc = arg; + struct syn_cache_head *scp; - mutex_enter(softnet_lock); - KERNEL_LOCK(1, NULL); + scp = &tcp_syn_cache[sc->sc_bucketidx]; + mutex_enter(&scp->sch_mtx); callout_ack(&sc->sc_timer); @@ -465,22 +487,26 @@ syn_cache_timer(void *arg) pool_put(&syn_cache_pool, sc); out: - KERNEL_UNLOCK_ONE(NULL); - mutex_exit(softnet_lock); + mutex_exit(&scp->sch_mtx); } /* * Remove syn cache created by the specified tcb entry, * because this does not make sense to keep them * (if there's no tcb entry, syn cache entry will never be used) + * + * XXX The tp->t_sc list is not independently locked. This is safe + * because syn_cache_cleanup is called with the socket lock held and + * syn_cache_timer also acquires the per-bucket lock before modifying + * the list. However, entries in different buckets could theoretically + * race on the t_sc list. In practice, the tcpcb is being torn down + * so no new entries will be added. */ void syn_cache_cleanup(struct tcpcb *tp) { struct syn_cache *sc, *nsc; - int s; - - s = splsoftnet(); + struct syn_cache_head *scp; for (sc = LIST_FIRST(&tp->t_sc); sc != NULL; sc = nsc) { nsc = LIST_NEXT(sc, sc_tpq); @@ -489,17 +515,22 @@ syn_cache_cleanup(struct tcpcb *tp) if (sc->sc_tp != tp) panic("invalid sc_tp in syn_cache_cleanup"); #endif + scp = &tcp_syn_cache[sc->sc_bucketidx]; + mutex_enter(&scp->sch_mtx); syn_cache_rm(sc); - syn_cache_put(sc); /* calls pool_put but see spl above */ + mutex_exit(&scp->sch_mtx); + syn_cache_put(sc); } /* just for safety */ LIST_INIT(&tp->t_sc); - - splx(s); } /* * Find an entry in the syn cache. + * + * Returns with the bucket mutex held if a match is found. + * The caller is responsible for releasing scp->sch_mtx. + * If no match is found, returns NULL with no lock held. */ static struct syn_cache * syn_cache_lookup(const struct sockaddr *src, const struct sockaddr *dst, @@ -508,24 +539,21 @@ syn_cache_lookup(const struct sockaddr *src, const struct sockaddr *dst, struct syn_cache *sc; struct syn_cache_head *scp; u_int32_t hash; - int s; SYN_HASHALL(hash, src, dst); scp = &tcp_syn_cache[hash % tcp_syn_cache_size]; *headp = scp; - s = splsoftnet(); + mutex_enter(&scp->sch_mtx); for (sc = TAILQ_FIRST(&scp->sch_bucket); sc != NULL; sc = TAILQ_NEXT(sc, sc_bucketq)) { if (sc->sc_hash != hash) continue; if (!memcmp(&sc->sc_src, src, src->sa_len) && - !memcmp(&sc->sc_dst, dst, dst->sa_len)) { - splx(s); - return (sc); - } + !memcmp(&sc->sc_dst, dst, dst->sa_len)) + return (sc); /* returns with sch_mtx held */ } - splx(s); + mutex_exit(&scp->sch_mtx); return (NULL); } @@ -559,12 +587,10 @@ syn_cache_get(struct sockaddr *src, struct sockaddr *dst, struct syn_cache_head *scp; struct inpcb *inp = NULL; struct tcpcb *tp; - int s; struct socket *oso; - s = splsoftnet(); + /* syn_cache_lookup returns with scp->sch_mtx held on match */ if ((sc = syn_cache_lookup(src, dst, &scp)) == NULL) { - splx(s); return NULL; } @@ -577,13 +603,13 @@ syn_cache_get(struct sockaddr *src, struct sockaddr *dst, SEQ_GT(th->th_seq, sc->sc_irs + 1 + sc->sc_win)) { m_freem(m); (void)syn_cache_respond(sc); - splx(s); + mutex_exit(&scp->sch_mtx); return ((struct socket *)(-1)); } /* Remove this cache entry */ syn_cache_rm(sc); - splx(s); + mutex_exit(&scp->sch_mtx); /* * Ok, create the full blown connection, and set things up @@ -613,7 +639,9 @@ syn_cache_get(struct sockaddr *src, struct sockaddr *dst, in4p_laddr(inp) = ((struct sockaddr_in *)dst)->sin_addr; inp->inp_lport = ((struct sockaddr_in *)dst)->sin_port; inp->inp_options = ip_srcroute(m); + INP_HASH_LOCK(inp->inp_table); inpcb_set_state(inp, INP_BOUND); + INP_HASH_UNLOCK(inp->inp_table); if (inp->inp_options == NULL) { inp->inp_options = sc->sc_ipopts; sc->sc_ipopts = NULL; @@ -633,7 +661,9 @@ syn_cache_get(struct sockaddr *src, struct sockaddr *dst, inp->inp_flags |= IN6P_IPV6_V6ONLY; else inp->inp_flags &= ~IN6P_IPV6_V6ONLY; + INP_HASH_LOCK(inp->inp_table); inpcb_set_state(inp, INP_BOUND); + INP_HASH_UNLOCK(inp->inp_table); } #endif break; @@ -642,7 +672,9 @@ syn_cache_get(struct sockaddr *src, struct sockaddr *dst, if (inp->inp_af == AF_INET6) { in6p_laddr(inp) = ((struct sockaddr_in6 *)dst)->sin6_addr; inp->inp_lport = ((struct sockaddr_in6 *)dst)->sin6_port; + INP_HASH_LOCK(inp->inp_table); inpcb_set_state(inp, INP_BOUND); + INP_HASH_UNLOCK(inp->inp_table); } break; #endif @@ -773,9 +805,7 @@ syn_cache_get(struct sockaddr *src, struct sockaddr *dst, tp->t_dupacks = 0; TCP_STATINC(TCP_STAT_SC_COMPLETED); - s = splsoftnet(); syn_cache_put(sc); - splx(s); return so; resetandabort: @@ -784,11 +814,14 @@ abort: if (so != NULL) { (void) soqremque(so, 1); (void) soabort(so); - mutex_enter(softnet_lock); + /* + * soabort drops the socket lock (shared with listener + * via sonewconn). Re-acquire the listener's lock + * for the caller (tcp_input). + */ + solock(oso); } - s = splsoftnet(); syn_cache_put(sc); - splx(s); TCP_STATINC(TCP_STAT_SC_ABORTED); return ((struct socket *)(-1)); } @@ -804,21 +837,19 @@ syn_cache_reset(struct sockaddr *src, struct sockaddr *dst, struct tcphdr *th) { struct syn_cache *sc; struct syn_cache_head *scp; - int s = splsoftnet(); - if ((sc = syn_cache_lookup(src, dst, &scp)) == NULL) { - splx(s); + /* syn_cache_lookup returns with scp->sch_mtx held on match */ + if ((sc = syn_cache_lookup(src, dst, &scp)) == NULL) return; - } if (SEQ_LT(th->th_seq, sc->sc_irs) || SEQ_GT(th->th_seq, sc->sc_irs+1)) { - splx(s); + mutex_exit(&scp->sch_mtx); return; } syn_cache_rm(sc); + mutex_exit(&scp->sch_mtx); TCP_STATINC(TCP_STAT_SC_RESET); - syn_cache_put(sc); /* calls pool_put but see spl above */ - splx(s); + syn_cache_put(sc); } void @@ -827,16 +858,13 @@ syn_cache_unreach(const struct sockaddr *src, const struct sockaddr *dst, { struct syn_cache *sc; struct syn_cache_head *scp; - int s; - s = splsoftnet(); - if ((sc = syn_cache_lookup(src, dst, &scp)) == NULL) { - splx(s); + /* syn_cache_lookup returns with scp->sch_mtx held on match */ + if ((sc = syn_cache_lookup(src, dst, &scp)) == NULL) return; - } /* If the sequence number != sc_iss, then it's a bogus ICMP msg */ if (ntohl(th->th_seq) != sc->sc_iss) { - splx(s); + mutex_exit(&scp->sch_mtx); return; } @@ -850,14 +878,14 @@ syn_cache_unreach(const struct sockaddr *src, const struct sockaddr *dst, */ if ((sc->sc_flags & SCF_UNREACH) == 0 || sc->sc_rxtshift < 3) { sc->sc_flags |= SCF_UNREACH; - splx(s); + mutex_exit(&scp->sch_mtx); return; } syn_cache_rm(sc); + mutex_exit(&scp->sch_mtx); TCP_STATINC(TCP_STAT_SC_UNREACH); - syn_cache_put(sc); /* calls pool_put but see spl above */ - splx(s); + syn_cache_put(sc); } /* @@ -883,7 +911,6 @@ syn_cache_add(struct sockaddr *src, struct sockaddr *dst, struct tcphdr *th, struct syn_cache *sc; struct syn_cache_head *scp; struct mbuf *ipopts; - int s; tp = sototcpcb(so); @@ -924,6 +951,7 @@ syn_cache_add(struct sockaddr *src, struct sockaddr *dst, struct tcphdr *th, * If we do, resend the SYN,ACK. We do not count this * as a retransmission (XXX though maybe we should). */ + /* syn_cache_lookup returns with scp->sch_mtx held on match */ if ((sc = syn_cache_lookup(src, dst, &scp)) != NULL) { TCP_STATINC(TCP_STAT_SC_DUPESYN); if (ipopts) { @@ -943,12 +971,11 @@ syn_cache_add(struct sockaddr *src, struct sockaddr *dst, struct tcphdr *th, _NET_STATINC_REF(tcps, TCP_STAT_SNDTOTAL); TCP_STAT_PUTREF(); } + mutex_exit(&scp->sch_mtx); return 1; } - s = splsoftnet(); sc = pool_get(&syn_cache_pool, PR_NOWAIT); - splx(s); if (sc == NULL) { if (ipopts) (void)m_free(ipopts); @@ -1052,14 +1079,12 @@ syn_cache_add(struct sockaddr *src, struct sockaddr *dst, struct tcphdr *th, TCP_STAT_PUTREF(); syn_cache_insert(sc, tp); } else { - s = splsoftnet(); /* * syn_cache_put() will try to schedule the timer, so * we need to initialize it */ syn_cache_timer_arm(sc); syn_cache_put(sc); - splx(s); TCP_STATINC(TCP_STAT_SC_DROPPED); } return 1; diff --git a/sys/netinet/tcp_syncache.h b/sys/netinet/tcp_syncache.h index 5e5f824dcdb6..da7623a6c907 100644 --- a/sys/netinet/tcp_syncache.h +++ b/sys/netinet/tcp_syncache.h @@ -143,6 +143,7 @@ #ifdef _KERNEL #include #include +#include #include #include @@ -194,6 +195,7 @@ struct syn_cache { }; struct syn_cache_head { + kmutex_t sch_mtx; /* per-bucket lock */ TAILQ_HEAD(, syn_cache) sch_bucket; /* bucket entries */ u_short sch_length; /* # entries in bucket */ }; diff --git a/sys/netinet/tcp_timer.c b/sys/netinet/tcp_timer.c index 1b53830acff9..ea36eca36ad7 100644 --- a/sys/netinet/tcp_timer.c +++ b/sys/netinet/tcp_timer.c @@ -224,6 +224,7 @@ void tcp_delack(void *arg) { struct tcpcb *tp = arg; + kmutex_t *lock = tp->t_lock; /* * If tcp_output() wasn't able to transmit the ACK @@ -231,21 +232,19 @@ tcp_delack(void *arg) * ACK callout. */ - mutex_enter(softnet_lock); + mutex_obj_hold(lock); + mutex_enter(lock); if ((tp->t_flags & (TF_DEAD | TF_DELACK)) != TF_DELACK) { - mutex_exit(softnet_lock); - return; - } - if (!callout_expired(&tp->t_delack_ch)) { - mutex_exit(softnet_lock); + mutex_exit(lock); + mutex_obj_free(lock); return; } - tp->t_flags |= TF_ACKNOW; - KERNEL_LOCK(1, NULL); + KERNEL_LOCK_UNLESS_NET_MPSAFE(); (void) tcp_output(tp); - KERNEL_UNLOCK_ONE(NULL); - mutex_exit(softnet_lock); + KERNEL_UNLOCK_UNLESS_NET_MPSAFE(); + mutex_exit(lock); + mutex_obj_free(lock); } /* @@ -301,48 +300,45 @@ void tcp_timer_rexmt(void *arg) { struct tcpcb *tp = arg; + kmutex_t *lock = tp->t_lock; uint32_t rto; #ifdef TCP_DEBUG struct socket *so = NULL; short ostate; #endif - mutex_enter(softnet_lock); + mutex_obj_hold(lock); + mutex_enter(lock); if ((tp->t_flags & TF_DEAD) != 0) { - mutex_exit(softnet_lock); - return; - } - if (!callout_expired(&tp->t_timer[TCPT_REXMT])) { - mutex_exit(softnet_lock); + mutex_exit(lock); + mutex_obj_free(lock); return; } - - KERNEL_LOCK(1, NULL); + tp->t_timer_armed &= ~(1 << TCPT_REXMT); + KERNEL_LOCK_UNLESS_NET_MPSAFE(); if ((tp->t_flags & TF_PMTUD_PEND) && tp->t_inpcb && SEQ_GEQ(tp->t_pmtud_th_seq, tp->snd_una) && SEQ_LT(tp->t_pmtud_th_seq, (int)(tp->snd_una + tp->t_ourmss))) { - extern struct sockaddr_in icmpsrc; struct icmp icmp; + struct in_addr dst; tp->t_flags &= ~TF_PMTUD_PEND; - /* XXX create fake icmp message with relevant entries */ + /* Build fake ICMP message from stored PMTUD parameters */ + memset(&icmp, 0, sizeof(icmp)); icmp.icmp_nextmtu = tp->t_pmtud_nextmtu; icmp.icmp_ip.ip_len = tp->t_pmtud_ip_len; icmp.icmp_ip.ip_hl = tp->t_pmtud_ip_hl; - icmpsrc.sin_addr = in4p_faddr(tp->t_inpcb); - icmp_mtudisc(&icmp, icmpsrc.sin_addr); + dst = in4p_faddr(tp->t_inpcb); /* - * Notify all connections to the same peer about - * new mss and trigger retransmit. + * Update the route MTU and notify this connection. + * icmp_mtudisc uses a local sockaddr (no global state), + * safe to call under per-socket lock. */ - inpcb_notifyall(&tcbtable, icmpsrc.sin_addr, EMSGSIZE, - tcp_mtudisc); - KERNEL_UNLOCK_ONE(NULL); - mutex_exit(softnet_lock); - return; - } + icmp_mtudisc(&icmp, dst); + tcp_mtudisc(tp->t_inpcb, 0); + } #ifdef TCP_DEBUG so = tp->t_inpcb->inp_socket; ostate = tp->t_state; @@ -440,31 +436,31 @@ tcp_timer_rexmt(void *arg) tcp_trace(TA_USER, ostate, tp, NULL, PRU_SLOWTIMO | (TCPT_REXMT << 8)); #endif - KERNEL_UNLOCK_ONE(NULL); - mutex_exit(softnet_lock); + KERNEL_UNLOCK_UNLESS_NET_MPSAFE(); + mutex_exit(lock); + mutex_obj_free(lock); } void tcp_timer_persist(void *arg) { struct tcpcb *tp = arg; + kmutex_t *lock = tp->t_lock; uint32_t rto; #ifdef TCP_DEBUG struct socket *so = NULL; short ostate; #endif - mutex_enter(softnet_lock); + mutex_obj_hold(lock); + mutex_enter(lock); if ((tp->t_flags & TF_DEAD) != 0) { - mutex_exit(softnet_lock); + mutex_exit(lock); + mutex_obj_free(lock); return; } - if (!callout_expired(&tp->t_timer[TCPT_PERSIST])) { - mutex_exit(softnet_lock); - return; - } - - KERNEL_LOCK(1, NULL); + tp->t_timer_armed &= ~(1 << TCPT_PERSIST); + KERNEL_LOCK_UNLESS_NET_MPSAFE(); #ifdef TCP_DEBUG so = tp->t_inpcb->inp_socket; ostate = tp->t_state; @@ -492,6 +488,13 @@ tcp_timer_persist(void *arg) tp = tcp_drop(tp, ETIMEDOUT); goto out; } + /* + * If rexmt is now armed, the connection recovered while + * we waited for solock (tcp_input opened the window and + * armed rexmt). Skip persist, rexmt handles it. + */ + if (TCP_TIMER_ISARMED(tp, TCPT_REXMT)) + goto out; TCP_STATINC(TCP_STAT_PERSISTTIMEO); tcp_setpersist(tp); tp->t_force = 1; @@ -504,30 +507,30 @@ tcp_timer_persist(void *arg) tcp_trace(TA_USER, ostate, tp, NULL, PRU_SLOWTIMO | (TCPT_PERSIST << 8)); #endif - KERNEL_UNLOCK_ONE(NULL); - mutex_exit(softnet_lock); + KERNEL_UNLOCK_UNLESS_NET_MPSAFE(); + mutex_exit(lock); + mutex_obj_free(lock); } void tcp_timer_keep(void *arg) { struct tcpcb *tp = arg; + kmutex_t *lock = tp->t_lock; struct socket *so = NULL; /* Quell compiler warning */ #ifdef TCP_DEBUG short ostate; #endif - mutex_enter(softnet_lock); + mutex_obj_hold(lock); + mutex_enter(lock); if ((tp->t_flags & TF_DEAD) != 0) { - mutex_exit(softnet_lock); - return; - } - if (!callout_expired(&tp->t_timer[TCPT_KEEP])) { - mutex_exit(softnet_lock); + mutex_exit(lock); + mutex_obj_free(lock); return; } - - KERNEL_LOCK(1, NULL); + tp->t_timer_armed &= ~(1 << TCPT_KEEP); + KERNEL_LOCK_UNLESS_NET_MPSAFE(); #ifdef TCP_DEBUG ostate = tp->t_state; @@ -576,41 +579,42 @@ tcp_timer_keep(void *arg) tcp_trace(TA_USER, ostate, tp, NULL, PRU_SLOWTIMO | (TCPT_KEEP << 8)); #endif - KERNEL_UNLOCK_ONE(NULL); - mutex_exit(softnet_lock); + KERNEL_UNLOCK_UNLESS_NET_MPSAFE(); + mutex_exit(lock); + mutex_obj_free(lock); return; dropit: TCP_STATINC(TCP_STAT_KEEPDROPS); (void) tcp_drop(tp, ETIMEDOUT); - KERNEL_UNLOCK_ONE(NULL); - mutex_exit(softnet_lock); + KERNEL_UNLOCK_UNLESS_NET_MPSAFE(); + mutex_exit(lock); + mutex_obj_free(lock); } void tcp_timer_2msl(void *arg) { struct tcpcb *tp = arg; + kmutex_t *lock = tp->t_lock; #ifdef TCP_DEBUG struct socket *so = NULL; short ostate; #endif - mutex_enter(softnet_lock); + mutex_obj_hold(lock); + mutex_enter(lock); if ((tp->t_flags & TF_DEAD) != 0) { - mutex_exit(softnet_lock); + mutex_exit(lock); + mutex_obj_free(lock); return; } - if (!callout_expired(&tp->t_timer[TCPT_2MSL])) { - mutex_exit(softnet_lock); - return; - } - /* * 2 MSL timeout went off, clear the SACK scoreboard, reset * the FACK estimate. */ - KERNEL_LOCK(1, NULL); + tp->t_timer_armed &= ~(1 << TCPT_2MSL); + KERNEL_LOCK_UNLESS_NET_MPSAFE(); tcp_free_sackholes(tp); tp->snd_fack = tp->snd_una; @@ -637,6 +641,7 @@ tcp_timer_2msl(void *arg) tcp_trace(TA_USER, ostate, tp, NULL, PRU_SLOWTIMO | (TCPT_2MSL << 8)); #endif - KERNEL_UNLOCK_ONE(NULL); - mutex_exit(softnet_lock); + KERNEL_UNLOCK_UNLESS_NET_MPSAFE(); + mutex_exit(lock); + mutex_obj_free(lock); } diff --git a/sys/netinet/tcp_timer.h b/sys/netinet/tcp_timer.h index ebc0bb0dd765..13bded2e5382 100644 --- a/sys/netinet/tcp_timer.h +++ b/sys/netinet/tcp_timer.h @@ -155,15 +155,25 @@ const char *tcptimers[] = * nticks is given in units of slow timeouts, * typically 500 ms (with PR_SLOWHZ at 2). */ -#define TCP_TIMER_ARM(tp, timer, nticks) \ +/* + * TCP-layer armed state, decoupled from callout PENDING/FIRED. + * Set on arm, cleared on disarm and at the top of each callback. + * Avoids both the stale-FIRED problem (callout_active) and the + * dispatched-but-not-yet-run problem (callout_pending). + */ +#define TCP_TIMER_ARM(tp, timer, nticks) do { \ + (tp)->t_timer_armed |= (1 << (timer)); \ callout_schedule(&(tp)->t_timer[(timer)], \ - (nticks) * (hz / PR_SLOWHZ)) + (nticks) * (hz / PR_SLOWHZ)); \ +} while (0) -#define TCP_TIMER_DISARM(tp, timer) \ - callout_stop(&(tp)->t_timer[(timer)]) +#define TCP_TIMER_DISARM(tp, timer) do { \ + (tp)->t_timer_armed &= ~(1 << (timer)); \ + callout_stop(&(tp)->t_timer[(timer)]); \ +} while (0) #define TCP_TIMER_ISARMED(tp, timer) \ - callout_active(&(tp)->t_timer[(timer)]) + ((tp)->t_timer_armed & (1 << (timer))) #define TCP_TIMER_MAXTICKS \ (INT_MAX / (hz / PR_SLOWHZ)) diff --git a/sys/netinet/tcp_usrreq.c b/sys/netinet/tcp_usrreq.c index 2ee356376477..050f20687882 100644 --- a/sys/netinet/tcp_usrreq.c +++ b/sys/netinet/tcp_usrreq.c @@ -448,9 +448,16 @@ tcp_attach(struct socket *so, int proto) struct inpcb *inp; int s, error, family; - /* Assign the lock (must happen even if we will error out). */ + /* + * Assign a per-socket lock. Each TCP socket gets its own mutex + * so that connections on different CPUs can be processed in parallel. + * sonewconn() inherits the listener's lock for accepted connections. + */ s = splsoftnet(); - sosetlock(so); + if (so->so_lock == NULL) { + so->so_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + mutex_enter(so->so_lock); + } KASSERT(solocked(so)); KASSERT(sotoinpcb(so) == NULL); @@ -489,6 +496,7 @@ tcp_attach(struct socket *so, int proto) goto out; } tp->t_state = TCPS_CLOSED; + tp->t_lock = so->so_lock; if ((so->so_options & SO_LINGER) && so->so_linger == 0) { so->so_linger = TCP_LINGERTIME; } @@ -547,6 +555,58 @@ tcp_accept(struct socket *so, struct sockaddr *nam) tcp_debug_trace(so, tp, ostate, PRU_ACCEPT); splx(s); + /* + * Migrate the accepted socket from the listener's shared lock + * to its own per-connection lock. This allows connections + * accepted on the same port to be processed in parallel. + * Pattern follows Unix domain sockets (uipc_usrreq.c). + */ + { + kmutex_t *old_lock = so->so_lock; + kmutex_t *new_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + int i; + + /* + * Halt all TCP timers before migrating the lock. + * Timer callbacks read tp->t_lock unsynchronized, + * so no callback can be in-flight during migration. + * callout_halt releases/reacquires old_lock as needed. + */ + for (i = 0; i < TCPT_NTIMERS; i++) + callout_halt(&tp->t_timer[i], old_lock); + callout_halt(&tp->t_delack_ch, old_lock); + + /* Migrate to per-connection lock */ + mutex_enter(new_lock); + membar_release(); + solockreset(so, new_lock); + tp->t_lock = new_lock; + + /* Re-arm timers under the new lock */ + if (TCPS_HAVEESTABLISHED(tp->t_state)) + TCP_TIMER_ARM(tp, TCPT_KEEP, tp->t_keepidle); + if (tp->snd_nxt != tp->snd_una) { + TCP_TIMER_ARM(tp, TCPT_REXMT, tp->t_rxtcur); + } else if (so->so_snd.sb_cc > 0) { + /* + * Data in send buffer but snd_nxt == snd_una: + * connection is in persist state (receiver window + * closed). Must re-arm persist timer or the + * connection will be stuck permanently. + */ + tcp_setpersist(tp); + } + if (tp->t_flags & TF_DELACK) + TCP_RESTART_DELACK(tp); + + mutex_exit(new_lock); + mutex_obj_free(old_lock); + /* + * Listener's lock (old_lock) remains held and the caller + * (do_sys_accept) releases it via sounlock(listener). + */ + } + return 0; } @@ -1421,26 +1481,33 @@ inet4_ident_core(struct in_addr raddr, u_int rport, inp = inpcb_lookup(&tcbtable, raddr, rport, laddr, lport, 0); - if (inp == NULL || (sockp = inp->inp_socket) == NULL) + if (inp == NULL) return ESRCH; + sockp = inp->inp_socket; if (dodrop) { struct tcpcb *tp; int error; - if (inp == NULL || (tp = intotcpcb(inp)) == NULL || - (inp->inp_socket->so_options & SO_ACCEPTCONN) != 0) + if ((tp = intotcpcb(inp)) == NULL || + (inp->inp_socket->so_options & SO_ACCEPTCONN) != 0) { + sounlock(sockp); return ESRCH; + } error = kauth_authorize_network(l->l_cred, KAUTH_NETWORK_SOCKET, KAUTH_REQ_NETWORK_SOCKET_DROP, inp->inp_socket, tp, NULL); - if (error) + if (error) { + sounlock(sockp); return error; + } (void)tcp_drop(tp, ECONNABORTED); + sounlock(sockp); return 0; } + sounlock(sockp); return copyout_uid(sockp, oldp, oldlenp); } @@ -1456,26 +1523,33 @@ inet6_ident_core(struct in6_addr *raddr, u_int rport, inp = in6pcb_lookup(&tcbtable, raddr, rport, laddr, lport, 0, 0); - if (inp == NULL || (sockp = inp->inp_socket) == NULL) + if (inp == NULL) return ESRCH; + sockp = inp->inp_socket; if (dodrop) { struct tcpcb *tp; int error; - if (inp == NULL || (tp = intotcpcb(inp)) == NULL || - (inp->inp_socket->so_options & SO_ACCEPTCONN) != 0) + if ((tp = intotcpcb(inp)) == NULL || + (inp->inp_socket->so_options & SO_ACCEPTCONN) != 0) { + sounlock(sockp); return ESRCH; + } error = kauth_authorize_network(l->l_cred, KAUTH_NETWORK_SOCKET, KAUTH_REQ_NETWORK_SOCKET_DROP, inp->inp_socket, tp, NULL); - if (error) + if (error) { + sounlock(sockp); return error; + } (void)tcp_drop(tp, ECONNABORTED); + sounlock(sockp); return 0; } + sounlock(sockp); return copyout_uid(sockp, oldp, oldlenp); } #endif @@ -1653,7 +1727,9 @@ sysctl_inpcblist(SYSCTLFN_ARGS) mutex_enter(softnet_lock); - TAILQ_FOREACH(inp, &pcbtbl->inpt_queue, inp_queue) { + INP_HASH_LOCK(pcbtbl); + PSLIST_WRITER_FOREACH(inp, &pcbtbl->inpt_queue_pslist, + struct inpcb, inp_queue_hash) { if (inp->inp_af != pf) continue; @@ -1762,6 +1838,7 @@ sysctl_inpcblist(SYSCTLFN_ARGS) if (len >= elem_size && elem_count > 0) { error = copyout(&pcb, dp, out_size); if (error) { + INP_HASH_UNLOCK(pcbtbl); mutex_exit(softnet_lock); return error; } @@ -1772,6 +1849,7 @@ sysctl_inpcblist(SYSCTLFN_ARGS) if (elem_count > 0 && elem_count != INT_MAX) elem_count--; } + INP_HASH_UNLOCK(pcbtbl); *oldlenp = needed; if (oldp == NULL) @@ -2330,6 +2408,15 @@ tcp_usrreq_init(void) #endif } +#ifdef NET_MPSAFE +/* + * TCP is fully MP-safe with per-socket locking. All usrreqs run + * under solock (held by the socket layer), so KERNEL_LOCK wrappers + * must NOT be used, acquiring KERNEL_LOCK inside solock inverts + * the lock order vs ctlinput (which takes KERNEL_LOCK then solock). + */ +/* No PR_WRAP_USRREQS all functions used directly */ +#else PR_WRAP_USRREQS(tcp) #define tcp_attach tcp_attach_wrapper #define tcp_detach tcp_detach_wrapper @@ -2350,6 +2437,7 @@ PR_WRAP_USRREQS(tcp) #define tcp_send tcp_send_wrapper #define tcp_sendoob tcp_sendoob_wrapper #define tcp_purgeif tcp_purgeif_wrapper +#endif const struct pr_usrreqs tcp_usrreqs = { .pr_attach = tcp_attach, diff --git a/sys/netinet/tcp_var.h b/sys/netinet/tcp_var.h index 7e2f8576eafc..da654f2afbad 100644 --- a/sys/netinet/tcp_var.h +++ b/sys/netinet/tcp_var.h @@ -215,6 +215,7 @@ struct tcpcb { struct ipqehead segq; /* sequencing queue */ int t_segqlen; /* length of the above */ callout_t t_timer[TCPT_NTIMERS];/* tcp timers */ + uint8_t t_timer_armed; /* bitmask: TCP-layer armed state */ short t_state; /* state of this connection */ short t_rxtshift; /* log(2) of rexmt exp. backoff */ uint32_t t_rxtcur; /* current retransmit value */ @@ -253,6 +254,7 @@ struct tcpcb { struct mbuf *t_template; /* skeletal packet for transmit */ struct inpcb *t_inpcb; /* back pointer to internet pcb */ + kmutex_t *t_lock; /* cached so->so_lock for timers */ callout_t t_delack_ch; /* delayed ACK callout */ /* * The following fields are used as in the protocol specification. diff --git a/sys/netinet/tcp_vtw.c b/sys/netinet/tcp_vtw.c index 94f574a8342f..5c8073939016 100644 --- a/sys/netinet/tcp_vtw.c +++ b/sys/netinet/tcp_vtw.c @@ -133,33 +133,16 @@ vtw_ctl_t vtw_tcpv4[VTW_NCLASS]; vtw_ctl_t vtw_tcpv6[VTW_NCLASS]; vtw_stats_t vtw_stats; -/* We provide state for the lookup_ports iterator. - * As currently we are netlock-protected, there is one. - * If we were finer-grain, we would have one per CPU. - * I do not want to be in the business of alloc/free. - * The best alternate would be allocate on the caller's - * stack, but that would require them to know the struct, - * or at least the size. - * See how she goes. +/* + * VTW lock protects all VTW global data structures (fat pointer + * hash tables, clock-hand allocator, VTW entries). Replaces the + * former dependency on softnet_lock. */ -struct tcp_ports_iterator { - union { - struct in_addr v4; - struct in6_addr v6; - } addr; - u_int port; - - uint32_t wild : 1; - - vtw_ctl_t *ctl; - fatp_t *fp; +static kmutex_t vtw_lock; - uint16_t slot_idx; - uint16_t ctl_idx; -}; - -static struct tcp_ports_iterator tcp_ports_iterator_v4; -static struct tcp_ports_iterator tcp_ports_iterator_v6; +/* struct tcp_ports_iterator is defined in tcp_vtw.h so that + * callers can stack-allocate the iterator state. + */ static int vtw_age(vtw_ctl_t *, struct timeval *); @@ -675,7 +658,7 @@ vtw_unhash(vtw_ctl_t *ctl, vtw_t *vtw) void vtw_del(vtw_ctl_t *ctl, vtw_t *vtw) { - KASSERT(mutex_owned(softnet_lock)); + KASSERT(mutex_owned(&vtw_lock)); if (vtw->hashed) { ++vtw_stats.del; @@ -708,7 +691,7 @@ vtw_inshash_v4(vtw_ctl_t *ctl, vtw_t *vtw) uint32_t tag; vtw_v4_t *v4 = (void*)vtw; - KASSERT(mutex_owned(softnet_lock)); + KASSERT(mutex_owned(&vtw_lock)); KASSERT(!vtw->hashed); KASSERT(ctl->clidx == vtw->msl_class); @@ -747,7 +730,7 @@ vtw_inshash_v6(vtw_ctl_t *ctl, vtw_t *vtw) uint32_t tag; vtw_v6_t *v6 = (void*)vtw; - KASSERT(mutex_owned(softnet_lock)); + KASSERT(mutex_owned(&vtw_lock)); KASSERT(!vtw->hashed); KASSERT(ctl->clidx == vtw->msl_class); @@ -1380,7 +1363,7 @@ vtw_alloc(vtw_ctl_t *ctl) int avail = ctl ? (ctl->nalloc + ctl->nfree) : 0; int msl; - KASSERT(mutex_owned(softnet_lock)); + KASSERT(mutex_owned(&vtw_lock)); /* If no resources, we will not get far. */ @@ -1552,7 +1535,7 @@ vtw_tick(void *arg) db_trace(KTR_VTW, (arg, "vtk: tick - now %8.8x:%8.8x" , now.tv_sec, now.tv_usec)); - mutex_enter(softnet_lock); + mutex_enter(&vtw_lock); for (i = 0; i < VTW_NCLASS; ++i) { cnt += vtw_age(&vtw_tcpv4[i], &now); @@ -1567,16 +1550,15 @@ vtw_tick(void *arg) tcp_vtw_was_enabled = 0; tcbtable.vestige = 0; } - mutex_exit(softnet_lock); + mutex_exit(&vtw_lock); } /* inpcb_lookup_locals assist for handling vestigial entries. */ static void * -tcp_init_ports_v4(struct in_addr addr, u_int port, int wild) +tcp_init_ports_v4(struct in_addr addr, u_int port, int wild, + struct tcp_ports_iterator *it) { - struct tcp_ports_iterator *it = &tcp_ports_iterator_v4; - bzero(it, sizeof (*it)); /* Note: the reference to vtw_tcpv4[0] is fine. @@ -1655,27 +1637,29 @@ tcp_lookup_v4(struct in_addr faddr, uint16_t fport, { vtw_t *vtw; vtw_ctl_t *ctl; - + int rc; db_trace(KTR_VTW , (res, "vtw: lookup %A:%P %A:%P" , faddr, fport , laddr, lport)); + mutex_enter(&vtw_lock); vtw = vtw_lookup_hash_v4((ctl = &vtw_tcpv4[0]) , faddr.s_addr, fport , laddr.s_addr, lport, 0); - return vtw_export_v4(ctl, vtw, res); + rc = vtw_export_v4(ctl, vtw, res); + mutex_exit(&vtw_lock); + return rc; } /* inpcb_lookup_locals assist for handling vestigial entries. */ static void * -tcp_init_ports_v6(const struct in6_addr *addr, u_int port, int wild) +tcp_init_ports_v6(const struct in6_addr *addr, u_int port, int wild, + struct tcp_ports_iterator *it) { - struct tcp_ports_iterator *it = &tcp_ports_iterator_v6; - bzero(it, sizeof (*it)); /* Note: the reference to vtw_tcpv6[0] is fine. @@ -1755,17 +1739,21 @@ tcp_lookup_v6(const struct in6_addr *faddr, uint16_t fport, { vtw_ctl_t *ctl; vtw_t *vtw; + int rc; db_trace(KTR_VTW , (res, "vtw: lookup %6A:%P %6A:%P" , db_store(faddr, sizeof (*faddr)), fport , db_store(laddr, sizeof (*laddr)), lport)); + mutex_enter(&vtw_lock); vtw = vtw_lookup_hash_v6((ctl = &vtw_tcpv6[0]) , faddr, fport , laddr, lport, 0); - return vtw_export_v6(ctl, vtw, res); + rc = vtw_export_v6(ctl, vtw, res); + mutex_exit(&vtw_lock); + return rc; } static vestigial_hooks_t tcp_hooks = { @@ -1882,11 +1870,13 @@ vtw_add(int af, struct tcpcb *tp) vtw_ctl_t *ctl; vtw_t *vtw; - KASSERT(mutex_owned(softnet_lock)); + mutex_enter(&vtw_lock); ctl = vtw_control(af, tp->t_msl); - if (!ctl) + if (!ctl) { + mutex_exit(&vtw_lock); return 0; + } #ifdef VTW_DEBUG enable = (af == AF_INET) ? tcp4_vtw_enable : tcp6_vtw_enable; @@ -1938,14 +1928,14 @@ vtw_add(int af, struct tcpcb *tp) /* Immediate port iterator functionality check: not wild */ if (enable & 8) { - struct tcp_ports_iterator *it; + struct tcp_ports_iterator it; struct vestigial_inpcb res; int cnt = 0; - it = tcp_init_ports_v4(in4p_laddr(inp) - , inp->inp_lport, 0); + tcp_init_ports_v4(in4p_laddr(inp) + , inp->inp_lport, 0, &it); - while (tcp_next_port_v4(it, &res)) { + while (tcp_next_port_v4(&it, &res)) { ++cnt; } KASSERT(cnt); @@ -1953,16 +1943,16 @@ vtw_add(int af, struct tcpcb *tp) /* Immediate port iterator functionality check: wild */ if (enable & 16) { - struct tcp_ports_iterator *it; + struct tcp_ports_iterator it; struct vestigial_inpcb res; struct in_addr any; int cnt = 0; any.s_addr = htonl(INADDR_ANY); - it = tcp_init_ports_v4(any, inp->inp_lport, 1); + tcp_init_ports_v4(any, inp->inp_lport, 1, &it); - while (tcp_next_port_v4(it, &res)) { + while (tcp_next_port_v4(&it, &res)) { ++cnt; } KASSERT(cnt); @@ -2008,14 +1998,14 @@ vtw_add(int af, struct tcpcb *tp) /* Immediate port iterator functionality check: not wild */ if (enable & 8) { - struct tcp_ports_iterator *it; + struct tcp_ports_iterator it; struct vestigial_inpcb res; int cnt = 0; - it = tcp_init_ports_v6(&in6p_laddr(inp) - , inp->inp_lport, 0); + tcp_init_ports_v6(&in6p_laddr(inp) + , inp->inp_lport, 0, &it); - while (tcp_next_port_v6(it, &res)) { + while (tcp_next_port_v6(&it, &res)) { ++cnt; } KASSERT(cnt); @@ -2023,15 +2013,15 @@ vtw_add(int af, struct tcpcb *tp) /* Immediate port iterator functionality check: wild */ if (enable & 16) { - struct tcp_ports_iterator *it; + struct tcp_ports_iterator it; struct vestigial_inpcb res; static struct in6_addr any = IN6ADDR_ANY_INIT; int cnt = 0; - it = tcp_init_ports_v6(&any - , inp->inp_lport, 1); + tcp_init_ports_v6(&any + , inp->inp_lport, 1, &it); - while (tcp_next_port_v6(it, &res)) { + while (tcp_next_port_v6(&it, &res)) { ++cnt; } KASSERT(cnt); @@ -2042,12 +2032,14 @@ vtw_add(int af, struct tcpcb *tp) } tcp_canceltimers(tp); + mutex_exit(&vtw_lock); tp = tcp_close(tp); KASSERT(!tp); return 1; } + mutex_exit(&vtw_lock); return 0; } @@ -2061,7 +2053,7 @@ vtw_restart_v4(vestigial_inpcb_t *vp) vtw_t *cp = ©.common; vtw_ctl_t *ctl; - KASSERT(mutex_owned(softnet_lock)); + KASSERT(mutex_owned(&vtw_lock)); db_trace(KTR_VTW , (vp->vtw, "vtw: restart %A:%P %A:%P" @@ -2109,7 +2101,7 @@ vtw_restart_v6(vestigial_inpcb_t *vp) vtw_t *cp = ©.common; vtw_ctl_t *ctl; - KASSERT(mutex_owned(softnet_lock)); + KASSERT(mutex_owned(&vtw_lock)); db_trace(KTR_VTW , (vp->vtw, "vtw: restart %6A:%P %6A:%P" @@ -2157,10 +2149,12 @@ vtw_restart(vestigial_inpcb_t *vp) if (!vp || !vp->valid) return; + mutex_enter(&vtw_lock); if (vp->v4) vtw_restart_v4(vp); else vtw_restart_v6(vp); + mutex_exit(&vtw_lock); } int @@ -2198,7 +2192,8 @@ vtw_earlyinit(void) { int i, rc; - callout_init(&vtw_cs, 0); + mutex_init(&vtw_lock, MUTEX_DEFAULT, IPL_SOFTNET); + callout_init(&vtw_cs, CALLOUT_MPSAFE); callout_setfunc(&vtw_cs, vtw_tick, 0); for (i = 0; i < VTW_NCLASS; ++i) { @@ -2293,7 +2288,7 @@ vtw_debug_process(vtw_sysargs_t *ap) struct vestigial_inpcb vestige; int rc = 0; - mutex_enter(softnet_lock); + mutex_enter(&vtw_lock); switch (ap->op) { case 0: // insert @@ -2308,28 +2303,36 @@ vtw_debug_process(vtw_sysargs_t *ap) case 2: // restart switch (ap->la.sin_family) { case AF_INET: - if (tcp_lookup_v4(ap->fa.sin_addr.v4, ap->fa.sin_port, - ap->la.sin_addr.v4, ap->la.sin_port, - &vestige)) { - if (ap->op == 2) { - vtw_restart(&vestige); - } + { + vtw_t *vtw; + vtw = vtw_lookup_hash_v4(&vtw_tcpv4[0], + ap->fa.sin_addr.v4.s_addr, ap->fa.sin_port, + ap->la.sin_addr.v4.s_addr, ap->la.sin_port, 0); + if (vtw && vtw_export_v4(&vtw_tcpv4[0], vtw, + &vestige)) { + if (ap->op == 2) + vtw_restart_v4(&vestige); rc = 0; } else rc = ESRCH; break; + } case AF_INET6: - if (tcp_lookup_v6(&ap->fa.sin_addr.v6, ap->fa.sin_port, - &ap->la.sin_addr.v6, ap->la.sin_port, - &vestige)) { - if (ap->op == 2) { - vtw_restart(&vestige); - } + { + vtw_t *vtw; + vtw = vtw_lookup_hash_v6(&vtw_tcpv6[0], + &ap->fa.sin_addr.v6, ap->fa.sin_port, + &ap->la.sin_addr.v6, ap->la.sin_port, 0); + if (vtw && vtw_export_v6(&vtw_tcpv6[0], vtw, + &vestige)) { + if (ap->op == 2) + vtw_restart_v6(&vestige); rc = 0; } else rc = ESRCH; break; + } default: rc = EINVAL; } @@ -2339,7 +2342,7 @@ vtw_debug_process(vtw_sysargs_t *ap) rc = EINVAL; } - mutex_exit(softnet_lock); + mutex_exit(&vtw_lock); return rc; } diff --git a/sys/netinet/tcp_vtw.h b/sys/netinet/tcp_vtw.h index 237db807c043..f364e4811e14 100644 --- a/sys/netinet/tcp_vtw.h +++ b/sys/netinet/tcp_vtw.h @@ -390,6 +390,22 @@ typedef struct vestigial_inpcb { struct vtw_ctl *ctl; } vestigial_inpcb_t; +struct tcp_ports_iterator { + union { + struct in_addr v4; + struct in6_addr v6; + } addr; + u_int port; + + uint32_t wild : 1; + + vtw_ctl_t *ctl; + fatp_t *fp; + + uint16_t slot_idx; + uint16_t ctl_idx; +}; + #ifdef _KERNEL void vtw_restart(vestigial_inpcb_t*); int vtw_earlyinit(void); diff --git a/sys/netinet/udp_usrreq.c b/sys/netinet/udp_usrreq.c index 59e9559be7e0..3b5c9df7c085 100644 --- a/sys/netinet/udp_usrreq.c +++ b/sys/netinet/udp_usrreq.c @@ -524,8 +524,14 @@ udp4_realinput(struct sockaddr_in *src, struct sockaddr_in *dst, */ /* * Locate pcb(s) for datagram. + * + * Hold INP_HASH_LOCK for structural stability during + * PSLIST iteration. Per-socket locks are acquired + * individually for each sendup (blocking, so WRITER). */ - TAILQ_FOREACH(inp, &udbtable.inpt_queue, inp_queue) { + INP_HASH_LOCK(&udbtable); + PSLIST_WRITER_FOREACH(inp, &udbtable.inpt_queue_pslist, + struct inpcb, inp_queue_hash) { if (inp->inp_af != AF_INET) continue; @@ -541,8 +547,10 @@ udp4_realinput(struct sockaddr_in *src, struct sockaddr_in *dst, continue; } + solock(inp->inp_socket); udp4_sendup(m, off, (struct sockaddr *)src, inp->inp_socket); + sounlock(inp->inp_socket); rcvcnt++; /* @@ -557,6 +565,7 @@ udp4_realinput(struct sockaddr_in *src, struct sockaddr_in *dst, (SO_REUSEPORT|SO_REUSEADDR)) == 0) break; } + INP_HASH_UNLOCK(&udbtable); } else { /* * Locate pcb for datagram. @@ -570,6 +579,11 @@ udp4_realinput(struct sockaddr_in *src, struct sockaddr_in *dst, return rcvcnt; } + /* + * Lookup returned with solock held. Hold it throughout + * processing so that inp/socket fields remain stable. + */ + #ifdef IPSEC /* Handle ESP over UDP */ if (inp->inp_flags & INP_ESPINUDP) { @@ -577,11 +591,13 @@ udp4_realinput(struct sockaddr_in *src, struct sockaddr_in *dst, case -1: /* Error, m was freed */ KASSERT(*mp == NULL); rcvcnt = -1; + sounlock(inp->inp_socket); goto bad; case 1: /* ESP over UDP */ KASSERT(*mp == NULL); rcvcnt++; + sounlock(inp->inp_socket); goto bad; case 0: /* plain UDP */ @@ -603,11 +619,13 @@ udp4_realinput(struct sockaddr_in *src, struct sockaddr_in *dst, case -1: /* Error, m was freed */ KASSERT(*mp == NULL); rcvcnt = -1; + sounlock(inp->inp_socket); goto bad; case 1: /* Foo over UDP */ KASSERT(*mp == NULL); rcvcnt++; + sounlock(inp->inp_socket); goto bad; case 0: /* plain UDP */ @@ -624,11 +642,14 @@ udp4_realinput(struct sockaddr_in *src, struct sockaddr_in *dst, /* * Check the minimum TTL for socket. */ - if (mtod(m, struct ip *)->ip_ttl < in4p_ip_minttl(inp)) + if (mtod(m, struct ip *)->ip_ttl < in4p_ip_minttl(inp)) { + sounlock(inp->inp_socket); goto bad; + } udp4_sendup(m, off, (struct sockaddr *)src, inp->inp_socket); rcvcnt++; + sounlock(inp->inp_socket); } bad: @@ -644,9 +665,13 @@ bad: static void udp_notify(struct inpcb *inp, int errno) { - inp->inp_socket->so_error = errno; - sorwakeup(inp->inp_socket); - sowwakeup(inp->inp_socket); + struct socket *so = inp->inp_socket; + + solock(so); + so->so_error = errno; + sorwakeup(so); + sowwakeup(so); + sounlock(so); } void * @@ -674,9 +699,22 @@ udp_ctlinput(int cmd, const struct sockaddr *sa, void *v) } if (ip) { + struct inpcb *inp; + uh = (struct udphdr *)((char *)ip + (ip->ip_hl << 2)); - inpcb_notify(&udbtable, satocsin(sa)->sin_addr, uh->uh_dport, - ip->ip_src, uh->uh_sport, errno, notify); + inp = inpcb_lookup(&udbtable, satocsin(sa)->sin_addr, + uh->uh_dport, ip->ip_src, uh->uh_sport, NULL); + if (inp != NULL) { + kmutex_t *lock = inp->inp_socket->so_lock; + mutex_obj_hold(lock); + bool acquired __diagused = inpcb_ref_acquire(inp); + KASSERT(acquired); + mutex_exit(lock); + (*notify)(inp, errno); + if (inpcb_ref_release(inp)) + inpcb_pool_put(inp); + mutex_obj_free(lock); + } /* XXX mapped address case */ } else { inpcb_notifyall(&udbtable, satocsin(sa)->sin_addr, errno, @@ -860,8 +898,14 @@ udp_attach(struct socket *so, int proto) KASSERT(sotoinpcb(so) == NULL); - /* Assign the lock (must happen even if we will error out). */ - sosetlock(so); + /* + * Assign a per-socket lock. Each UDP socket gets its own mutex + * so that datagrams on different CPUs can be processed in parallel. + */ + if (so->so_lock == NULL) { + so->so_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + mutex_enter(so->so_lock); + } #ifdef MBUFTRACE so->so_mowner = &udp_mowner; @@ -975,7 +1019,9 @@ udp_disconnect(struct socket *so) so->so_state &= ~SS_ISCONNECTED; /* XXX */ inpcb_disconnect(inp); in4p_laddr(inp) = zeroin_addr; /* XXX */ + INP_HASH_LOCK(inp->inp_table); inpcb_set_state(inp, INP_BOUND); /* XXX */ + INP_HASH_UNLOCK(inp->inp_table); splx(s); return 0; @@ -1105,7 +1151,9 @@ udp_send(struct socket *so, struct mbuf *m, struct sockaddr *nam, if (nam) { inpcb_disconnect(inp); in4p_laddr(inp) = laddr; /* XXX */ + INP_HASH_LOCK(inp->inp_table); inpcb_set_state(inp, INP_BOUND); /* XXX */ + INP_HASH_UNLOCK(inp->inp_table); } die: m_freem(m); @@ -1344,6 +1392,15 @@ udp4_espinudp(struct mbuf **mp, int off) } #endif +#ifdef NET_MPSAFE +/* + * UDP is fully MP-safe with per-socket locking. All usrreqs run + * under solock (held by the socket layer), so KERNEL_LOCK wrappers + * must NOT be used, acquiring KERNEL_LOCK inside solock inverts + * the lock order vs ctlinput (which takes KERNEL_LOCK then solock). + */ +/* No PR_WRAP_USRREQS all functions used directly */ +#else PR_WRAP_USRREQS(udp) #define udp_attach udp_attach_wrapper #define udp_detach udp_detach_wrapper @@ -1364,6 +1421,7 @@ PR_WRAP_USRREQS(udp) #define udp_send udp_send_wrapper #define udp_sendoob udp_sendoob_wrapper #define udp_purgeif udp_purgeif_wrapper +#endif const struct pr_usrreqs udp_usrreqs = { .pr_attach = udp_attach, diff --git a/sys/netinet6/dccp6_usrreq.c b/sys/netinet6/dccp6_usrreq.c index 3e0d5be9ceff..fede31b2e856 100644 --- a/sys/netinet6/dccp6_usrreq.c +++ b/sys/netinet6/dccp6_usrreq.c @@ -340,10 +340,10 @@ dccp6_attach(struct socket *so, int proto) return dccp_attach(so, proto); } -static int +static void dccp6_detach(struct socket *so) { - return dccp_detach(so); + dccp_detach(so); } static int diff --git a/sys/netinet6/icmp6.c b/sys/netinet6/icmp6.c index c72fe762ee8b..fc3f5c560ad4 100644 --- a/sys/netinet6/icmp6.c +++ b/sys/netinet6/icmp6.c @@ -1945,6 +1945,7 @@ icmp6_rip6_input(struct mbuf **mp, int off) struct sockaddr_in6 rip6src; struct icmp6_hdr *icmp6; struct mbuf *n, *opts = NULL; + int s; IP6_EXTHDR_GET(icmp6, struct icmp6_hdr *, m, off, sizeof(*icmp6)); if (icmp6 == NULL) { @@ -1962,7 +1963,9 @@ icmp6_rip6_input(struct mbuf **mp, int off) return IPPROTO_DONE; } - TAILQ_FOREACH(inp, &raw6cbtable.inpt_queue, inp_queue) { + s = pserialize_read_enter(); + PSLIST_READER_FOREACH(inp, &raw6cbtable.inpt_queue_pslist, + struct inpcb, inp_queue_hash) { if (inp->inp_af != AF_INET6) continue; if (in6p_ip6(inp).ip6_nxt != IPPROTO_ICMPV6) @@ -2008,6 +2011,7 @@ icmp6_rip6_input(struct mbuf **mp, int off) #ifdef IPSEC if (ipsec_used && last && ipsec_in_reject(m, last)) { + pserialize_read_exit(s); m_freem(m); IP6_STATDEC(IP6_STAT_DELIVERED); /* do not inject data into pcb */ @@ -2027,7 +2031,9 @@ icmp6_rip6_input(struct mbuf **mp, int off) } else { sorwakeup(last->inp_socket); } + pserialize_read_exit(s); } else { + pserialize_read_exit(s); m_freem(m); IP6_STATDEC(IP6_STAT_DELIVERED); } diff --git a/sys/netinet6/in6.c b/sys/netinet6/in6.c index c71cb6edbba3..8d42c3fdd4a6 100644 --- a/sys/netinet6/in6.c +++ b/sys/netinet6/in6.c @@ -2374,8 +2374,6 @@ in6_tunnel_validate(const struct ip6_hdr *ip6, const struct in6_addr *src, } #define IN6_LLTBL_DEFAULT_HSIZE 32 -#define IN6_LLTBL_HASH(k, h) \ - (((((((k >> 8) ^ k) >> 8) ^ k) >> 8) ^ k) & ((h) - 1)) /* * Do actual deallocation of @lle. @@ -2463,13 +2461,6 @@ in6_lltable_rtcheck(struct ifnet *ifp, u_int flags, return 0; } -static inline uint32_t -in6_lltable_hash_dst(const struct in6_addr *dst, uint32_t hsize) -{ - - return IN6_LLTBL_HASH(dst->s6_addr32[3], hsize); -} - static uint32_t in6_lltable_hash(const struct llentry *lle, uint32_t hsize) { @@ -2493,12 +2484,12 @@ static inline struct llentry * in6_lltable_find_dst(struct lltable *llt, const struct in6_addr *dst) { struct llentry *lle; - struct llentries *lleh; + struct pslist_head *lleh; u_int hashidx; hashidx = in6_lltable_hash_dst(dst, llt->llt_hsize); lleh = &llt->lle_head[hashidx]; - LIST_FOREACH(lle, lleh, lle_next) { + PSLIST_WRITER_FOREACH(lle, lleh, struct llentry, lle_next) { if (lle->la_flags & LLE_DELETED) continue; if (IN6_ARE_ADDR_EQUAL(&lle->r_l3addr.addr6, dst)) diff --git a/sys/netinet6/in6_pcb.c b/sys/netinet6/in6_pcb.c index 6ce9af6ef01b..029e3aceddcc 100644 --- a/sys/netinet6/in6_pcb.c +++ b/sys/netinet6/in6_pcb.c @@ -71,6 +71,7 @@ __KERNEL_RCSID(0, "$NetBSD: in6_pcb.c,v 1.178 2026/03/01 20:42:03 andvar Exp $") #include #include +#include #include #include #include @@ -117,12 +118,12 @@ const struct in6_addr zeroin6_addr; (laddr)->s6_addr32[2] ^ (laddr)->s6_addr32[3]) + ntohs(lport)) & \ (table)->inpt_bindhash] #define IN6PCBHASH_CONNECT(table, faddr, fport, laddr, lport) \ - &(table)->inpt_bindhashtbl[ \ + &(table)->inpt_connecthashtbl[ \ ((((faddr)->s6_addr32[0] ^ (faddr)->s6_addr32[1] ^ \ (faddr)->s6_addr32[2] ^ (faddr)->s6_addr32[3]) + ntohs(fport)) + \ (((laddr)->s6_addr32[0] ^ (laddr)->s6_addr32[1] ^ \ (laddr)->s6_addr32[2] ^ (laddr)->s6_addr32[3]) + \ - ntohs(lport))) & (table)->inpt_bindhash] + ntohs(lport))) & (table)->inpt_connecthash] int ip6_anonportmin = IPV6PORT_ANONMIN; int ip6_anonportmax = IPV6PORT_ANONMAX; @@ -279,8 +280,12 @@ in6pcb_bind_port(struct inpcb *inp, struct sockaddr_in6 *sin6, struct lwp *l) t = inpcb_lookup_local(table, *(struct in_addr *)&sin6->sin6_addr.s6_addr32[3], sin6->sin6_port, wild, &vestige); - if (t && (reuseport & t->inp_socket->so_options) == 0) + if (t && (reuseport & t->inp_socket->so_options) == 0) { + inp_lookup_unlock(t); return EADDRINUSE; + } + if (t) + inp_lookup_unlock(t); if (!t && vestige.valid && !(reuseport && vestige.reuse_port)) @@ -296,8 +301,12 @@ in6pcb_bind_port(struct inpcb *inp, struct sockaddr_in6 *sin6, struct lwp *l) t = in6pcb_lookup_local(table, &sin6->sin6_addr, sin6->sin6_port, wild, &vestige); - if (t && (reuseport & t->inp_socket->so_options) == 0) + if (t && (reuseport & t->inp_socket->so_options) == 0) { + inp_lookup_unlock(t); return EADDRINUSE; + } + if (t) + inp_lookup_unlock(t); if (!t && vestige.valid && !(reuseport && vestige.reuse_port)) @@ -312,12 +321,17 @@ in6pcb_bind_port(struct inpcb *inp, struct sockaddr_in6 *sin6, struct lwp *l) return e; } else { inp->inp_lport = sin6->sin6_port; + INP_HASH_LOCK(inp->inp_table); inpcb_set_state(inp, INP_BOUND); + INP_HASH_UNLOCK(inp->inp_table); } - LIST_REMOVE(inp, inp_lhash); - LIST_INSERT_HEAD(IN6PCBHASH_PORT(table, inp->inp_lport), - inp, inp_lhash); + INP_HASH_LOCK(table); + PSLIST_WRITER_REMOVE(inp, inp_port_hash); + PSLIST_ENTRY_INIT(inp, inp_port_hash); + PSLIST_WRITER_INSERT_HEAD(IN6PCBHASH_PORT(table, inp->inp_lport), + inp, inp_port_hash); + INP_HASH_UNLOCK(table); return 0; } @@ -395,6 +409,7 @@ in6pcb_connect(void *v, struct sockaddr_in6 *sin6, struct lwp *l) struct in6_addr mapped; #endif struct sockaddr_in6 tmp; + struct inpcb *t; struct vestigial_inpcb vestige; struct psref psref; int bound; @@ -504,11 +519,16 @@ in6pcb_connect(void *v, struct sockaddr_in6 *sin6, struct lwp *l) in6p_ip6(inp).ip6_hlim = (u_int8_t)in6pcb_selecthlim_rt(inp); curlwp_bindx(bound); - if (in6pcb_lookup(inp->inp_table, &sin6->sin6_addr, + t = in6pcb_lookup(inp->inp_table, &sin6->sin6_addr, sin6->sin6_port, - IN6_IS_ADDR_UNSPECIFIED(&in6p_laddr(inp)) ? in6a : &in6p_laddr(inp), - inp->inp_lport, 0, &vestige) - || vestige.valid) + IN6_IS_ADDR_UNSPECIFIED(&in6p_laddr(inp)) ? + in6a : &in6p_laddr(inp), + inp->inp_lport, 0, &vestige); + if (t != NULL) { + inp_lookup_unlock(t); + return EADDRINUSE; + } + if (vestige.valid) return EADDRINUSE; if (IN6_IS_ADDR_UNSPECIFIED(&in6p_laddr(inp)) || (IN6_IS_ADDR_V4MAPPED(&in6p_laddr(inp)) && @@ -535,7 +555,9 @@ in6pcb_connect(void *v, struct sockaddr_in6 *sin6, struct lwp *l) return error; } + INP_HASH_LOCK(inp->inp_table); inpcb_set_state(inp, INP_CONNECTED); + INP_HASH_UNLOCK(inp->inp_table); in6p_flowinfo(inp) &= ~IPV6_FLOWLABEL_MASK; if (ip6_auto_flowlabel) in6p_flowinfo(inp) |= @@ -552,7 +574,9 @@ in6pcb_disconnect(struct inpcb *inp) { memset((void *)&in6p_faddr(inp), 0, sizeof(in6p_faddr(inp))); inp->inp_fport = 0; + INP_HASH_LOCK(inp->inp_table); inpcb_set_state(inp, INP_BOUND); + INP_HASH_UNLOCK(inp->inp_table); in6p_flowinfo(inp) &= ~IPV6_FLOWLABEL_MASK; #if defined(IPSEC) if (ipsec_enabled) @@ -642,11 +666,28 @@ in6pcb_notify(struct inpcbtable *table, const struct sockaddr *dst, } errno = inet6ctlerrmap[cmd]; - TAILQ_FOREACH(inp, &table->inpt_queue, inp_queue) { + + /* + * Hand-over-hand refcount iteration (like FreeBSD's inp_next). + * Route cache reads (rtcache_validate/getdst/unref) are + * non-blocking and safe inside pserialize. The blocking notify + * callback and ip6_notify_pmtu are called outside pserialize. + */ + { + struct inpcb *next; + int s; + bool matched, do_pmtu; + + s = pserialize_read_enter(); + for (inp = PSLIST_READER_FIRST(&table->inpt_queue_pslist, + struct inpcb, inp_queue_hash); inp != NULL; inp = next) { struct rtentry *rt = NULL; - if (inp->inp_af != AF_INET6) + if (inp->inp_af != AF_INET6) { + next = PSLIST_READER_NEXT(inp, struct inpcb, + inp_queue_hash); continue; + } /* * Under the following condition, notify of redirects @@ -679,6 +720,8 @@ in6pcb_notify(struct inpcbtable *table, const struct sockaddr *dst, * pmtud) may be a good idea. netbsd/openbsd has it. see * icmp6_mtudisc_update(). */ + matched = false; + do_pmtu = false; if ((PRC_IS_REDIRECT(cmd) || cmd == PRC_HOSTDEAD) && IN6_IS_ADDR_UNSPECIFIED(&in6p_laddr(inp)) && (rt = rtcache_validate(&inp->inp_route)) != NULL && @@ -690,56 +733,116 @@ in6pcb_notify(struct inpcbtable *table, const struct sockaddr *dst, if (dst6 == NULL) ; else if (IN6_ARE_ADDR_EQUAL(&dst6->sin6_addr, - &sa6_dst->sin6_addr)) { - rtcache_unref(rt, &inp->inp_route); - goto do_notify; - } + &sa6_dst->sin6_addr)) + matched = true; } rtcache_unref(rt, &inp->inp_route); /* - * If the error designates a new path MTU for a destination - * and the application (associated with this socket) wanted to - * know the value, notify. Note that we notify for all - * disconnected sockets if the corresponding application - * wanted. This is because some UDP applications keep sending - * sockets disconnected. - * XXX: should we avoid to notify the value to TCP sockets? + * If the error designates a new path MTU for a + * destination and the application wanted to know, + * note it for ip6_notify_pmtu (called outside + * pserialize). This is independent of the notify + * callback -- a socket can get PMTU notification + * without matching the address/port check below. */ - if (cmd == PRC_MSGSIZE && (inp->inp_flags & IN6P_MTU) != 0 && + if (cmd == PRC_MSGSIZE && + (inp->inp_flags & IN6P_MTU) != 0 && (IN6_IS_ADDR_UNSPECIFIED(&in6p_faddr(inp)) || - IN6_ARE_ADDR_EQUAL(&in6p_faddr(inp), &sa6_dst->sin6_addr))) { - ip6_notify_pmtu(inp, (const struct sockaddr_in6 *)dst, - (u_int32_t *)cmdarg); + IN6_ARE_ADDR_EQUAL(&in6p_faddr(inp), + &sa6_dst->sin6_addr))) { + do_pmtu = true; + } + + if (!matched) { + /* + * Detect if we should notify the error. If no + * source and destination ports are specified, but + * non-zero flowinfo and local address match, notify + * the error. This is the case when the error is + * delivered with an encrypted buffer by ESP. + * Otherwise, just compare addresses and ports as + * usual. + */ + if (lport == 0 && fport == 0 && flowinfo && + inp->inp_socket != NULL && + flowinfo == + (in6p_flowinfo(inp) & IPV6_FLOWLABEL_MASK) && + IN6_ARE_ADDR_EQUAL(&in6p_laddr(inp), + &sa6_src.sin6_addr)) + matched = true; + else if (!IN6_ARE_ADDR_EQUAL(&in6p_faddr(inp), + &sa6_dst->sin6_addr) || + inp->inp_socket == NULL || + (lport && inp->inp_lport != lport) || + (!IN6_IS_ADDR_UNSPECIFIED(&sa6_src.sin6_addr) && + !IN6_ARE_ADDR_EQUAL(&in6p_laddr(inp), + &sa6_src.sin6_addr)) || + (fport && inp->inp_fport != fport)) { + if (!do_pmtu) { + next = PSLIST_READER_NEXT(inp, + struct inpcb, inp_queue_hash); + continue; + } + /* PMTU side-effect only, no notify */ + } else { + matched = true; + } + } + + if (matched) + nmatch++; + + if (!matched && !do_pmtu) { + next = PSLIST_READER_NEXT(inp, struct inpcb, + inp_queue_hash); + continue; } /* - * Detect if we should notify the error. If no source and - * destination ports are specified, but non-zero flowinfo and - * local address match, notify the error. This is the case - * when the error is delivered with an encrypted buffer - * by ESP. Otherwise, just compare addresses and ports - * as usual. + * Hand-over-hand: pin next element, then ref current. + * Both happen inside pserialize so the list is stable. */ - if (lport == 0 && fport == 0 && flowinfo && - inp->inp_socket != NULL && - flowinfo == (in6p_flowinfo(inp) & IPV6_FLOWLABEL_MASK) && - IN6_ARE_ADDR_EQUAL(&in6p_laddr(inp), &sa6_src.sin6_addr)) - goto do_notify; - else if (!IN6_ARE_ADDR_EQUAL(&in6p_faddr(inp), - &sa6_dst->sin6_addr) || - inp->inp_socket == NULL || - (lport && inp->inp_lport != lport) || - (!IN6_IS_ADDR_UNSPECIFIED(&sa6_src.sin6_addr) && - !IN6_ARE_ADDR_EQUAL(&in6p_laddr(inp), - &sa6_src.sin6_addr)) || - (fport && inp->inp_fport != fport)) - continue; + next = PSLIST_READER_NEXT(inp, struct inpcb, + inp_queue_hash); + while (next != NULL && !inpcb_ref_acquire(next)) + next = PSLIST_READER_NEXT(next, struct inpcb, + inp_queue_hash); + if (!inpcb_ref_acquire(inp)) { + /* + * Current is dying -- cannot safely deliver. + * Release next's pin and stop iterating. + */ + if (next != NULL) { + if (inpcb_ref_release(next)) + inpcb_pool_put(next); + } + break; + } + pserialize_read_exit(s); - do_notify: - if (notify) + /* + * Blocking work outside pserialize: ip6_notify_pmtu + * calls sbappendaddr (needs solock), and the notify + * callback (e.g. in6pcb_rtchange) may call rtcache_free. + */ + if (do_pmtu) + ip6_notify_pmtu(inp, + (const struct sockaddr_in6 *)dst, + (u_int32_t *)cmdarg); + if (matched && notify) (*notify)(inp, errno); - nmatch++; + if (inpcb_ref_release(inp)) + inpcb_pool_put(inp); + + s = pserialize_read_enter(); + /* Release next's pin; still safe, we are in pserialize */ + if (next != NULL) { + bool last __diagused = inpcb_ref_release(next); + KASSERT(!last); + } + } + pserialize_read_exit(s); } return nmatch; } @@ -753,7 +856,9 @@ in6pcb_purgeif0(struct inpcbtable *table, struct ifnet *ifp) KASSERT(ifp != NULL); - TAILQ_FOREACH(inp, &table->inpt_queue, inp_queue) { + INP_HASH_LOCK(table); + PSLIST_WRITER_FOREACH(inp, &table->inpt_queue_pslist, + struct inpcb, inp_queue_hash) { bool need_unlock = false; if (inp->inp_af != AF_INET6) continue; @@ -792,6 +897,7 @@ in6pcb_purgeif0(struct inpcbtable *table, struct ifnet *ifp) if (need_unlock) inp_unlock(inp); } + INP_HASH_UNLOCK(table); } void @@ -800,7 +906,9 @@ in6pcb_purgeif(struct inpcbtable *table, struct ifnet *ifp) struct rtentry *rt; struct inpcb *inp; - TAILQ_FOREACH(inp, &table->inpt_queue, inp_queue) { + INP_HASH_LOCK(table); + PSLIST_WRITER_FOREACH(inp, &table->inpt_queue_pslist, + struct inpcb, inp_queue_hash) { if (inp->inp_af != AF_INET6) continue; if ((rt = rtcache_validate(&inp->inp_route)) != NULL && @@ -810,6 +918,7 @@ in6pcb_purgeif(struct inpcbtable *table, struct ifnet *ifp) } else rtcache_unref(rt, &inp->inp_route); } + INP_HASH_UNLOCK(table); } /* @@ -823,26 +932,26 @@ in6pcb_rtchange(struct inpcb *inp, int errno) return; rtcache_free(&inp->inp_route); - /* - * A new route can be allocated the next time - * output is attempted. - */ } struct inpcb * in6pcb_lookup_local(struct inpcbtable *table, struct in6_addr *laddr6, u_int lport_arg, int lookup_wildcard, struct vestigial_inpcb *vp) { - struct inpcbhead *head; + struct pslist_head *head; struct inpcb *inp, *match = NULL; int matchwild = 3, wildcard; in_port_t lport = lport_arg; + int s; if (vp) vp->valid = 0; head = IN6PCBHASH_PORT(table, lport); - LIST_FOREACH(inp, head, inp_lhash) { + s = pserialize_read_enter(); + PSLIST_READER_FOREACH(inp, head, struct inpcb, inp_port_hash) { + if (atomic_load_acquire(&inp->inp_state) == INP_FREED) + continue; if (inp->inp_af != AF_INET6) continue; @@ -907,17 +1016,33 @@ in6pcb_lookup_local(struct inpcbtable *table, struct in6_addr *laddr6, break; } } + if (match != NULL) { + struct socket *so = match->inp_socket; + if (__predict_true(so != NULL && + mutex_tryenter(so->so_lock))) { + if (__predict_false(atomic_load_acquire( + &match->inp_state) == INP_FREED)) { + mutex_exit(so->so_lock); + match = NULL; + } + } else { + match = NULL; + } + } + pserialize_read_exit(s); if (match && matchwild == 0) return match; if (vp && table->vestige && table->vestige->init_ports6) { + struct tcp_ports_iterator ports_it; struct vestigial_inpcb better; bool has_better = false; void *state; state = (*table->vestige->init_ports6)(laddr6, lport_arg, - lookup_wildcard); + lookup_wildcard, + &ports_it); while (table->vestige && (*table->vestige->next_port6)(state, vp)) { @@ -954,6 +1079,8 @@ in6pcb_lookup_local(struct inpcbtable *table, struct in6_addr *laddr6, } if (has_better) { + if (match != NULL) + sounlock(match->inp_socket); *vp = better; return 0; } @@ -1045,17 +1172,21 @@ in6pcb_lookup(struct inpcbtable *table, const struct in6_addr *faddr6, int faith, struct vestigial_inpcb *vp) { - struct inpcbhead *head; + struct pslist_head *head; struct inpcb *inp; in_port_t fport = fport_arg, lport = lport_arg; + int s; if (vp) vp->valid = 0; head = IN6PCBHASH_CONNECT(table, faddr6, fport, laddr6, lport); - LIST_FOREACH(inp, head, inp_hash) { + s = pserialize_read_enter(); + PSLIST_READER_FOREACH(inp, head, struct inpcb, inp_connect_hash) { if (inp->inp_af != AF_INET6) continue; + if (atomic_load_acquire(&inp->inp_state) == INP_FREED) + continue; /* find exact match on both source and dest */ if (inp->inp_fport != fport) @@ -1074,8 +1205,60 @@ in6pcb_lookup(struct inpcbtable *table, const struct in6_addr *faddr6, IN6_IS_ADDR_V4MAPPED(faddr6)) && (inp->inp_flags & IN6P_IPV6_V6ONLY)) continue; - return inp; + { + struct socket *so = inp->inp_socket; + if (__predict_true(so != NULL && + mutex_tryenter(so->so_lock))) { + if (__predict_false(atomic_load_acquire( + &inp->inp_state) == INP_FREED)) { + mutex_exit(so->so_lock); + inp = NULL; + } + break; + } + if (so != NULL) { + /* + * Grab a refcount on the inpcb to keep + * it alive across the blocking lock + * acquire outside pserialize. + */ + if (!inpcb_ref_acquire(inp)) + continue; + { + kmutex_t *lock = so->so_lock; + bool last; + mutex_obj_hold(lock); + pserialize_read_exit(s); + mutex_enter(lock); + if (__predict_false(lock != so->so_lock)) { + mutex_exit(lock); + mutex_obj_free(lock); + if (inpcb_ref_release(inp)) + inpcb_pool_put(inp); + return 0; + } + if (__predict_false(atomic_load_acquire( + &inp->inp_state) == INP_FREED)) { + mutex_exit(lock); + mutex_obj_free(lock); + if (inpcb_ref_release(inp)) + inpcb_pool_put(inp); + return 0; + } + mutex_obj_free(lock); + last = inpcb_ref_release(inp); + KASSERT(!last); + return inp; + } + } + continue; + } } + pserialize_read_exit(s); + + if (inp != NULL) + return inp; + if (vp && table->vestige) { if ((*table->vestige->lookup6)(faddr6, fport_arg, laddr6, lport_arg, vp)) @@ -1089,17 +1272,21 @@ struct inpcb * in6pcb_lookup_bound(struct inpcbtable *table, const struct in6_addr *laddr6, u_int lport_arg, int faith) { - struct inpcbhead *head; + struct pslist_head *head; struct inpcb *inp; in_port_t lport = lport_arg; + int s; #ifdef INET struct in6_addr zero_mapped; #endif + s = pserialize_read_enter(); head = IN6PCBHASH_BIND(table, laddr6, lport); - LIST_FOREACH(inp, head, inp_hash) { + PSLIST_READER_FOREACH(inp, head, struct inpcb, inp_bind_hash) { if (inp->inp_af != AF_INET6) continue; + if (atomic_load_acquire(&inp->inp_state) == INP_FREED) + continue; if (faith && (inp->inp_flags & IN6P_FAITH) == 0) continue; @@ -1110,17 +1297,30 @@ in6pcb_lookup_bound(struct inpcbtable *table, const struct in6_addr *laddr6, if (IN6_IS_ADDR_V4MAPPED(laddr6) && (inp->inp_flags & IN6P_IPV6_V6ONLY) != 0) continue; - if (IN6_ARE_ADDR_EQUAL(&in6p_laddr(inp), laddr6)) - goto out; + if (IN6_ARE_ADDR_EQUAL(&in6p_laddr(inp), laddr6)) { + struct socket *so = inp->inp_socket; + if (__predict_true(so != NULL && + mutex_tryenter(so->so_lock))) { + if (__predict_false(atomic_load_acquire( + &inp->inp_state) == INP_FREED)) { + mutex_exit(so->so_lock); + continue; + } + goto out; + } + continue; + } } #ifdef INET if (IN6_IS_ADDR_V4MAPPED(laddr6)) { memset(&zero_mapped, 0, sizeof(zero_mapped)); zero_mapped.s6_addr16[5] = 0xffff; head = IN6PCBHASH_BIND(table, &zero_mapped, lport); - LIST_FOREACH(inp, head, inp_hash) { + PSLIST_READER_FOREACH(inp, head, struct inpcb, inp_bind_hash) { if (inp->inp_af != AF_INET6) continue; + if (atomic_load_acquire(&inp->inp_state) == INP_FREED) + continue; if (faith && (inp->inp_flags & IN6P_FAITH) == 0) continue; @@ -1130,15 +1330,28 @@ in6pcb_lookup_bound(struct inpcbtable *table, const struct in6_addr *laddr6, continue; if ((inp->inp_flags & IN6P_IPV6_V6ONLY) != 0) continue; - if (IN6_ARE_ADDR_EQUAL(&in6p_laddr(inp), &zero_mapped)) - goto out; + if (IN6_ARE_ADDR_EQUAL(&in6p_laddr(inp), &zero_mapped)) { + struct socket *so = inp->inp_socket; + if (__predict_true(so != NULL && + mutex_tryenter(so->so_lock))) { + if (__predict_false(atomic_load_acquire( + &inp->inp_state) == INP_FREED)) { + mutex_exit(so->so_lock); + continue; + } + goto out; + } + continue; + } } } #endif head = IN6PCBHASH_BIND(table, &zeroin6_addr, lport); - LIST_FOREACH(inp, head, inp_hash) { + PSLIST_READER_FOREACH(inp, head, struct inpcb, inp_bind_hash) { if (inp->inp_af != AF_INET6) continue; + if (atomic_load_acquire(&inp->inp_state) == INP_FREED) + continue; if (faith && (inp->inp_flags & IN6P_FAITH) == 0) continue; @@ -1149,16 +1362,24 @@ in6pcb_lookup_bound(struct inpcbtable *table, const struct in6_addr *laddr6, if (IN6_IS_ADDR_V4MAPPED(laddr6) && (inp->inp_flags & IN6P_IPV6_V6ONLY) != 0) continue; - if (IN6_ARE_ADDR_EQUAL(&in6p_laddr(inp), &zeroin6_addr)) - goto out; + if (IN6_ARE_ADDR_EQUAL(&in6p_laddr(inp), &zeroin6_addr)) { + struct socket *so = inp->inp_socket; + if (__predict_true(so != NULL && + mutex_tryenter(so->so_lock))) { + if (__predict_false(atomic_load_acquire( + &inp->inp_state) == INP_FREED)) { + mutex_exit(so->so_lock); + inp = NULL; + goto out; + } + goto out; + } + continue; + } } - return NULL; - + inp = NULL; out: - if (inp != LIST_FIRST(head)) { - LIST_REMOVE(inp, inp_hash); - LIST_INSERT_HEAD(head, inp, inp_hash); - } + pserialize_read_exit(s); return inp; } @@ -1169,20 +1390,32 @@ in6pcb_set_state(struct inpcb *inp, int state) if (inp->inp_af != AF_INET6) return; - if (inp->inp_state > INP_ATTACHED) - LIST_REMOVE(inp, inp_hash); + KASSERT(INP_HASH_LOCKED(inp->inp_table)); + + /* Remove from current hash (uses separate entry per hash table) */ + switch (inp->inp_state) { + case INP_BOUND: + PSLIST_WRITER_REMOVE(inp, inp_bind_hash); + PSLIST_ENTRY_INIT(inp, inp_bind_hash); + break; + case INP_CONNECTED: + PSLIST_WRITER_REMOVE(inp, inp_connect_hash); + PSLIST_ENTRY_INIT(inp, inp_connect_hash); + break; + } + /* Insert into new hash */ switch (state) { case INP_BOUND: - LIST_INSERT_HEAD(IN6PCBHASH_BIND(inp->inp_table, + PSLIST_WRITER_INSERT_HEAD(IN6PCBHASH_BIND(inp->inp_table, &in6p_laddr(inp), inp->inp_lport), inp, - inp_hash); + inp_bind_hash); break; case INP_CONNECTED: - LIST_INSERT_HEAD(IN6PCBHASH_CONNECT(inp->inp_table, + PSLIST_WRITER_INSERT_HEAD(IN6PCBHASH_CONNECT(inp->inp_table, &in6p_faddr(inp), inp->inp_fport, &in6p_laddr(inp), inp->inp_lport), inp, - inp_hash); + inp_connect_hash); break; } diff --git a/sys/netinet6/in6_proto.c b/sys/netinet6/in6_proto.c index 2fc720b35ed1..21216f778dd2 100644 --- a/sys/netinet6/in6_proto.c +++ b/sys/netinet6/in6_proto.c @@ -138,54 +138,22 @@ __KERNEL_RCSID(0, "$NetBSD: in6_proto.c,v 1.131 2024/02/09 22:08:37 andvar Exp $ DOMAIN_DEFINE(inet6domain); /* forward declare and add to link set */ -/* Wrappers to acquire kernel_lock. */ - -PR_WRAP_CTLINPUT(rip6_ctlinput) +/* + * Per-socket-lock protocols use internal locking for ctlinput. + * Only encap6 (IPsec) still needs the KERNEL_LOCK wrapper. + */ PR_WRAP_CTLINPUT(encap6_ctlinput) -PR_WRAP_CTLINPUT(udp6_ctlinput) -PR_WRAP_CTLINPUT(tcp6_ctlinput) -#define rip6_ctlinput rip6_ctlinput_wrapper #define encap6_ctlinput encap6_ctlinput_wrapper -#define udp6_ctlinput udp6_ctlinput_wrapper -#define tcp6_ctlinput tcp6_ctlinput_wrapper -PR_WRAP_CTLOUTPUT(rip6_ctloutput) -PR_WRAP_CTLOUTPUT(tcp_ctloutput) -PR_WRAP_CTLOUTPUT(udp6_ctloutput) +/* icmp6_ctloutput still needs KERNEL_LOCK (not per-socket lock) */ PR_WRAP_CTLOUTPUT(icmp6_ctloutput) -#define rip6_ctloutput rip6_ctloutput_wrapper -#define tcp_ctloutput tcp_ctloutput_wrapper -#define udp6_ctloutput udp6_ctloutput_wrapper #define icmp6_ctloutput icmp6_ctloutput_wrapper -#if defined(DCCP) -PR_WRAP_CTLINPUT(dccp6_ctlinput) -PR_WRAP_CTLOUTPUT(dccp_ctloutput) - -#define dccp6_ctlinput dccp6_ctlinput_wrapper -#define dccp_ctloutput dccp_ctloutput_wrapper -#endif - -#if defined(SCTP) -PR_WRAP_CTLINPUT(sctp6_ctlinput) -PR_WRAP_CTLOUTPUT(sctp_ctloutput) - -#define sctp6_ctlinput sctp6_ctlinput_wrapper -#define sctp_ctloutput sctp_ctloutput_wrapper -#endif - #ifdef NET_MPSAFE -PR_WRAP_INPUT6(udp6_input) -PR_WRAP_INPUT6(tcp6_input) -#ifdef DCCP -PR_WRAP_INPUT6(dccp6_input) -#endif -#ifdef SCTP -PR_WRAP_INPUT6(sctp6_input) -#endif -PR_WRAP_INPUT6(rip6_input) +/* All per-socket-lock protocols acquire solock internally, no input wrappers. */ +/* rip6_input uses pserialize for TAILQ iteration + tryenter for solock */ PR_WRAP_INPUT6(dest6_input) PR_WRAP_INPUT6(route6_input) PR_WRAP_INPUT6(frag6_input) @@ -194,14 +162,12 @@ PR_WRAP_INPUT6(pfsync_input) #endif PR_WRAP_INPUT6(pim6_input) -#define udp6_input udp6_input_wrapper -#define tcp6_input tcp6_input_wrapper -#define dccp6_input dccp6_input_wrapper -#define sctp6_input sctp6_input_wrapper -#define rip6_input rip6_input_wrapper #define dest6_input dest6_input_wrapper #define route6_input route6_input_wrapper #define frag6_input frag6_input_wrapper +#if NPFSYNC > 0 +#define pfsync_input pfsync_input_wrapper +#endif #define pim6_input pim6_input_wrapper #endif diff --git a/sys/netinet6/in6_src.c b/sys/netinet6/in6_src.c index 725e04d8d29d..abdb5be14a1c 100644 --- a/sys/netinet6/in6_src.c +++ b/sys/netinet6/in6_src.c @@ -878,7 +878,9 @@ in6pcb_set_port(struct sockaddr_in6 *sin6, struct inpcb *inp, struct lwp *l) inp->inp_flags |= IN6P_ANONPORT; *lastport = lport; inp->inp_lport = htons(lport); + INP_HASH_LOCK(inp->inp_table); in6pcb_set_state(inp, INP_BOUND); + INP_HASH_UNLOCK(inp->inp_table); return (0); /* success */ } diff --git a/sys/netinet6/in6_var.h b/sys/netinet6/in6_var.h index 470916b17394..ca6ca4334175 100644 --- a/sys/netinet6/in6_var.h +++ b/sys/netinet6/in6_var.h @@ -609,6 +609,16 @@ struct in6pcb; #define LLTABLE6(ifp) (((struct in6_ifextra *)(ifp)->if_afdata[AF_INET6])->lltable) +#define IN6_LLTBL_HASH(k, h) \ + (((((((k >> 8) ^ k) >> 8) ^ k) >> 8) ^ k) & ((h) - 1)) + +static inline uint32_t +in6_lltable_hash_dst(const struct in6_addr *dst, uint32_t hsize) +{ + + return IN6_LLTBL_HASH(dst->s6_addr32[3], hsize); +} + void in6_sysctl_multicast_setup(struct sysctllog **); #endif /* _KERNEL */ diff --git a/sys/netinet6/nd6.c b/sys/netinet6/nd6.c index 91ad16738563..2c2f9cea390a 100644 --- a/sys/netinet6/nd6.c +++ b/sys/netinet6/nd6.c @@ -129,7 +129,7 @@ struct nd_domain nd6_nd_domain = { .nd_maxretrans = MAX_RETRANS_TIMER, .nd_maxnudhint = 0, /* max # of subsequent upper layer hints */ .nd_gctimer = 24*60*60, /* stale neighbor GC timer duration */ - .nd_maxqueuelen = 1, /* max # of packets in unresolved ND entries */ + .nd_maxqueuelen = 16, /* max # of packets in unresolved ND entries */ .nd_nud_enabled = nd6_nud_enabled, .nd_reachable = nd6_llinfo_reachable, .nd_retrans = nd6_llinfo_retrans, @@ -1399,6 +1399,7 @@ nd6_cache_lladdr( * XXX is it dependent to ifp->if_type? */ memcpy(&ln->ll_addr, lladdr, ifp->if_addrlen); + membar_release(); ln->la_flags |= LLE_VALID; ln->la_flags &= ~LLE_UNRESOLVED; } @@ -1560,6 +1561,27 @@ nd6_slowtimo(void *ignored_arg) SOFTNET_KERNEL_UNLOCK_UNLESS_NET_MPSAFE(); } +/* + * Lock-free NDP lookup for the fast path. Must be called inside a + * pserialize read section. Returns the llentry if found, NULL otherwise. + */ +static struct llentry * +nd6_lookup_psz(struct ifnet *ifp, const struct in6_addr *addr6) +{ + struct lltable *llt = LLTABLE6(ifp); + uint32_t hashidx = in6_lltable_hash_dst(addr6, llt->llt_hsize); + struct pslist_head *head = &llt->lle_head[hashidx]; + struct llentry *lle; + + PSLIST_READER_FOREACH(lle, head, struct llentry, lle_next) { + if (lle->la_flags & LLE_DELETED) + continue; + if (IN6_ARE_ADDR_EQUAL(&lle->r_l3addr.addr6, addr6)) + return lle; + } + return NULL; +} + /* * Return 0 if a neighbor cache is found. Return EWOULDBLOCK if a cache is not * found and trying to resolve a neighbor; in this case the mbuf is queued in @@ -1573,6 +1595,7 @@ nd6_resolve(struct ifnet *ifp, const struct rtentry *rt, struct mbuf *m, bool created = false; const struct sockaddr_in6 *dst = satocsin6(_dst); int error; + int s; struct nd_kifinfo *ndi = ND_IFINFO(ifp); /* discard the packet if IPv6 operation is disabled on the interface */ @@ -1582,29 +1605,28 @@ nd6_resolve(struct ifnet *ifp, const struct rtentry *rt, struct mbuf *m, } /* - * Address resolution or Neighbor Unreachability Detection - * for the next hop. - * At this point, the destination of the packet must be a unicast - * or an anycast address(i.e. not a multicast). + * Fast path: pserialize read section, no locks. + * If the entry exists, is VALID, and REACHABLE/DELAY/PROBE, + * just copy the MAC address. + * Pserialize + deferred free guarantees the entry won't be + * freed under us. */ - - /* Look up the neighbor cache for the nexthop */ - ln = nd6_lookup(&dst->sin6_addr, ifp, false); - + s = pserialize_read_enter(); + ln = nd6_lookup_psz(ifp, &dst->sin6_addr); if (ln != NULL && (ln->la_flags & LLE_VALID) != 0 && - /* Only STALE needs to go the slow path to change its state. */ (ln->ln_state == ND_LLINFO_REACHABLE || ln->ln_state == ND_LLINFO_DELAY || ln->ln_state == ND_LLINFO_PROBE)) { - /* Fast path */ + membar_acquire(); memcpy(lldst, &ln->ll_addr, MIN(dstsize, ifp->if_addrlen)); - LLE_RUNLOCK(ln); + pserialize_read_exit(s); return 0; } - if (ln != NULL) - LLE_RUNLOCK(ln); + pserialize_read_exit(s); - /* Slow path */ + /* + * Slow path: take write lock directly, create or update entry. + */ ln = nd6_lookup(&dst->sin6_addr, ifp, true); if (ln == NULL && nd6_is_addr_neighbor(dst, ifp)) { /* diff --git a/sys/netinet6/raw_ip6.c b/sys/netinet6/raw_ip6.c index dfcfd28dcea1..68763c578b7d 100644 --- a/sys/netinet6/raw_ip6.c +++ b/sys/netinet6/raw_ip6.c @@ -170,6 +170,7 @@ rip6_input(struct mbuf **mp, int *offp, int proto) struct inpcb *last = NULL; struct sockaddr_in6 rip6src; struct mbuf *n; + int psz; RIP6_STATINC(RIP6_STAT_IPACKETS); @@ -188,7 +189,15 @@ rip6_input(struct mbuf **mp, int *offp, int proto) return IPPROTO_DONE; } - TAILQ_FOREACH(inp, &raw6cbtable.inpt_queue, inp_queue) { + /* + * pserialize protects the PSLIST iteration from concurrent + * insert/remove in inpcb_create/inpcb_destroy. Use tryenter + * for solock to avoid blocking inside the pserialize section. + * On tryenter failure, drop (rare contention). + */ + psz = pserialize_read_enter(); + PSLIST_READER_FOREACH(inp, &raw6cbtable.inpt_queue_pslist, + struct inpcb, inp_queue_hash) { if (inp->inp_af != AF_INET6) continue; if (in6p_ip6(inp).ip6_nxt && @@ -227,8 +236,14 @@ rip6_input(struct mbuf **mp, int *offp, int proto) } #endif else if ((n = m_copypacket(m, M_DONTWAIT)) != NULL) { - rip6_sbappendaddr(last, ip6, sin6tosa(&rip6src), - *offp, n); + if (mutex_tryenter(last->inp_socket->so_lock)) { + rip6_sbappendaddr(last, ip6, + sin6tosa(&rip6src), *offp, n); + mutex_exit(last->inp_socket->so_lock); + } else { + /* tryenter failed, drop (rare contention) */ + m_freem(n); + } } last = inp; @@ -236,14 +251,38 @@ rip6_input(struct mbuf **mp, int *offp, int proto) #ifdef IPSEC if (ipsec_used && last && ipsec_in_reject(m, last)) { + pserialize_read_exit(psz); m_freem(m); IP6_STATDEC(IP6_STAT_DELIVERED); /* do not inject data into pcb */ } else #endif if (last != NULL) { - rip6_sbappendaddr(last, ip6, sin6tosa(&rip6src), *offp, m); + if (mutex_tryenter(last->inp_socket->so_lock)) { + rip6_sbappendaddr(last, ip6, sin6tosa(&rip6src), + *offp, m); + mutex_exit(last->inp_socket->so_lock); + pserialize_read_exit(psz); + } else if (inpcb_ref_acquire(last)) { + kmutex_t *lock = last->inp_socket->so_lock; + mutex_obj_hold(lock); + pserialize_read_exit(psz); + mutex_enter(lock); + if (last->inp_state != INP_FREED) + rip6_sbappendaddr(last, ip6, + sin6tosa(&rip6src), *offp, m); + else + m_freem(m); + mutex_exit(lock); + mutex_obj_free(lock); + if (inpcb_ref_release(last)) + inpcb_pool_put(last); + } else { + pserialize_read_exit(psz); + m_freem(m); + } } else { + pserialize_read_exit(psz); RIP6_STATINC(RIP6_STAT_NOSOCK); if (m->m_flags & M_MCAST) RIP6_STATINC(RIP6_STAT_NOSOCKMCAST); @@ -261,6 +300,7 @@ rip6_input(struct mbuf **mp, int *offp, int proto) } IP6_STATDEC(IP6_STAT_DELIVERED); } + return IPPROTO_DONE; } @@ -318,6 +358,8 @@ rip6_ctlinput(int cmd, const struct sockaddr *sa, void *d) inp = NULL; inp = in6pcb_lookup(&raw6cbtable, &sa6->sin6_addr, 0, (const struct in6_addr *)&sa6_src->sin6_addr, 0, 0, 0); + if (inp != NULL) + inp_lookup_unlock(inp); #if 0 if (!inp) { /* @@ -590,7 +632,16 @@ rip6_attach(struct socket *so, int proto) int s, error; KASSERT(sotoinpcb(so) == NULL); - sosetlock(so); + + /* + * Assign a per-socket lock. Each raw IPv6 socket gets its own + * mutex so that input on different CPUs can be processed in + * parallel. + */ + if (so->so_lock == NULL) { + so->so_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + mutex_enter(so->so_lock); + } error = kauth_authorize_network(kauth_cred_get(), KAUTH_NETWORK_SOCKET, KAUTH_REQ_NETWORK_SOCKET_RAWSOCK, @@ -980,6 +1031,15 @@ sysctl_net_inet6_raw6_setup(struct sysctllog **clog) CTL_EOL); } +#ifdef NET_MPSAFE +/* + * Raw IPv6 is fully MP-safe with per-socket locking. All usrreqs run + * under solock (held by the socket layer), so KERNEL_LOCK wrappers + * must NOT be used, acquiring KERNEL_LOCK inside solock inverts + * the lock order vs ctlinput (which takes KERNEL_LOCK then solock). + */ +/* No PR_WRAP_USRREQS all functions used directly */ +#else PR_WRAP_USRREQS(rip6) #define rip6_attach rip6_attach_wrapper #define rip6_detach rip6_detach_wrapper @@ -1000,6 +1060,7 @@ PR_WRAP_USRREQS(rip6) #define rip6_send rip6_send_wrapper #define rip6_sendoob rip6_sendoob_wrapper #define rip6_purgeif rip6_purgeif_wrapper +#endif const struct pr_usrreqs rip6_usrreqs = { .pr_attach = rip6_attach, diff --git a/sys/netinet6/sctp6_usrreq.c b/sys/netinet6/sctp6_usrreq.c index 7d62d2409f89..0e32062c405d 100644 --- a/sys/netinet6/sctp6_usrreq.c +++ b/sys/netinet6/sctp6_usrreq.c @@ -103,7 +103,7 @@ __KERNEL_RCSID(0, "$NetBSD: sctp6_usrreq.c,v 1.27 2025/11/19 22:31:52 andvar Exp extern u_int32_t sctp_debug_on; #endif -static int sctp6_detach(struct socket *so); +static void sctp6_detach(struct socket *so); extern int sctp_no_csum_on_loopback; @@ -275,9 +275,17 @@ sctp_skip_csum: offset -= sizeof(*ch); ecn_bits = ((ntohl(ip6->ip6_flow) >> 20) & 0x000000ff); s = splsoftnet(); + /* + * Acquire solock for the duration of input processing. + * sbappend/sorwakeup require solock on NetBSD. + */ + if (in6p && in6p->sctp_socket) + solock(in6p->sctp_socket); (void)sctp_common_input_processing(&m, iphlen, offset, length, sh, ch, in6p, stcb, net, ecn_bits); /* inp's ref-count reduced && stcb unlocked */ + if (in6p && in6p->sctp_socket) + sounlock(in6p->sctp_socket); splx(s); /* XXX this stuff below gets moved to appropriate parts later... */ m_freem(m); @@ -578,7 +586,10 @@ sctp6_attach(struct socket *so, int proto) int error; struct sctp_inpcb *inp; - sosetlock(so); + if (so->so_lock == NULL) { + so->so_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + mutex_enter(so->so_lock); + } inp = (struct sctp_inpcb *)so->so_pcb; if (inp != NULL) return EINVAL; @@ -680,21 +691,20 @@ sctp6_bind(struct socket *so, struct sockaddr *nam, struct lwp *l) } /*This could be made common with sctp_detach() since they are identical */ -static int +static void sctp6_detach(struct socket *so) { struct sctp_inpcb *inp; inp = (struct sctp_inpcb *)so->so_pcb; if (inp == 0) - return EINVAL; + return; if (((so->so_options & SO_LINGER) && (so->so_linger == 0)) || (so->so_rcv.sb_cc > 0)) sctp_inpcb_free(inp, 1); else sctp_inpcb_free(inp, 0); - return 0; } static int @@ -1315,6 +1325,9 @@ sctp6_purgeif(struct socket *so, struct ifnet *ifp) return 0; } +#ifdef NET_MPSAFE +/* No PR_WRAP_USRREQS: SCTP uses per-socket locks */ +#else PR_WRAP_USRREQS(sctp6) #define sctp6_attach sctp6_attach_wrapper #define sctp6_detach sctp6_detach_wrapper @@ -1335,6 +1348,7 @@ PR_WRAP_USRREQS(sctp6) #define sctp6_send sctp6_send_wrapper #define sctp6_sendoob sctp6_sendoob_wrapper #define sctp6_purgeif sctp6_purgeif_wrapper +#endif const struct pr_usrreqs sctp6_usrreqs = { .pr_attach = sctp6_attach, diff --git a/sys/netinet6/udp6_usrreq.c b/sys/netinet6/udp6_usrreq.c index 5ed063c0b806..6e2b3e328f81 100644 --- a/sys/netinet6/udp6_usrreq.c +++ b/sys/netinet6/udp6_usrreq.c @@ -177,9 +177,13 @@ udp6_init(void) static void udp6_notify(struct inpcb *inp, int errno) { - inp->inp_socket->so_error = errno; - sorwakeup(inp->inp_socket); - sowwakeup(inp->inp_socket); + struct socket *so = inp->inp_socket; + + solock(so); + so->so_error = errno; + sorwakeup(so); + sowwakeup(so); + sounlock(so); } void * @@ -244,6 +248,7 @@ udp6_ctlinput(int cmd, const struct sockaddr *sa, void *d) m_copydata(m, off, sizeof(*uhp), (void *)&uh); if (cmd == PRC_MSGSIZE) { + struct inpcb *t; int valid = 0; /* @@ -251,10 +256,14 @@ udp6_ctlinput(int cmd, const struct sockaddr *sa, void *d) * corresponding to the address in the ICMPv6 message * payload. */ - if (in6pcb_lookup(&udbtable, &sa6->sin6_addr, - uh.uh_dport, (const struct in6_addr *)&sa6_src->sin6_addr, - uh.uh_sport, 0, 0)) + t = in6pcb_lookup(&udbtable, + &sa6->sin6_addr, uh.uh_dport, + (const struct in6_addr *)&sa6_src->sin6_addr, + uh.uh_sport, 0, 0); + if (t != NULL) { + inp_lookup_unlock(t); valid++; + } #if 0 /* * As the use of sendto(2) is fairly popular, @@ -287,9 +296,34 @@ udp6_ctlinput(int cmd, const struct sockaddr *sa, void *d) */ } - (void)in6pcb_notify(&udbtable, sa, uh.uh_dport, - sin6tocsa(sa6_src), uh.uh_sport, cmd, cmdarg, - notify); + if (cmd == PRC_MSGSIZE) { + /* + * Use in6pcb_notify for PRC_MSGSIZE because + * unconnected sockets may share the same + * destination and all need path MTU updates. + */ + (void)in6pcb_notify(&udbtable, sa, uh.uh_dport, + sin6tocsa(sa6_src), uh.uh_sport, cmd, cmdarg, + notify); + } else { + struct inpcb *inp; + + inp = in6pcb_lookup(&udbtable, + &sa6->sin6_addr, uh.uh_dport, + &sa6_src->sin6_addr, uh.uh_sport, 0, NULL); + if (inp != NULL) { + kmutex_t *lock = inp->inp_socket->so_lock; + mutex_obj_hold(lock); + bool acquired __diagused = + inpcb_ref_acquire(inp); + KASSERT(acquired); + mutex_exit(lock); + (*notify)(inp, inet6ctlerrmap[cmd]); + if (inpcb_ref_release(inp)) + inpcb_pool_put(inp); + mutex_obj_free(lock); + } + } } else { (void)in6pcb_notify(&udbtable, sa, 0, sin6tocsa(sa6_src), 0, cmd, cmdarg, notify); @@ -466,8 +500,13 @@ udp6_realinput(int af, struct sockaddr_in6 *src, struct sockaddr_in6 *dst, */ /* * Locate pcb(s) for datagram. + * Hold INP_HASH_LOCK for structural stability during + * PSLIST iteration. Per-socket locks are acquired + * individually for each sendup (blocking, so WRITER). */ - TAILQ_FOREACH(inp, &udbtable.inpt_queue, inp_queue) { + INP_HASH_LOCK(&udbtable); + PSLIST_WRITER_FOREACH(inp, &udbtable.inpt_queue_pslist, + struct inpcb, inp_queue_hash) { if (inp->inp_af != AF_INET6) continue; @@ -492,7 +531,9 @@ udp6_realinput(int af, struct sockaddr_in6 *src, struct sockaddr_in6 *dst, continue; } + solock(inp->inp_socket); udp6_sendup(m, off, sin6tosa(src), inp->inp_socket); + sounlock(inp->inp_socket); rcvcnt++; /* @@ -507,9 +548,12 @@ udp6_realinput(int af, struct sockaddr_in6 *src, struct sockaddr_in6 *dst, (SO_REUSEPORT|SO_REUSEADDR)) == 0) break; } + INP_HASH_UNLOCK(&udbtable); } else { /* * Locate pcb for datagram. + * Lookup returns with solock held (per-socket lock). + * Hold it through processing and sendup, then release. */ inp = in6pcb_lookup(&udbtable, &src6, sport, dst6, dport, 0, 0); @@ -527,11 +571,13 @@ udp6_realinput(int af, struct sockaddr_in6 *src, struct sockaddr_in6 *dst, case -1: /* Error, m was freed */ KASSERT(*mp == NULL); rcvcnt = -1; + inp_lookup_unlock(inp); goto bad; case 1: /* ESP over UDP */ KASSERT(*mp == NULL); rcvcnt++; + inp_lookup_unlock(inp); goto bad; case 0: /* plain UDP */ @@ -554,11 +600,13 @@ udp6_realinput(int af, struct sockaddr_in6 *src, struct sockaddr_in6 *dst, case -1: /* Error, m was freed */ KASSERT(*mp == NULL); rcvcnt = -1; + inp_lookup_unlock(inp); goto bad; case 1: /* Foo over UDP */ KASSERT(*mp == NULL); rcvcnt++; + inp_lookup_unlock(inp); goto bad; case 0: /* plain UDP */ @@ -574,6 +622,7 @@ udp6_realinput(int af, struct sockaddr_in6 *src, struct sockaddr_in6 *dst, udp6_sendup(m, off, sin6tosa(src), inp->inp_socket); rcvcnt++; + inp_lookup_unlock(inp); } bad: @@ -1067,7 +1116,16 @@ udp6_attach(struct socket *so, int proto) int s, error; KASSERT(sotoinpcb(so) == NULL); - sosetlock(so); + + /* + * Allocate a per-socket lock for MP-safe operation. + * If a lock was already assigned (e.g., by socreate for + * SS_PRIV sockets), we use that one. + */ + if (so->so_lock == NULL) { + so->so_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + mutex_enter(so->so_lock); + } error = soreserve(so, udp6_sendspace, udp6_recvspace); if (error) { @@ -1186,7 +1244,9 @@ udp6_disconnect(struct socket *so) splx(s); so->so_state &= ~SS_ISCONNECTED; /* XXX */ + INP_HASH_LOCK(inp->inp_table); in6pcb_set_state(inp, INP_BOUND); /* XXX */ + INP_HASH_UNLOCK(inp->inp_table); return 0; } @@ -1513,6 +1573,15 @@ udp6_espinudp(struct mbuf **mp, int off) } #endif /* IPSEC */ +#ifdef NET_MPSAFE +/* + * UDP6 is fully MP-safe with per-socket locking. All usrreqs run + * under solock (held by the socket layer), so KERNEL_LOCK wrappers + * must NOT be used, acquiring KERNEL_LOCK inside solock inverts + * the lock order vs ctlinput (which takes KERNEL_LOCK then solock). + */ +/* No PR_WRAP_USRREQS all functions used directly */ +#else PR_WRAP_USRREQS(udp6) #define udp6_attach udp6_attach_wrapper #define udp6_detach udp6_detach_wrapper @@ -1533,6 +1602,7 @@ PR_WRAP_USRREQS(udp6) #define udp6_send udp6_send_wrapper #define udp6_sendoob udp6_sendoob_wrapper #define udp6_purgeif udp6_purgeif_wrapper +#endif const struct pr_usrreqs udp6_usrreqs = { .pr_attach = udp6_attach, diff --git a/sys/netmpls/mpls_proto.c b/sys/netmpls/mpls_proto.c index 716b59f44eb0..2e9d1ea5c41e 100644 --- a/sys/netmpls/mpls_proto.c +++ b/sys/netmpls/mpls_proto.c @@ -85,7 +85,10 @@ mpls_attach(struct socket *so, int proto) { int error = EOPNOTSUPP; - sosetlock(so); + if (so->so_lock == NULL) { + so->so_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + mutex_enter(so->so_lock); + } if (so->so_snd.sb_hiwat == 0 || so->so_rcv.sb_hiwat == 0) { error = soreserve(so, 8192, 8192); } @@ -313,6 +316,9 @@ sysctl_net_mpls_setup(struct sysctllog **clog) DOMAIN_DEFINE(mplsdomain); +#ifdef NET_MPSAFE +/* No PR_WRAP_USRREQS: per-socket lock, no KERNEL_LOCK needed */ +#else PR_WRAP_USRREQS(mpls) #define mpls_attach mpls_attach_wrapper #define mpls_detach mpls_detach_wrapper @@ -333,6 +339,7 @@ PR_WRAP_USRREQS(mpls) #define mpls_send mpls_send_wrapper #define mpls_sendoob mpls_sendoob_wrapper #define mpls_purgeif mpls_purgeif_wrapper +#endif static const struct pr_usrreqs mpls_usrreqs = { .pr_attach = mpls_attach, diff --git a/sys/rump/net/lib/libsockin/sockin.c b/sys/rump/net/lib/libsockin/sockin.c index 4198f6131dc5..cbc0cf839b50 100644 --- a/sys/rump/net/lib/libsockin/sockin.c +++ b/sys/rump/net/lib/libsockin/sockin.c @@ -441,7 +441,10 @@ sockin_attach(struct socket *so, int proto) const int type = so->so_proto->pr_type; int error, news, family; - sosetlock(so); + if (so->so_lock == NULL) { + so->so_lock = mutex_obj_alloc(MUTEX_DEFAULT, IPL_NONE); + mutex_enter(so->so_lock); + } if (so->so_snd.sb_hiwat == 0 || so->so_rcv.sb_hiwat == 0) { error = soreserve(so, SOCKIN_SBSIZE, SOCKIN_SBSIZE); if (error)